bebechien's picture
Upload folder using huggingface_hub
9aee162 verified
raw
history blame
12.9 kB
import gradio as gr
from config import AppConfig
from engine import FunctionGemmaEngine
from typing import Optional
def build_interface() -> gr.Blocks:
# --- State Management Wrappers ---
def init_session(profile: Optional[gr.OAuthProfile] = None):
config = AppConfig()
new_engine = FunctionGemmaEngine(config)
username = profile.username if profile else None
repo_update, push_update = update_hub_interactive(new_engine, username)
return (
new_engine,
new_engine.get_tools_json(),
new_engine.config.MODEL_NAME,
f"Ready. (Session {new_engine.session_id})",
repo_update, push_update, username
)
def run_training_wrapper(engine, epochs, lr, test_size, shuffle, model_name):
engine.config.MODEL_NAME = model_name.strip()
yield from engine.run_training_pipeline(epochs, lr, test_size, shuffle)
def handle_reset(engine, model_name):
engine.config.MODEL_NAME = model_name.strip()
return engine.refresh_model()
def update_tools_wrapper(engine, json_val):
return engine.update_tools(json_val)
def import_file_wrapper(engine, file_obj):
return engine.load_csv(file_obj)
def stop_wrapper(engine):
engine.trigger_stop()
return "Stopping..."
def zip_wrapper(engine):
path = engine.get_zip_path()
if path:
return gr.update(value=path, visible=True)
return gr.update(value=None, visible=False)
def upload_wrapper(engine, repo_name, oauth_token: gr.OAuthToken | None):
if oauth_token is None:
return "❌ Error: You must log in (top right) to upload models."
if not repo_name:
return "❌ Error: Please enter a repository name."
return engine.upload_model_to_hub(
repo_name=repo_name,
oauth_token=oauth_token.token,
)
def update_repo_preview(username, repo_name):
"""Updates the markdown preview to show 'username/repo_name'."""
if not username:
return "⚠️ Sign in to see the target repository path."
clean_repo = repo_name.strip() if repo_name else "..."
return f"Target Repository: **`{username}/{clean_repo}`**"
def update_hub_interactive(engine, username: Optional[str] = None):
is_logged_in = username is not None
has_model_tuned = engine is not None and engine.has_model_tuned
return gr.update(interactive=is_logged_in), gr.update(interactive=is_logged_in and has_model_tuned)
# --- UI Layout ---
with gr.Blocks(title="FunctionGemma Modkit") as demo:
engine_state = gr.State()
username_state = gr.State()
with gr.Column():
gr.Markdown("# 🤖 FunctionGemma Modkit: Fine-Tuning")
gr.Markdown("Fine-tune FunctionGemma to understand your custom functions.<br>See [README](https://huggingface.co/spaces/google/functiongemma-modkit/blob/main/README.md) for more details.")
gr.Markdown("(Optional) Sign in to Hugging Face if you plan to push your fine-tuned model to the Hub later (3. Export).")
with gr.Row():
gr.LoginButton(value="Sign in with Hugging Face")
with gr.Column(scale=3):
gr.Markdown("")
with gr.Tabs():
# --- TAB 1: PREPARING DATASET ---
with gr.TabItem("1. Preparing Dataset"):
gr.Markdown("### 🛠️ Tool Schema & Data Import")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("**Step 1: Define Functions**<br>Edit the JSON schema below to define the tools the model should learn.")
tools_editor = gr.Code(
language="json",
label="Tool Definitions (JSON Schema)",
lines=15
)
update_tools_btn = gr.Button("💾 Update Tool Schema")
tools_status = gr.Markdown("")
with gr.Column(scale=1):
gr.Markdown("**Step 2: Upload Data (Optional)**<br>To train on your own data, upload a CSV file to replace the [default dataset](https://huggingface.co/datasets/bebechien/SimpleToolCalling).")
gr.Markdown("**Example CSV Row:** No header required.<br>Format: `[User Prompt, Tool Name, Tool Args JSON]`\n```csv\n\"What is the weather in London?\", \"get_weather\", \"{\"\"location\"\": \"\"London, UK\"\"}\"\n```")
import_file = gr.File(
label="Upload Dataset (.csv)",
file_types=[".csv"],
height=100
)
import_status = gr.Markdown("")
# --- TAB 2: TRAINING ---
with gr.TabItem("2. Training"):
gr.Markdown("### 🚀 Fine-Tuning Configuration")
with gr.Group():
gr.Markdown("**Hyperparameters**")
with gr.Row():
default_models = AppConfig().AVAILABLE_MODELS
param_model = gr.Dropdown(
choices=default_models,
allow_custom_value=True,
label="Base Model",
info="Select a preset OR type a custom Hugging Face model ID (e.g. 'google/functiongemma-270m-it')",
interactive=True
)
param_epochs = gr.Slider(
minimum=1, maximum=20, value=5, step=1,
label="Epochs", info="Total training passes"
)
with gr.Row():
param_lr = gr.Number(
value=5e-5,
label="Learning Rate",
info="e.g. 5e-5"
)
param_test_size = gr.Slider(
minimum=0.1, maximum=0.9, value=0.2, step=0.05,
label="Test Split", info="Validation ratio (0.2 = 20%)"
)
param_shuffle = gr.Checkbox(
value=True,
label="Shuffle Data",
info="Randomize before split"
)
with gr.Row():
run_training_btn = gr.Button("🚀 Run Fine-Tuning", variant="primary", scale=1)
stop_training_btn = gr.Button("🛑 Stop", variant="stop", visible=False, scale=1)
clear_reload_btn = gr.Button("🔄 Reload Model & Reset Data", variant="secondary", scale=1)
with gr.Row():
output_display = gr.Textbox(
lines=20,
label="Logs & Results",
value="Initializing...",
interactive=False,
autoscroll=True
)
loss_plot = gr.Plot(label="Training Metrics")
# --- TAB 3: EXPORT ---
with gr.TabItem("3. Export"):
gr.Markdown("### 📦 Export Trained Model")
with gr.Row():
with gr.Column():
gr.Markdown("#### Option A: Download ZIP")
gr.Markdown("Download the model weights locally.")
zip_btn = gr.Button("⬇️ Prepare Model ZIP", variant="secondary", interactive=False)
download_file = gr.File(label="Download Archive", interactive=False)
with gr.Column():
gr.Markdown("#### Option B: Save to Hugging Face Hub")
gr.Markdown("Publish your fine-tuned model to your personal Hugging Face account.")
repo_name_input = gr.Textbox(
label="TargetRepository Name",
value="my-functiongemma-v1",
placeholder="e.g., my-functiongemma-v1",
interactive=False
)
push_to_hub_btn = gr.Button("Save to Hugging Face Hub", variant="secondary", interactive=False)
repo_id_preview = gr.Markdown("Target Repository: (Waiting for input...)")
upload_status = gr.Markdown("")
# --- EVENT WIRING ---
action_buttons = [
clear_reload_btn,
run_training_btn,
zip_btn
]
def set_interactivity(interactive: bool):
return [gr.update(interactive=interactive) for _ in action_buttons]
demo.load(
fn=lambda: set_interactivity(False), outputs=action_buttons
).then(
fn=init_session,
inputs=None,
outputs=[engine_state, tools_editor, param_model, output_display, repo_name_input, push_to_hub_btn, username_state]
).then(
fn=update_repo_preview,
inputs=[username_state, repo_name_input],
outputs=[repo_id_preview]
).then(
fn=lambda: [gr.update(interactive=True)]*2, outputs=[clear_reload_btn, run_training_btn]
)
update_tools_btn.click(
fn=update_tools_wrapper,
inputs=[engine_state, tools_editor],
outputs=[tools_status]
)
import_file.upload(
fn=import_file_wrapper,
inputs=[engine_state, import_file],
outputs=[import_status]
)
run_training_btn.click(
fn=lambda: (
gr.update(visible=False),
gr.update(interactive=False),
gr.update(interactive=False),
gr.update(visible=True)
),
outputs=[run_training_btn, clear_reload_btn, zip_btn, stop_training_btn]
).then(
fn=run_training_wrapper,
inputs=[engine_state, param_epochs, param_lr, param_test_size, param_shuffle, param_model],
outputs=[output_display, loss_plot],
).then(
fn=lambda: (
gr.update(visible=True),
gr.update(interactive=True),
gr.update(interactive=True),
gr.update(visible=False)
),
outputs=[run_training_btn, clear_reload_btn, zip_btn, stop_training_btn]
).then(
fn=update_hub_interactive,
inputs=[engine_state, username_state],
outputs=[repo_name_input, push_to_hub_btn]
)
stop_training_btn.click(
fn=stop_wrapper,
inputs=[engine_state],
outputs=None
)
clear_reload_btn.click(
fn=lambda: set_interactivity(False), outputs=action_buttons
).then(
fn=lambda: gr.update(interactive=False), outputs=push_to_hub_btn
).then(
fn=handle_reset,
inputs=[engine_state, param_model],
outputs=[output_display]
).then(
fn=lambda: [gr.update(interactive=True)]*2, outputs=[clear_reload_btn, run_training_btn]
).then(
fn=update_hub_interactive,
inputs=[engine_state, username_state],
outputs=[repo_name_input, push_to_hub_btn]
)
zip_btn.click(
fn=lambda: set_interactivity(False), outputs=action_buttons
).then(
fn=zip_wrapper,
inputs=[engine_state],
outputs=[download_file]
).then(
fn=lambda: set_interactivity(True), outputs=action_buttons
)
repo_name_input.change(
fn=update_repo_preview,
inputs=[username_state, repo_name_input],
outputs=[repo_id_preview]
)
push_to_hub_btn.click(
fn=lambda: set_interactivity(False), outputs=action_buttons
).then(
fn=lambda: gr.update(interactive=False), outputs=push_to_hub_btn
).then(
fn=upload_wrapper,
inputs=[engine_state, repo_name_input],
outputs=[upload_status]
).then(
fn=lambda: set_interactivity(True), outputs=action_buttons
).then(
fn=update_hub_interactive,
inputs=[engine_state, username_state],
outputs=[repo_name_input, push_to_hub_btn]
)
return demo