Spaces:
Running
Running
| import gradio as gr | |
| from config import AppConfig | |
| from engine import FunctionGemmaEngine | |
| from typing import Optional | |
| def build_interface() -> gr.Blocks: | |
| # --- State Management Wrappers --- | |
| def init_session(profile: Optional[gr.OAuthProfile] = None): | |
| config = AppConfig() | |
| new_engine = FunctionGemmaEngine(config) | |
| username = profile.username if profile else None | |
| repo_update, push_update = update_hub_interactive(new_engine, username) | |
| return ( | |
| new_engine, | |
| new_engine.get_tools_json(), | |
| new_engine.config.MODEL_NAME, | |
| f"Ready. (Session {new_engine.session_id})", | |
| repo_update, push_update, username | |
| ) | |
| def run_training_wrapper(engine, epochs, lr, test_size, shuffle, model_name): | |
| engine.config.MODEL_NAME = model_name.strip() | |
| yield from engine.run_training_pipeline(epochs, lr, test_size, shuffle) | |
| def handle_reset(engine, model_name): | |
| engine.config.MODEL_NAME = model_name.strip() | |
| return engine.refresh_model() | |
| def update_tools_wrapper(engine, json_val): | |
| return engine.update_tools(json_val) | |
| def import_file_wrapper(engine, file_obj): | |
| return engine.load_csv(file_obj) | |
| def stop_wrapper(engine): | |
| engine.trigger_stop() | |
| return "Stopping..." | |
| def zip_wrapper(engine): | |
| path = engine.get_zip_path() | |
| if path: | |
| return gr.update(value=path, visible=True) | |
| return gr.update(value=None, visible=False) | |
| def upload_wrapper(engine, repo_name, oauth_token: gr.OAuthToken | None): | |
| if oauth_token is None: | |
| return "❌ Error: You must log in (top right) to upload models." | |
| if not repo_name: | |
| return "❌ Error: Please enter a repository name." | |
| return engine.upload_model_to_hub( | |
| repo_name=repo_name, | |
| oauth_token=oauth_token.token, | |
| ) | |
| def update_repo_preview(username, repo_name): | |
| """Updates the markdown preview to show 'username/repo_name'.""" | |
| if not username: | |
| return "⚠️ Sign in to see the target repository path." | |
| clean_repo = repo_name.strip() if repo_name else "..." | |
| return f"Target Repository: **`{username}/{clean_repo}`**" | |
| def update_hub_interactive(engine, username: Optional[str] = None): | |
| is_logged_in = username is not None | |
| has_model_tuned = engine is not None and engine.has_model_tuned | |
| return gr.update(interactive=is_logged_in), gr.update(interactive=is_logged_in and has_model_tuned) | |
| # --- UI Layout --- | |
| with gr.Blocks(title="FunctionGemma Modkit") as demo: | |
| engine_state = gr.State() | |
| username_state = gr.State() | |
| with gr.Column(): | |
| gr.Markdown("# 🤖 FunctionGemma Modkit: Fine-Tuning") | |
| gr.Markdown("Fine-tune FunctionGemma to understand your custom functions.<br>See [README](https://huggingface.co/spaces/google/functiongemma-modkit/blob/main/README.md) for more details.") | |
| gr.Markdown("(Optional) Sign in to Hugging Face if you plan to push your fine-tuned model to the Hub later (3. Export).") | |
| with gr.Row(): | |
| gr.LoginButton(value="Sign in with Hugging Face") | |
| with gr.Column(scale=3): | |
| gr.Markdown("") | |
| with gr.Tabs(): | |
| # --- TAB 1: PREPARING DATASET --- | |
| with gr.TabItem("1. Preparing Dataset"): | |
| gr.Markdown("### 🛠️ Tool Schema & Data Import") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("**Step 1: Define Functions**<br>Edit the JSON schema below to define the tools the model should learn.") | |
| tools_editor = gr.Code( | |
| language="json", | |
| label="Tool Definitions (JSON Schema)", | |
| lines=15 | |
| ) | |
| update_tools_btn = gr.Button("💾 Update Tool Schema") | |
| tools_status = gr.Markdown("") | |
| with gr.Column(scale=1): | |
| gr.Markdown("**Step 2: Upload Data (Optional)**<br>To train on your own data, upload a CSV file to replace the [default dataset](https://huggingface.co/datasets/bebechien/SimpleToolCalling).") | |
| gr.Markdown("**Example CSV Row:** No header required.<br>Format: `[User Prompt, Tool Name, Tool Args JSON]`\n```csv\n\"What is the weather in London?\", \"get_weather\", \"{\"\"location\"\": \"\"London, UK\"\"}\"\n```") | |
| import_file = gr.File( | |
| label="Upload Dataset (.csv)", | |
| file_types=[".csv"], | |
| height=100 | |
| ) | |
| import_status = gr.Markdown("") | |
| # --- TAB 2: TRAINING --- | |
| with gr.TabItem("2. Training"): | |
| gr.Markdown("### 🚀 Fine-Tuning Configuration") | |
| with gr.Group(): | |
| gr.Markdown("**Hyperparameters**") | |
| with gr.Row(): | |
| default_models = AppConfig().AVAILABLE_MODELS | |
| param_model = gr.Dropdown( | |
| choices=default_models, | |
| allow_custom_value=True, | |
| label="Base Model", | |
| info="Select a preset OR type a custom Hugging Face model ID (e.g. 'google/functiongemma-270m-it')", | |
| interactive=True | |
| ) | |
| param_epochs = gr.Slider( | |
| minimum=1, maximum=20, value=5, step=1, | |
| label="Epochs", info="Total training passes" | |
| ) | |
| with gr.Row(): | |
| param_lr = gr.Number( | |
| value=5e-5, | |
| label="Learning Rate", | |
| info="e.g. 5e-5" | |
| ) | |
| param_test_size = gr.Slider( | |
| minimum=0.1, maximum=0.9, value=0.2, step=0.05, | |
| label="Test Split", info="Validation ratio (0.2 = 20%)" | |
| ) | |
| param_shuffle = gr.Checkbox( | |
| value=True, | |
| label="Shuffle Data", | |
| info="Randomize before split" | |
| ) | |
| with gr.Row(): | |
| run_training_btn = gr.Button("🚀 Run Fine-Tuning", variant="primary", scale=1) | |
| stop_training_btn = gr.Button("🛑 Stop", variant="stop", visible=False, scale=1) | |
| clear_reload_btn = gr.Button("🔄 Reload Model & Reset Data", variant="secondary", scale=1) | |
| with gr.Row(): | |
| output_display = gr.Textbox( | |
| lines=20, | |
| label="Logs & Results", | |
| value="Initializing...", | |
| interactive=False, | |
| autoscroll=True | |
| ) | |
| loss_plot = gr.Plot(label="Training Metrics") | |
| # --- TAB 3: EXPORT --- | |
| with gr.TabItem("3. Export"): | |
| gr.Markdown("### 📦 Export Trained Model") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("#### Option A: Download ZIP") | |
| gr.Markdown("Download the model weights locally.") | |
| zip_btn = gr.Button("⬇️ Prepare Model ZIP", variant="secondary", interactive=False) | |
| download_file = gr.File(label="Download Archive", interactive=False) | |
| with gr.Column(): | |
| gr.Markdown("#### Option B: Save to Hugging Face Hub") | |
| gr.Markdown("Publish your fine-tuned model to your personal Hugging Face account.") | |
| repo_name_input = gr.Textbox( | |
| label="TargetRepository Name", | |
| value="my-functiongemma-v1", | |
| placeholder="e.g., my-functiongemma-v1", | |
| interactive=False | |
| ) | |
| push_to_hub_btn = gr.Button("Save to Hugging Face Hub", variant="secondary", interactive=False) | |
| repo_id_preview = gr.Markdown("Target Repository: (Waiting for input...)") | |
| upload_status = gr.Markdown("") | |
| # --- EVENT WIRING --- | |
| action_buttons = [ | |
| clear_reload_btn, | |
| run_training_btn, | |
| zip_btn | |
| ] | |
| def set_interactivity(interactive: bool): | |
| return [gr.update(interactive=interactive) for _ in action_buttons] | |
| demo.load( | |
| fn=lambda: set_interactivity(False), outputs=action_buttons | |
| ).then( | |
| fn=init_session, | |
| inputs=None, | |
| outputs=[engine_state, tools_editor, param_model, output_display, repo_name_input, push_to_hub_btn, username_state] | |
| ).then( | |
| fn=update_repo_preview, | |
| inputs=[username_state, repo_name_input], | |
| outputs=[repo_id_preview] | |
| ).then( | |
| fn=lambda: [gr.update(interactive=True)]*2, outputs=[clear_reload_btn, run_training_btn] | |
| ) | |
| update_tools_btn.click( | |
| fn=update_tools_wrapper, | |
| inputs=[engine_state, tools_editor], | |
| outputs=[tools_status] | |
| ) | |
| import_file.upload( | |
| fn=import_file_wrapper, | |
| inputs=[engine_state, import_file], | |
| outputs=[import_status] | |
| ) | |
| run_training_btn.click( | |
| fn=lambda: ( | |
| gr.update(visible=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(visible=True) | |
| ), | |
| outputs=[run_training_btn, clear_reload_btn, zip_btn, stop_training_btn] | |
| ).then( | |
| fn=run_training_wrapper, | |
| inputs=[engine_state, param_epochs, param_lr, param_test_size, param_shuffle, param_model], | |
| outputs=[output_display, loss_plot], | |
| ).then( | |
| fn=lambda: ( | |
| gr.update(visible=True), | |
| gr.update(interactive=True), | |
| gr.update(interactive=True), | |
| gr.update(visible=False) | |
| ), | |
| outputs=[run_training_btn, clear_reload_btn, zip_btn, stop_training_btn] | |
| ).then( | |
| fn=update_hub_interactive, | |
| inputs=[engine_state, username_state], | |
| outputs=[repo_name_input, push_to_hub_btn] | |
| ) | |
| stop_training_btn.click( | |
| fn=stop_wrapper, | |
| inputs=[engine_state], | |
| outputs=None | |
| ) | |
| clear_reload_btn.click( | |
| fn=lambda: set_interactivity(False), outputs=action_buttons | |
| ).then( | |
| fn=lambda: gr.update(interactive=False), outputs=push_to_hub_btn | |
| ).then( | |
| fn=handle_reset, | |
| inputs=[engine_state, param_model], | |
| outputs=[output_display] | |
| ).then( | |
| fn=lambda: [gr.update(interactive=True)]*2, outputs=[clear_reload_btn, run_training_btn] | |
| ).then( | |
| fn=update_hub_interactive, | |
| inputs=[engine_state, username_state], | |
| outputs=[repo_name_input, push_to_hub_btn] | |
| ) | |
| zip_btn.click( | |
| fn=lambda: set_interactivity(False), outputs=action_buttons | |
| ).then( | |
| fn=zip_wrapper, | |
| inputs=[engine_state], | |
| outputs=[download_file] | |
| ).then( | |
| fn=lambda: set_interactivity(True), outputs=action_buttons | |
| ) | |
| repo_name_input.change( | |
| fn=update_repo_preview, | |
| inputs=[username_state, repo_name_input], | |
| outputs=[repo_id_preview] | |
| ) | |
| push_to_hub_btn.click( | |
| fn=lambda: set_interactivity(False), outputs=action_buttons | |
| ).then( | |
| fn=lambda: gr.update(interactive=False), outputs=push_to_hub_btn | |
| ).then( | |
| fn=upload_wrapper, | |
| inputs=[engine_state, repo_name_input], | |
| outputs=[upload_status] | |
| ).then( | |
| fn=lambda: set_interactivity(True), outputs=action_buttons | |
| ).then( | |
| fn=update_hub_interactive, | |
| inputs=[engine_state, username_state], | |
| outputs=[repo_name_input, push_to_hub_btn] | |
| ) | |
| return demo | |