| | import gradio as gr |
| | from typing import Optional, Tuple, Generator, List, Any |
| | from config import AppConfig |
| | from engine import FunctionGemmaEngine |
| |
|
| | |
| |
|
| | class UIController: |
| | """ |
| | Handles the business logic and interaction with the Engine. |
| | Stateless methods that operate on the passed Engine state. |
| | """ |
| |
|
| | @staticmethod |
| | def init_session(profile: Optional[gr.OAuthProfile] = None) -> Tuple[Any, ...]: |
| | config = AppConfig() |
| | new_engine = FunctionGemmaEngine(config) |
| | username = profile.username if profile else None |
| | |
| | |
| | repo_update, push_update, zip_update = UIController.update_hub_interactive(new_engine, username) |
| |
|
| | return ( |
| | new_engine, |
| | new_engine.get_tools_json(), |
| | new_engine.config.MODEL_NAME, |
| | f"Ready. (Session {new_engine.session_id})", |
| | repo_update, |
| | push_update, |
| | zip_update, |
| | username |
| | ) |
| |
|
| | @staticmethod |
| | def run_training(engine: FunctionGemmaEngine, epochs: int, lr: float, |
| | test_size: float, shuffle: bool, model_name: str) -> Generator: |
| | if not engine: |
| | yield "⚠️ Engine not initialized.", None |
| | return |
| | |
| | engine.config.MODEL_NAME = model_name.strip() |
| | yield from engine.run_training_pipeline(epochs, lr, test_size, shuffle) |
| |
|
| | @staticmethod |
| | def run_evaluation(engine: FunctionGemmaEngine, test_size: float, shuffle: bool, model_name: str) -> Generator: |
| | if not engine: |
| | yield "⚠️ Engine not initialized." |
| | return |
| | |
| | engine.config.MODEL_NAME = model_name.strip() |
| | yield from engine.run_evaluation(test_size, shuffle) |
| |
|
| | @staticmethod |
| | def handle_reset(engine: FunctionGemmaEngine, model_name: str) -> str: |
| | engine.config.MODEL_NAME = model_name.strip() |
| | return engine.refresh_model() |
| |
|
| | @staticmethod |
| | def update_tools(engine: FunctionGemmaEngine, json_val: str) -> str: |
| | return engine.update_tools(json_val) |
| |
|
| | @staticmethod |
| | def import_file(engine: FunctionGemmaEngine, file_obj: Any) -> str: |
| | return engine.load_csv(file_obj) |
| |
|
| | @staticmethod |
| | def stop_process(engine: FunctionGemmaEngine) -> str: |
| | engine.trigger_stop() |
| | return |
| |
|
| | @staticmethod |
| | def zip_model(engine: FunctionGemmaEngine) -> Any: |
| | path = engine.get_zip_path() |
| | if path: |
| | return gr.update(value=path, visible=True) |
| | return gr.update(value=None, visible=False) |
| |
|
| | @staticmethod |
| | def upload_model(engine: FunctionGemmaEngine, repo_name: str, oauth_token: Optional[gr.OAuthToken]) -> str: |
| | if oauth_token is None: |
| | return "❌ Error: You must log in (top right) to upload models." |
| | if not repo_name: |
| | return "❌ Error: Please enter a repository name." |
| | |
| | return engine.upload_model_to_hub( |
| | repo_name=repo_name, |
| | oauth_token=oauth_token.token, |
| | ) |
| |
|
| | @staticmethod |
| | def update_repo_preview(username: Optional[str], repo_name: str) -> str: |
| | if not username: |
| | return "⚠️ Sign in to see the target repository path." |
| | clean_repo = repo_name.strip() if repo_name else "..." |
| | return f"Target Repository: **`{username}/{clean_repo}`**" |
| | |
| | @staticmethod |
| | def update_hub_interactive(engine: Optional[FunctionGemmaEngine], username: Optional[str] = None): |
| | is_logged_in = username is not None |
| | has_model_tuned = engine is not None and getattr(engine, 'has_model_tuned', False) |
| | |
| | return ( |
| | gr.update(interactive=is_logged_in), |
| | gr.update(interactive=is_logged_in and has_model_tuned), |
| | gr.update(interactive=has_model_tuned) |
| | ) |
| |
|
| | |
| |
|
| | def _render_header(): |
| | with gr.Column(): |
| | gr.Markdown("# 🤖 FunctionGemma Tuning Lab: Fine-Tuning") |
| | gr.Markdown("Fine-tune FunctionGemma to understand your custom functions.<br>" |
| | "See [README](https://huggingface.co/spaces/google/functiongemma-tuning-lab/blob/main/README.md) for more details.") |
| | gr.Markdown("(Optional) Sign in to Hugging Face if you plan to push your fine-tuned model to the Hub later (3. Export).<br>⚠️ **Warning:** Signing in will refresh the page and reset your current session (including data and model progress).") |
| | with gr.Row(): |
| | gr.LoginButton(value="Sign in with Hugging Face") |
| | with gr.Column(scale=3): |
| | gr.Markdown("") |
| |
|
| | def _render_dataset_tab(engine_state): |
| | with gr.TabItem("1. Preparing Dataset"): |
| | gr.Markdown("### 🛠️ Tool Schema & Data Import") |
| | gr.Markdown("**Important Limitation:** This configuration will fail if the defined tools require **different parameter structures**.<br>The framework cannot currently handle a mix of tools with distinct signatures. For example, the following combination will not work:") |
| | gr.Markdown("* `sum(int a, int b)`\n* `query(string q)`") |
| | gr.Markdown("Ensure that all tools within this specific schema definition share a consistent parameter format.") |
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | gr.Markdown("**Step 1: Define Functions**<br>Edit the JSON schema below to define the tools the model should learn.") |
| | tools_editor = gr.Code(language="json", label="Tool Definitions (JSON Schema)", lines=15) |
| | update_tools_btn = gr.Button("💾 Update Tool Schema") |
| | tools_status = gr.Markdown("") |
| |
|
| | with gr.Column(scale=1): |
| | gr.Markdown("**Step 2: Upload Data (Optional)**<br>To train on your own data, upload a CSV file to replace the [default dataset](https://huggingface.co/datasets/bebechien/SimpleToolCalling).") |
| | gr.Markdown("**Example CSV Row:** No header required.<br>Format: `[User Prompt, Tool Name, Tool Args JSON]`\n```csv\n\"What is the weather in London?\", \"get_weather\", \"{\"\"location\"\": \"\"London, UK\"\"}\"\n```") |
| | import_file = gr.File(label="Upload Dataset (.csv)", file_types=[".csv"], height=100) |
| | import_status = gr.Markdown("") |
| | |
| | |
| | return { |
| | "tools_editor": tools_editor, |
| | "update_tools_btn": update_tools_btn, |
| | "tools_status": tools_status, |
| | "import_file": import_file, |
| | "import_status": import_status |
| | } |
| |
|
| | def _render_training_tab(engine_state): |
| | with gr.TabItem("2. Training & Eval"): |
| | gr.Markdown("### 🚀 Fine-Tuning Configuration") |
| | with gr.Group(): |
| | gr.Markdown("**Hyperparameters**") |
| | with gr.Row(): |
| | default_models = AppConfig().AVAILABLE_MODELS |
| | param_model = gr.Dropdown( |
| | choices=default_models, allow_custom_value=True, label="Base Model", info="Select a preset OR type a custom Hugging Face model ID (e.g. 'google/functiongemma-270m-it')", interactive=True |
| | ) |
| | param_epochs = gr.Slider(1, 20, value=5, step=1, label="Epochs", info="Total training passes") |
| | with gr.Row(): |
| | param_lr = gr.Number(value=5e-5, label="Learning Rate", info="e.g. 5e-5") |
| | param_test_size = gr.Slider(0.1, 0.9, value=0.2, step=0.05, label="Test Split", info="Validation ratio (0.2 = 20%)") |
| | param_shuffle = gr.Checkbox(value=True, label="Shuffle Data", info="Randomize before split") |
| |
|
| | with gr.Row(): |
| | run_eval_btn = gr.Button("🧪 Run Evaluation", variant="secondary", scale=1) |
| | stop_training_btn = gr.Button("🛑 Stop", variant="stop", visible=False, scale=1) |
| | run_training_btn = gr.Button("🚀 Run Fine-Tuning", variant="primary", scale=1) |
| | clear_reload_btn = gr.Button("🔄 Reload Model & Reset Data", variant="secondary", scale=1) |
| |
|
| | with gr.Row(): |
| | output_display = gr.Textbox(lines=20, label="Logs", value="Initializing...", interactive=False, autoscroll=True) |
| | loss_plot = gr.Plot(label="Training Metrics") |
| |
|
| | return { |
| | "params": [param_epochs, param_lr, param_test_size, param_shuffle, param_model], |
| | "eval_params": [param_test_size, param_shuffle, param_model], |
| | "buttons": [run_training_btn, stop_training_btn, clear_reload_btn, run_eval_btn], |
| | "outputs": [output_display, loss_plot], |
| | "model_input": param_model |
| | } |
| |
|
| | def _render_export_tab(engine_state, username_state): |
| | with gr.TabItem("3. Export"): |
| | gr.Markdown("### 📦 Export Trained Model") |
| | with gr.Row(): |
| | with gr.Column(): |
| | gr.Markdown("#### Option A: Download ZIP") |
| | gr.Markdown("Download the model weights locally.") |
| | zip_btn = gr.Button("⬇️ Prepare Model ZIP", variant="secondary", interactive=False) |
| | download_file = gr.File(label="Download Archive", interactive=False) |
| | gr.Markdown("NOTE: Zipping usually takes 1~2 min.") |
| | |
| | with gr.Column(): |
| | gr.Markdown("#### Option B: Save to Hugging Face Hub") |
| | gr.Markdown("Publish your fine-tuned model to your personal Hugging Face account.") |
| | repo_name_input = gr.Textbox( |
| | label="Target Repository Name", value="functiongemma-270m-it-tuning-lab", placeholder="e.g., functiongemma-270m-it-tuned", interactive=False |
| | ) |
| | push_to_hub_btn = gr.Button("Save to Hugging Face Hub", variant="secondary", interactive=False) |
| | repo_id_preview = gr.Markdown("Target Repository: (Waiting for input...)") |
| | upload_status = gr.Markdown("") |
| |
|
| | return { |
| | "zip_controls": [zip_btn, download_file], |
| | "hub_controls": [repo_name_input, push_to_hub_btn, repo_id_preview, upload_status] |
| | } |
| |
|
| | |
| |
|
| | def build_interface() -> gr.Blocks: |
| | with gr.Blocks(title="FunctionGemma Tuning Lab") as demo: |
| | engine_state = gr.State() |
| | username_state = gr.State() |
| |
|
| | _render_header() |
| | |
| | with gr.Tabs(): |
| | data_ui = _render_dataset_tab(engine_state) |
| | train_ui = _render_training_tab(engine_state) |
| | export_ui = _render_export_tab(engine_state, username_state) |
| |
|
| | |
| | |
| | |
| | run_btn, stop_btn, reload_btn, eval_btn = train_ui["buttons"] |
| | action_buttons = [reload_btn, run_btn, eval_btn] |
| | |
| | repo_input = export_ui["hub_controls"][0] |
| | push_btn = export_ui["hub_controls"][1] |
| | zip_btn = export_ui["zip_controls"][0] |
| |
|
| | def lock_ui(): |
| | """Locks all buttons (including Zip/Push) during processing""" |
| | return [gr.update(interactive=False) for _ in action_buttons] + \ |
| | [gr.update(interactive=False), gr.update(interactive=False)] |
| | |
| | def unlock_ui(): |
| | """Unlocks general action buttons only. Zip/Push are handled by update_hub_interactive""" |
| | return [gr.update(interactive=True) for _ in action_buttons] |
| |
|
| | |
| |
|
| | |
| | demo.load(lock_ui, outputs=action_buttons + [push_btn, zip_btn]).then( |
| | fn=UIController.init_session, |
| | inputs=None, |
| | outputs=[ |
| | engine_state, |
| | data_ui["tools_editor"], |
| | train_ui["model_input"], |
| | train_ui["outputs"][0], |
| | repo_input, |
| | push_btn, |
| | zip_btn, |
| | username_state |
| | ] |
| | ).then( |
| | fn=UIController.update_repo_preview, |
| | inputs=[username_state, repo_input], |
| | outputs=[export_ui["hub_controls"][2]] |
| | ).then(unlock_ui, outputs=action_buttons) |
| |
|
| | |
| | data_ui["update_tools_btn"].click( |
| | fn=UIController.update_tools, |
| | inputs=[engine_state, data_ui["tools_editor"]], |
| | outputs=[data_ui["tools_status"]] |
| | ) |
| |
|
| | data_ui["import_file"].upload( |
| | fn=UIController.import_file, |
| | inputs=[engine_state, data_ui["import_file"]], |
| | outputs=[data_ui["import_status"]] |
| | ) |
| |
|
| | |
| | |
| | |
| | train_run_event = run_btn.click( |
| | fn=lambda: ( |
| | gr.update(visible=False), |
| | gr.update(interactive=False), |
| | gr.update(interactive=False), |
| | gr.update(interactive=False), |
| | gr.update(visible=True) |
| | ), |
| | outputs=[run_btn, reload_btn, eval_btn, zip_btn, stop_btn] |
| | ) |
| | train_run_event = train_run_event.then( |
| | fn=UIController.run_training, |
| | inputs=[engine_state, *train_ui["params"]], |
| | outputs=train_ui["outputs"], |
| | ).then( |
| | fn=lambda: ( |
| | gr.update(visible=True), |
| | gr.update(interactive=True), |
| | gr.update(interactive=True), |
| | gr.update(visible=False) |
| | ), |
| | outputs=[run_btn, reload_btn, eval_btn, stop_btn] |
| | ).then( |
| | |
| | fn=UIController.update_hub_interactive, |
| | inputs=[engine_state, username_state], |
| | outputs=[repo_input, push_btn, zip_btn] |
| | ) |
| |
|
| | |
| | eval_run_event = eval_btn.click( |
| | fn=lambda: ( |
| | gr.update(interactive=False), |
| | gr.update(interactive=False), |
| | gr.update(visible=False), |
| | gr.update(visible=True) |
| | ), |
| | outputs=[run_btn, reload_btn, eval_btn, stop_btn] |
| | ) |
| | eval_run_event = eval_run_event.then( |
| | fn=UIController.run_evaluation, |
| | inputs=[engine_state, *train_ui["eval_params"]], |
| | outputs=[train_ui["outputs"][0]] |
| | ).then( |
| | fn=lambda: ( |
| | gr.update(interactive=True), |
| | gr.update(interactive=True), |
| | gr.update(visible=True), |
| | gr.update(visible=False) |
| | ), |
| | outputs=[run_btn, reload_btn, eval_btn, stop_btn] |
| | ) |
| |
|
| | stop_btn.click( |
| | fn=UIController.stop_process, |
| | inputs=[engine_state], |
| | cancels=[train_run_event, eval_run_event], |
| | outputs=None, |
| | queue=False |
| | ) |
| |
|
| | reload_btn.click(lock_ui, outputs=action_buttons + [push_btn, zip_btn]).then( |
| | fn=UIController.handle_reset, |
| | inputs=[engine_state, train_ui["model_input"]], |
| | outputs=[train_ui["outputs"][0]] |
| | ).then(unlock_ui, outputs=action_buttons).then( |
| | fn=UIController.update_hub_interactive, |
| | inputs=[engine_state, username_state], |
| | outputs=[repo_input, push_btn, zip_btn] |
| | ) |
| |
|
| | |
| | zip_btn.click(lock_ui, outputs=action_buttons + [push_btn, zip_btn]).then( |
| | fn=UIController.zip_model, |
| | inputs=[engine_state], |
| | outputs=[export_ui["zip_controls"][1]] |
| | ).then(unlock_ui, outputs=action_buttons).then( |
| | fn=UIController.update_hub_interactive, |
| | inputs=[engine_state, username_state], |
| | outputs=[repo_input, push_btn, zip_btn] |
| | ) |
| |
|
| | repo_input.change( |
| | fn=UIController.update_repo_preview, |
| | inputs=[username_state, repo_input], |
| | outputs=[export_ui["hub_controls"][2]] |
| | ) |
| |
|
| | push_btn.click(lock_ui, outputs=action_buttons + [push_btn, zip_btn]).then( |
| | fn=UIController.upload_model, |
| | inputs=[engine_state, repo_input], |
| | outputs=[export_ui["hub_controls"][3]] |
| | ).then(unlock_ui, outputs=action_buttons).then( |
| | fn=UIController.update_hub_interactive, |
| | inputs=[engine_state, username_state], |
| | outputs=[repo_input, push_btn, zip_btn] |
| | ) |
| |
|
| | return demo |