| import gradio as gr |
| from engine import FunctionGemmaEngine |
|
|
| def build_interface(engine: FunctionGemmaEngine) -> gr.Blocks: |
| with gr.Blocks(title="FunctionGemma Modkit") as demo: |
| gr.Markdown("# π€ FunctionGemma Modkit: Fine-Tuning") |
| gr.Markdown("Fine-tune FunctionGemma to understand your custom functions.<br>See [README](https://huggingface.co/spaces/google/functiongemma-modkit/blob/main/README.md) for more details.") |
| |
| with gr.Tabs(): |
| |
| |
| with gr.TabItem("1. Preparing Dataset"): |
| gr.Markdown("### π οΈ Tool Schema & Data Import") |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| gr.Markdown("**Step 1: Define Functions**\n\nEdit the JSON schema below to define the tools the model should learn.") |
| tools_editor = gr.Code( |
| value=engine.get_tools_json(), |
| language="json", |
| label="Tool Definitions (JSON Schema)", |
| lines=15 |
| ) |
| update_tools_btn = gr.Button("πΎ Update Tool Schema") |
| tools_status = gr.Markdown("") |
|
|
| with gr.Column(scale=1): |
| gr.Markdown("**Step 2: Upload Data (Optional)**\n\nUpload a CSV file to replace the default dataset ([bebechien/SimpleToolCalling](https://huggingface.co/datasets/bebechien/SimpleToolCalling)).") |
| gr.Markdown("**Example CSV Row:**<br>Format: `[User Prompt, Tool Name, Tool Args JSON]`\n```csv\n\"What is the weather in London?\", \"get_weather\", \"{\"\"location\"\": \"\"London, UK\"\"}\"\n```") |
| import_file = gr.File( |
| label="Upload Dataset (.csv)", |
| file_types=[".csv"], |
| height=100 |
| ) |
| import_status = gr.Markdown("") |
|
|
| |
| with gr.TabItem("2. Training"): |
| gr.Markdown("### π Fine-Tuning Configuration") |
| |
| with gr.Group(): |
| gr.Markdown("**Hyperparameters**") |
| with gr.Row(): |
| param_epochs = gr.Slider( |
| minimum=1, maximum=20, value=5, step=1, |
| label="Epochs", info="Total training passes" |
| ) |
| param_lr = gr.Number( |
| value=5e-5, |
| label="Learning Rate", |
| info="e.g. 5e-5" |
| ) |
| param_test_size = gr.Slider( |
| minimum=0.1, maximum=0.9, value=0.2, step=0.05, |
| label="Test Split", info="Validation data ratio. Typical value is 0.2 (80% for training, 20% for testing)" |
| ) |
| param_shuffle = gr.Checkbox( |
| value=True, |
| label="Shuffle Data", |
| info="Randomize before split" |
| ) |
|
|
| with gr.Row(): |
| run_training_btn = gr.Button("π Run Fine-Tuning", variant="primary", scale=2) |
| stop_training_btn = gr.Button("π Stop", variant="stop", visible=False, scale=1) |
| clear_reload_btn = gr.Button("π Reset", variant="secondary", scale=1) |
|
|
| with gr.Row(): |
| |
| output_display = gr.Textbox( |
| lines=20, |
| label="Logs & Results", |
| value="Ready.", |
| interactive=False, |
| autoscroll=True |
| ) |
| |
| loss_plot = gr.Plot(label="Training Metrics") |
|
|
| |
| with gr.TabItem("3. Export"): |
| gr.Markdown("### π¦ Export Trained Model") |
| gr.Markdown("Download the fine-tuned LoRA adapters or full model weights (depending on configuration) as a ZIP file.") |
| |
| with gr.Row(): |
| zip_btn = gr.Button("β¬οΈ Prepare Model ZIP", variant="primary", scale=1) |
| download_file = gr.File(label="Download Archive", interactive=False, scale=2) |
|
|
| |
|
|
| |
| update_tools_btn.click( |
| fn=engine.update_tools, |
| inputs=[tools_editor], |
| outputs=[tools_status] |
| ) |
|
|
| |
| import_file.upload( |
| fn=engine.load_csv, |
| inputs=[import_file], |
| outputs=[import_status] |
| ) |
|
|
| |
| run_training_btn.click( |
| fn=lambda: ( |
| gr.update(visible=False), |
| gr.update(interactive=False), |
| gr.update(visible=True) |
| ), |
| outputs=[run_training_btn, clear_reload_btn, stop_training_btn] |
| ).then( |
| fn=engine.run_training_pipeline, |
| inputs=[param_epochs, param_lr, param_test_size, param_shuffle], |
| outputs=[output_display, loss_plot], |
| ).then( |
| fn=lambda: ( |
| gr.update(visible=True), |
| gr.update(interactive=True), |
| gr.update(visible=False) |
| ), |
| outputs=[run_training_btn, clear_reload_btn, stop_training_btn] |
| ) |
| |
| |
| stop_training_btn.click( |
| fn=lambda: (engine.trigger_stop(), "Stopping...")[1], |
| outputs=None |
| ) |
|
|
| |
| clear_reload_btn.click( |
| fn=engine.refresh_data_and_model, |
| outputs=[output_display] |
| ) |
| |
| |
| def handle_zip(): |
| path = engine.get_zip_path() |
| if path: |
| return gr.update(value=path, visible=True) |
| return gr.update(value=None, visible=False) |
|
|
| zip_btn.click( |
| fn=handle_zip, |
| outputs=[download_file] |
| ) |
|
|
| return demo |
|
|