bebechien's picture
Add warning on sign in with hf
1ceda48 verified
raw
history blame
16.2 kB
import gradio as gr
from typing import Optional, Tuple, Generator, List, Any
from config import AppConfig
from engine import FunctionGemmaEngine
# --- Controller / Logic Layer ---
class UIController:
"""
Handles the business logic and interaction with the Engine.
Stateless methods that operate on the passed Engine state.
"""
@staticmethod
def init_session(profile: Optional[gr.OAuthProfile] = None) -> Tuple[Any, ...]:
config = AppConfig()
new_engine = FunctionGemmaEngine(config)
username = profile.username if profile else None
# Calculate initial interactivity state
repo_update, push_update, zip_update = UIController.update_hub_interactive(new_engine, username)
return (
new_engine,
new_engine.get_tools_json(),
new_engine.config.MODEL_NAME,
f"Ready. (Session {new_engine.session_id})",
repo_update,
push_update,
zip_update,
username
)
@staticmethod
def run_training(engine: FunctionGemmaEngine, epochs: int, lr: float,
test_size: float, shuffle: bool, model_name: str) -> Generator:
if not engine:
yield "⚠️ Engine not initialized.", None
return
engine.config.MODEL_NAME = model_name.strip()
yield from engine.run_training_pipeline(epochs, lr, test_size, shuffle)
@staticmethod
def run_evaluation(engine: FunctionGemmaEngine, test_size: float, shuffle: bool, model_name: str) -> Generator:
if not engine:
yield "⚠️ Engine not initialized."
return
engine.config.MODEL_NAME = model_name.strip()
yield from engine.run_evaluation(test_size, shuffle)
@staticmethod
def handle_reset(engine: FunctionGemmaEngine, model_name: str) -> str:
engine.config.MODEL_NAME = model_name.strip()
return engine.refresh_model()
@staticmethod
def update_tools(engine: FunctionGemmaEngine, json_val: str) -> str:
return engine.update_tools(json_val)
@staticmethod
def import_file(engine: FunctionGemmaEngine, file_obj: Any) -> str:
return engine.load_csv(file_obj)
@staticmethod
def stop_process(engine: FunctionGemmaEngine) -> str:
engine.trigger_stop()
return
@staticmethod
def zip_model(engine: FunctionGemmaEngine) -> Any:
path = engine.get_zip_path()
if path:
return gr.update(value=path, visible=True)
return gr.update(value=None, visible=False)
@staticmethod
def upload_model(engine: FunctionGemmaEngine, repo_name: str, oauth_token: Optional[gr.OAuthToken]) -> str:
if oauth_token is None:
return "❌ Error: You must log in (top right) to upload models."
if not repo_name:
return "❌ Error: Please enter a repository name."
return engine.upload_model_to_hub(
repo_name=repo_name,
oauth_token=oauth_token.token,
)
@staticmethod
def update_repo_preview(username: Optional[str], repo_name: str) -> str:
if not username:
return "⚠️ Sign in to see the target repository path."
clean_repo = repo_name.strip() if repo_name else "..."
return f"Target Repository: **`{username}/{clean_repo}`**"
@staticmethod
def update_hub_interactive(engine: Optional[FunctionGemmaEngine], username: Optional[str] = None):
is_logged_in = username is not None
has_model_tuned = engine is not None and getattr(engine, 'has_model_tuned', False)
return (
gr.update(interactive=is_logged_in),
gr.update(interactive=is_logged_in and has_model_tuned),
gr.update(interactive=has_model_tuned)
)
# --- View / Layout Layer ---
def _render_header():
with gr.Column():
gr.Markdown("# 🤖 FunctionGemma Tuning Lab: Fine-Tuning")
gr.Markdown("Fine-tune FunctionGemma to understand your custom functions.<br>"
"See [README](https://huggingface.co/spaces/google/functiongemma-tuning-lab/blob/main/README.md) for more details.")
gr.Markdown("(Optional) Sign in to Hugging Face if you plan to push your fine-tuned model to the Hub later (3. Export).")
with gr.Row():
gr.LoginButton(value="Sign in with Hugging Face")
with gr.Column(scale=3):
gr.Markdown("⚠️ **Warning:** Signing in will refresh the page and reset your current session (including data and model progress).")
def _render_dataset_tab(engine_state):
with gr.TabItem("1. Preparing Dataset"):
gr.Markdown("### 🛠️ Tool Schema & Data Import")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("**Step 1: Define Functions**<br>Edit the JSON schema below to define the tools the model should learn.")
tools_editor = gr.Code(language="json", label="Tool Definitions (JSON Schema)", lines=15)
update_tools_btn = gr.Button("💾 Update Tool Schema")
tools_status = gr.Markdown("")
with gr.Column(scale=1):
gr.Markdown("**Step 2: Upload Data (Optional)**<br>To train on your own data, upload a CSV file to replace the [default dataset](https://huggingface.co/datasets/bebechien/SimpleToolCalling).")
gr.Markdown("**Example CSV Row:** No header required.<br>Format: `[User Prompt, Tool Name, Tool Args JSON]`\n```csv\n\"What is the weather in London?\", \"get_weather\", \"{\"\"location\"\": \"\"London, UK\"\"}\"\n```")
import_file = gr.File(label="Upload Dataset (.csv)", file_types=[".csv"], height=100)
import_status = gr.Markdown("")
# Return controls needed for wiring
return {
"tools_editor": tools_editor,
"update_tools_btn": update_tools_btn,
"tools_status": tools_status,
"import_file": import_file,
"import_status": import_status
}
def _render_training_tab(engine_state):
with gr.TabItem("2. Training & Eval"):
gr.Markdown("### 🚀 Fine-Tuning Configuration")
with gr.Group():
gr.Markdown("**Hyperparameters**")
with gr.Row():
default_models = AppConfig().AVAILABLE_MODELS
param_model = gr.Dropdown(
choices=default_models, allow_custom_value=True, label="Base Model", info="Select a preset OR type a custom Hugging Face model ID (e.g. 'google/functiongemma-270m-it')", interactive=True
)
param_epochs = gr.Slider(1, 20, value=5, step=1, label="Epochs", info="Total training passes")
with gr.Row():
param_lr = gr.Number(value=5e-5, label="Learning Rate", info="e.g. 5e-5")
param_test_size = gr.Slider(0.1, 0.9, value=0.2, step=0.05, label="Test Split", info="Validation ratio (0.2 = 20%)")
param_shuffle = gr.Checkbox(value=True, label="Shuffle Data", info="Randomize before split")
with gr.Row():
run_eval_btn = gr.Button("🧪 Run Evaluation", variant="secondary", scale=1)
stop_training_btn = gr.Button("🛑 Stop", variant="stop", visible=False, scale=1)
run_training_btn = gr.Button("🚀 Run Fine-Tuning", variant="primary", scale=1)
clear_reload_btn = gr.Button("🔄 Reload Model & Reset Data", variant="secondary", scale=1)
with gr.Row():
output_display = gr.Textbox(lines=20, label="Logs", value="Initializing...", interactive=False, autoscroll=True)
loss_plot = gr.Plot(label="Training Metrics")
return {
"params": [param_epochs, param_lr, param_test_size, param_shuffle, param_model],
"eval_params": [param_test_size, param_shuffle, param_model],
"buttons": [run_training_btn, stop_training_btn, clear_reload_btn, run_eval_btn],
"outputs": [output_display, loss_plot],
"model_input": param_model # specifically needed for initialization
}
def _render_export_tab(engine_state, username_state):
with gr.TabItem("3. Export"):
gr.Markdown("### 📦 Export Trained Model")
with gr.Row():
with gr.Column():
gr.Markdown("#### Option A: Download ZIP")
gr.Markdown("Download the model weights locally.")
zip_btn = gr.Button("⬇️ Prepare Model ZIP", variant="secondary", interactive=False)
download_file = gr.File(label="Download Archive", interactive=False)
gr.Markdown("NOTE: Zipping usually takes 1~2 min.")
with gr.Column():
gr.Markdown("#### Option B: Save to Hugging Face Hub")
gr.Markdown("Publish your fine-tuned model to your personal Hugging Face account.")
repo_name_input = gr.Textbox(
label="Target Repository Name", value="functiongemma-270m-it-tuning-lab", placeholder="e.g., functiongemma-270m-it-tuned", interactive=False
)
push_to_hub_btn = gr.Button("Save to Hugging Face Hub", variant="secondary", interactive=False)
repo_id_preview = gr.Markdown("Target Repository: (Waiting for input...)")
upload_status = gr.Markdown("")
return {
"zip_controls": [zip_btn, download_file],
"hub_controls": [repo_name_input, push_to_hub_btn, repo_id_preview, upload_status]
}
# --- Main Build Function ---
def build_interface() -> gr.Blocks:
with gr.Blocks(title="FunctionGemma Tuning Lab") as demo:
engine_state = gr.State()
username_state = gr.State()
_render_header()
with gr.Tabs():
data_ui = _render_dataset_tab(engine_state)
train_ui = _render_training_tab(engine_state)
export_ui = _render_export_tab(engine_state, username_state)
# Helpers for UI State
# 'action_buttons' now ONLY contains buttons that should always be enabled after a process
# Zip and Push buttons are excluded here because their state depends on has_model_tuned
run_btn, stop_btn, reload_btn, eval_btn = train_ui["buttons"]
action_buttons = [reload_btn, run_btn, eval_btn]
repo_input = export_ui["hub_controls"][0]
push_btn = export_ui["hub_controls"][1]
zip_btn = export_ui["zip_controls"][0]
def lock_ui():
"""Locks all buttons (including Zip/Push) during processing"""
return [gr.update(interactive=False) for _ in action_buttons] + \
[gr.update(interactive=False), gr.update(interactive=False)]
def unlock_ui():
"""Unlocks general action buttons only. Zip/Push are handled by update_hub_interactive"""
return [gr.update(interactive=True) for _ in action_buttons]
# --- Event Wiring ---
# 1. Initialization
demo.load(lock_ui, outputs=action_buttons + [push_btn, zip_btn]).then(
fn=UIController.init_session,
inputs=None,
outputs=[
engine_state,
data_ui["tools_editor"],
train_ui["model_input"],
train_ui["outputs"][0], # log output
repo_input,
push_btn,
zip_btn, # Update Zip state based on initial engine state
username_state
]
).then(
fn=UIController.update_repo_preview,
inputs=[username_state, repo_input],
outputs=[export_ui["hub_controls"][2]]
).then(unlock_ui, outputs=action_buttons)
# 2. Data Tab
data_ui["update_tools_btn"].click(
fn=UIController.update_tools,
inputs=[engine_state, data_ui["tools_editor"]],
outputs=[data_ui["tools_status"]]
)
data_ui["import_file"].upload(
fn=UIController.import_file,
inputs=[engine_state, data_ui["import_file"]],
outputs=[data_ui["import_status"]]
)
# 3. Training & Eval Tab
# 3a. Training
train_run_event = run_btn.click(
fn=lambda: (
gr.update(visible=False),
gr.update(interactive=False), # Lock Reload
gr.update(interactive=False), # Lock Eval
gr.update(interactive=False), # Lock Zip
gr.update(visible=True) # Show Stop
),
outputs=[run_btn, reload_btn, eval_btn, zip_btn, stop_btn]
)
train_run_event = train_run_event.then(
fn=UIController.run_training,
inputs=[engine_state, *train_ui["params"]],
outputs=train_ui["outputs"],
).then(
fn=lambda: (
gr.update(visible=True),
gr.update(interactive=True),
gr.update(interactive=True),
gr.update(visible=False)
),
outputs=[run_btn, reload_btn, eval_btn, stop_btn]
).then(
# Final check determines if Zip/Push should unlock
fn=UIController.update_hub_interactive,
inputs=[engine_state, username_state],
outputs=[repo_input, push_btn, zip_btn]
)
# 3b. Evaluation
eval_run_event = eval_btn.click(
fn=lambda: (
gr.update(interactive=False), # Lock Run
gr.update(interactive=False), # Lock Reload
gr.update(visible=False), # Hide self (optional, or lock)
gr.update(visible=True) # Show Stop
),
outputs=[run_btn, reload_btn, eval_btn, stop_btn]
)
eval_run_event = eval_run_event.then(
fn=UIController.run_evaluation,
inputs=[engine_state, *train_ui["eval_params"]],
outputs=[train_ui["outputs"][0]] # Output only to log, not plot
).then(
fn=lambda: (
gr.update(interactive=True),
gr.update(interactive=True),
gr.update(visible=True),
gr.update(visible=False)
),
outputs=[run_btn, reload_btn, eval_btn, stop_btn]
)
stop_btn.click(
fn=UIController.stop_process,
inputs=[engine_state],
cancels=[train_run_event, eval_run_event],
outputs=None,
queue=False
)
reload_btn.click(lock_ui, outputs=action_buttons + [push_btn, zip_btn]).then(
fn=UIController.handle_reset,
inputs=[engine_state, train_ui["model_input"]],
outputs=[train_ui["outputs"][0]]
).then(unlock_ui, outputs=action_buttons).then(
fn=UIController.update_hub_interactive,
inputs=[engine_state, username_state],
outputs=[repo_input, push_btn, zip_btn]
)
# 4. Export Tab
zip_btn.click(lock_ui, outputs=action_buttons + [push_btn, zip_btn]).then(
fn=UIController.zip_model,
inputs=[engine_state],
outputs=[export_ui["zip_controls"][1]]
).then(unlock_ui, outputs=action_buttons).then(
fn=UIController.update_hub_interactive,
inputs=[engine_state, username_state],
outputs=[repo_input, push_btn, zip_btn]
)
repo_input.change(
fn=UIController.update_repo_preview,
inputs=[username_state, repo_input],
outputs=[export_ui["hub_controls"][2]]
)
push_btn.click(lock_ui, outputs=action_buttons + [push_btn, zip_btn]).then(
fn=UIController.upload_model,
inputs=[engine_state, repo_input],
outputs=[export_ui["hub_controls"][3]]
).then(unlock_ui, outputs=action_buttons).then(
fn=UIController.update_hub_interactive,
inputs=[engine_state, username_state],
outputs=[repo_input, push_btn, zip_btn]
)
return demo