|
|
import gradio as gr |
|
|
from typing import Optional, Tuple, Generator, List, Any |
|
|
from config import AppConfig |
|
|
from engine import FunctionGemmaEngine |
|
|
|
|
|
|
|
|
|
|
|
class UIController: |
|
|
""" |
|
|
Handles the business logic and interaction with the Engine. |
|
|
Stateless methods that operate on the passed Engine state. |
|
|
""" |
|
|
|
|
|
@staticmethod |
|
|
def init_session(profile: Optional[gr.OAuthProfile] = None) -> Tuple[Any, ...]: |
|
|
config = AppConfig() |
|
|
new_engine = FunctionGemmaEngine(config) |
|
|
username = profile.username if profile else None |
|
|
|
|
|
|
|
|
repo_update, push_update, zip_update = UIController.update_hub_interactive(new_engine, username) |
|
|
|
|
|
return ( |
|
|
new_engine, |
|
|
new_engine.get_tools_json(), |
|
|
new_engine.config.MODEL_NAME, |
|
|
f"Ready. (Session {new_engine.session_id})", |
|
|
repo_update, |
|
|
push_update, |
|
|
zip_update, |
|
|
username |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def run_training(engine: FunctionGemmaEngine, epochs: int, lr: float, |
|
|
test_size: float, shuffle: bool, model_name: str) -> Generator: |
|
|
if not engine: |
|
|
yield "⚠️ Engine not initialized.", None |
|
|
return |
|
|
|
|
|
engine.config.MODEL_NAME = model_name.strip() |
|
|
yield from engine.run_training_pipeline(epochs, lr, test_size, shuffle) |
|
|
|
|
|
@staticmethod |
|
|
def run_evaluation(engine: FunctionGemmaEngine, test_size: float, shuffle: bool, model_name: str) -> Generator: |
|
|
if not engine: |
|
|
yield "⚠️ Engine not initialized." |
|
|
return |
|
|
|
|
|
engine.config.MODEL_NAME = model_name.strip() |
|
|
yield from engine.run_evaluation(test_size, shuffle) |
|
|
|
|
|
@staticmethod |
|
|
def handle_reset(engine: FunctionGemmaEngine, model_name: str) -> str: |
|
|
engine.config.MODEL_NAME = model_name.strip() |
|
|
return engine.refresh_model() |
|
|
|
|
|
@staticmethod |
|
|
def update_tools(engine: FunctionGemmaEngine, json_val: str) -> str: |
|
|
return engine.update_tools(json_val) |
|
|
|
|
|
@staticmethod |
|
|
def import_file(engine: FunctionGemmaEngine, file_obj: Any) -> str: |
|
|
return engine.load_csv(file_obj) |
|
|
|
|
|
@staticmethod |
|
|
def stop_process(engine: FunctionGemmaEngine) -> str: |
|
|
engine.trigger_stop() |
|
|
return |
|
|
|
|
|
@staticmethod |
|
|
def zip_model(engine: FunctionGemmaEngine) -> Any: |
|
|
path = engine.get_zip_path() |
|
|
if path: |
|
|
return gr.update(value=path, visible=True) |
|
|
return gr.update(value=None, visible=False) |
|
|
|
|
|
@staticmethod |
|
|
def upload_model(engine: FunctionGemmaEngine, repo_name: str, oauth_token: Optional[gr.OAuthToken]) -> str: |
|
|
if oauth_token is None: |
|
|
return "❌ Error: You must log in (top right) to upload models." |
|
|
if not repo_name: |
|
|
return "❌ Error: Please enter a repository name." |
|
|
|
|
|
return engine.upload_model_to_hub( |
|
|
repo_name=repo_name, |
|
|
oauth_token=oauth_token.token, |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def update_repo_preview(username: Optional[str], repo_name: str) -> str: |
|
|
if not username: |
|
|
return "⚠️ Sign in to see the target repository path." |
|
|
clean_repo = repo_name.strip() if repo_name else "..." |
|
|
return f"Target Repository: **`{username}/{clean_repo}`**" |
|
|
|
|
|
@staticmethod |
|
|
def update_hub_interactive(engine: Optional[FunctionGemmaEngine], username: Optional[str] = None): |
|
|
is_logged_in = username is not None |
|
|
has_model_tuned = engine is not None and getattr(engine, 'has_model_tuned', False) |
|
|
|
|
|
return ( |
|
|
gr.update(interactive=is_logged_in), |
|
|
gr.update(interactive=is_logged_in and has_model_tuned), |
|
|
gr.update(interactive=has_model_tuned) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def _render_header(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("# 🤖 FunctionGemma Tuning Lab: Fine-Tuning") |
|
|
gr.Markdown("Fine-tune FunctionGemma to understand your custom functions.<br>" |
|
|
"See [README](https://huggingface.co/spaces/google/functiongemma-tuning-lab/blob/main/README.md) for more details.") |
|
|
gr.Markdown("(Optional) Sign in to Hugging Face if you plan to push your fine-tuned model to the Hub later (3. Export).<br>⚠️ **Warning:** Signing in will refresh the page and reset your current session (including data and model progress).") |
|
|
with gr.Row(): |
|
|
gr.LoginButton(value="Sign in with Hugging Face") |
|
|
with gr.Column(scale=3): |
|
|
gr.Markdown("") |
|
|
|
|
|
def _render_dataset_tab(engine_state): |
|
|
with gr.TabItem("1. Preparing Dataset"): |
|
|
gr.Markdown("### 🛠️ Tool Schema & Data Import") |
|
|
gr.Markdown("**Important Limitation:** This configuration will fail if the defined tools require **different parameter structures**.<br>The framework cannot currently handle a mix of tools with distinct signatures. For example, the following combination will not work:") |
|
|
gr.Markdown("* `sum(int a, int b)`\n* `query(string q)`") |
|
|
gr.Markdown("Ensure that all tools within this specific schema definition share a consistent parameter format.") |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("**Step 1: Define Functions**<br>Edit the JSON schema below to define the tools the model should learn.") |
|
|
tools_editor = gr.Code(language="json", label="Tool Definitions (JSON Schema)", lines=15) |
|
|
update_tools_btn = gr.Button("💾 Update Tool Schema") |
|
|
tools_status = gr.Markdown("") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("**Step 2: Upload Data (Optional)**<br>To train on your own data, upload a CSV file to replace the [default dataset](https://huggingface.co/datasets/bebechien/SimpleToolCalling).") |
|
|
gr.Markdown("**Example CSV Row:** No header required.<br>Format: `[User Prompt, Tool Name, Tool Args JSON]`\n```csv\n\"What is the weather in London?\", \"get_weather\", \"{\"\"location\"\": \"\"London, UK\"\"}\"\n```") |
|
|
import_file = gr.File(label="Upload Dataset (.csv)", file_types=[".csv"], height=100) |
|
|
import_status = gr.Markdown("") |
|
|
|
|
|
|
|
|
return { |
|
|
"tools_editor": tools_editor, |
|
|
"update_tools_btn": update_tools_btn, |
|
|
"tools_status": tools_status, |
|
|
"import_file": import_file, |
|
|
"import_status": import_status |
|
|
} |
|
|
|
|
|
def _render_training_tab(engine_state): |
|
|
with gr.TabItem("2. Training & Eval"): |
|
|
gr.Markdown("### 🚀 Fine-Tuning Configuration") |
|
|
with gr.Group(): |
|
|
gr.Markdown("**Hyperparameters**") |
|
|
with gr.Row(): |
|
|
default_models = AppConfig().AVAILABLE_MODELS |
|
|
param_model = gr.Dropdown( |
|
|
choices=default_models, allow_custom_value=True, label="Base Model", info="Select a preset OR type a custom Hugging Face model ID (e.g. 'google/functiongemma-270m-it')", interactive=True |
|
|
) |
|
|
param_epochs = gr.Slider(1, 20, value=5, step=1, label="Epochs", info="Total training passes") |
|
|
with gr.Row(): |
|
|
param_lr = gr.Number(value=5e-5, label="Learning Rate", info="e.g. 5e-5") |
|
|
param_test_size = gr.Slider(0.1, 0.9, value=0.2, step=0.05, label="Test Split", info="Validation ratio (0.2 = 20%)") |
|
|
param_shuffle = gr.Checkbox(value=True, label="Shuffle Data", info="Randomize before split") |
|
|
|
|
|
with gr.Row(): |
|
|
run_eval_btn = gr.Button("🧪 Run Evaluation", variant="secondary", scale=1) |
|
|
stop_training_btn = gr.Button("🛑 Stop", variant="stop", visible=False, scale=1) |
|
|
run_training_btn = gr.Button("🚀 Run Fine-Tuning", variant="primary", scale=1) |
|
|
clear_reload_btn = gr.Button("🔄 Reload Model & Reset Data", variant="secondary", scale=1) |
|
|
|
|
|
with gr.Row(): |
|
|
output_display = gr.Textbox(lines=20, label="Logs", value="Initializing...", interactive=False, autoscroll=True) |
|
|
loss_plot = gr.Plot(label="Training Metrics") |
|
|
|
|
|
return { |
|
|
"params": [param_epochs, param_lr, param_test_size, param_shuffle, param_model], |
|
|
"eval_params": [param_test_size, param_shuffle, param_model], |
|
|
"buttons": [run_training_btn, stop_training_btn, clear_reload_btn, run_eval_btn], |
|
|
"outputs": [output_display, loss_plot], |
|
|
"model_input": param_model |
|
|
} |
|
|
|
|
|
def _render_export_tab(engine_state, username_state): |
|
|
with gr.TabItem("3. Export"): |
|
|
gr.Markdown("### 📦 Export Trained Model") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("#### Option A: Download ZIP") |
|
|
gr.Markdown("Download the model weights locally.") |
|
|
zip_btn = gr.Button("⬇️ Prepare Model ZIP", variant="secondary", interactive=False) |
|
|
download_file = gr.File(label="Download Archive", interactive=False) |
|
|
gr.Markdown("NOTE: Zipping usually takes 1~2 min.") |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("#### Option B: Save to Hugging Face Hub") |
|
|
gr.Markdown("Publish your fine-tuned model to your personal Hugging Face account.") |
|
|
repo_name_input = gr.Textbox( |
|
|
label="Target Repository Name", value="functiongemma-270m-it-tuning-lab", placeholder="e.g., functiongemma-270m-it-tuned", interactive=False |
|
|
) |
|
|
push_to_hub_btn = gr.Button("Save to Hugging Face Hub", variant="secondary", interactive=False) |
|
|
repo_id_preview = gr.Markdown("Target Repository: (Waiting for input...)") |
|
|
upload_status = gr.Markdown("") |
|
|
|
|
|
return { |
|
|
"zip_controls": [zip_btn, download_file], |
|
|
"hub_controls": [repo_name_input, push_to_hub_btn, repo_id_preview, upload_status] |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def build_interface() -> gr.Blocks: |
|
|
with gr.Blocks(title="FunctionGemma Tuning Lab") as demo: |
|
|
engine_state = gr.State() |
|
|
username_state = gr.State() |
|
|
|
|
|
_render_header() |
|
|
|
|
|
with gr.Tabs(): |
|
|
data_ui = _render_dataset_tab(engine_state) |
|
|
train_ui = _render_training_tab(engine_state) |
|
|
export_ui = _render_export_tab(engine_state, username_state) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
run_btn, stop_btn, reload_btn, eval_btn = train_ui["buttons"] |
|
|
action_buttons = [reload_btn, run_btn, eval_btn] |
|
|
|
|
|
repo_input = export_ui["hub_controls"][0] |
|
|
push_btn = export_ui["hub_controls"][1] |
|
|
zip_btn = export_ui["zip_controls"][0] |
|
|
|
|
|
def lock_ui(): |
|
|
"""Locks all buttons (including Zip/Push) during processing""" |
|
|
return [gr.update(interactive=False) for _ in action_buttons] + \ |
|
|
[gr.update(interactive=False), gr.update(interactive=False)] |
|
|
|
|
|
def unlock_ui(): |
|
|
"""Unlocks general action buttons only. Zip/Push are handled by update_hub_interactive""" |
|
|
return [gr.update(interactive=True) for _ in action_buttons] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo.load(lock_ui, outputs=action_buttons + [push_btn, zip_btn]).then( |
|
|
fn=UIController.init_session, |
|
|
inputs=None, |
|
|
outputs=[ |
|
|
engine_state, |
|
|
data_ui["tools_editor"], |
|
|
train_ui["model_input"], |
|
|
train_ui["outputs"][0], |
|
|
repo_input, |
|
|
push_btn, |
|
|
zip_btn, |
|
|
username_state |
|
|
] |
|
|
).then( |
|
|
fn=UIController.update_repo_preview, |
|
|
inputs=[username_state, repo_input], |
|
|
outputs=[export_ui["hub_controls"][2]] |
|
|
).then(unlock_ui, outputs=action_buttons) |
|
|
|
|
|
|
|
|
data_ui["update_tools_btn"].click( |
|
|
fn=UIController.update_tools, |
|
|
inputs=[engine_state, data_ui["tools_editor"]], |
|
|
outputs=[data_ui["tools_status"]] |
|
|
) |
|
|
|
|
|
data_ui["import_file"].upload( |
|
|
fn=UIController.import_file, |
|
|
inputs=[engine_state, data_ui["import_file"]], |
|
|
outputs=[data_ui["import_status"]] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_run_event = run_btn.click( |
|
|
fn=lambda: ( |
|
|
gr.update(visible=False), |
|
|
gr.update(interactive=False), |
|
|
gr.update(interactive=False), |
|
|
gr.update(interactive=False), |
|
|
gr.update(visible=True) |
|
|
), |
|
|
outputs=[run_btn, reload_btn, eval_btn, zip_btn, stop_btn] |
|
|
) |
|
|
train_run_event = train_run_event.then( |
|
|
fn=UIController.run_training, |
|
|
inputs=[engine_state, *train_ui["params"]], |
|
|
outputs=train_ui["outputs"], |
|
|
).then( |
|
|
fn=lambda: ( |
|
|
gr.update(visible=True), |
|
|
gr.update(interactive=True), |
|
|
gr.update(interactive=True), |
|
|
gr.update(visible=False) |
|
|
), |
|
|
outputs=[run_btn, reload_btn, eval_btn, stop_btn] |
|
|
).then( |
|
|
|
|
|
fn=UIController.update_hub_interactive, |
|
|
inputs=[engine_state, username_state], |
|
|
outputs=[repo_input, push_btn, zip_btn] |
|
|
) |
|
|
|
|
|
|
|
|
eval_run_event = eval_btn.click( |
|
|
fn=lambda: ( |
|
|
gr.update(interactive=False), |
|
|
gr.update(interactive=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=True) |
|
|
), |
|
|
outputs=[run_btn, reload_btn, eval_btn, stop_btn] |
|
|
) |
|
|
eval_run_event = eval_run_event.then( |
|
|
fn=UIController.run_evaluation, |
|
|
inputs=[engine_state, *train_ui["eval_params"]], |
|
|
outputs=[train_ui["outputs"][0]] |
|
|
).then( |
|
|
fn=lambda: ( |
|
|
gr.update(interactive=True), |
|
|
gr.update(interactive=True), |
|
|
gr.update(visible=True), |
|
|
gr.update(visible=False) |
|
|
), |
|
|
outputs=[run_btn, reload_btn, eval_btn, stop_btn] |
|
|
) |
|
|
|
|
|
stop_btn.click( |
|
|
fn=UIController.stop_process, |
|
|
inputs=[engine_state], |
|
|
cancels=[train_run_event, eval_run_event], |
|
|
outputs=None, |
|
|
queue=False |
|
|
) |
|
|
|
|
|
reload_btn.click(lock_ui, outputs=action_buttons + [push_btn, zip_btn]).then( |
|
|
fn=UIController.handle_reset, |
|
|
inputs=[engine_state, train_ui["model_input"]], |
|
|
outputs=[train_ui["outputs"][0]] |
|
|
).then(unlock_ui, outputs=action_buttons).then( |
|
|
fn=UIController.update_hub_interactive, |
|
|
inputs=[engine_state, username_state], |
|
|
outputs=[repo_input, push_btn, zip_btn] |
|
|
) |
|
|
|
|
|
|
|
|
zip_btn.click(lock_ui, outputs=action_buttons + [push_btn, zip_btn]).then( |
|
|
fn=UIController.zip_model, |
|
|
inputs=[engine_state], |
|
|
outputs=[export_ui["zip_controls"][1]] |
|
|
).then(unlock_ui, outputs=action_buttons).then( |
|
|
fn=UIController.update_hub_interactive, |
|
|
inputs=[engine_state, username_state], |
|
|
outputs=[repo_input, push_btn, zip_btn] |
|
|
) |
|
|
|
|
|
repo_input.change( |
|
|
fn=UIController.update_repo_preview, |
|
|
inputs=[username_state, repo_input], |
|
|
outputs=[export_ui["hub_controls"][2]] |
|
|
) |
|
|
|
|
|
push_btn.click(lock_ui, outputs=action_buttons + [push_btn, zip_btn]).then( |
|
|
fn=UIController.upload_model, |
|
|
inputs=[engine_state, repo_input], |
|
|
outputs=[export_ui["hub_controls"][3]] |
|
|
).then(unlock_ui, outputs=action_buttons).then( |
|
|
fn=UIController.update_hub_interactive, |
|
|
inputs=[engine_state, username_state], |
|
|
outputs=[repo_input, push_btn, zip_btn] |
|
|
) |
|
|
|
|
|
return demo |