bebechien's picture
Upload folder using huggingface_hub
99bb79f verified
raw
history blame
6.51 kB
import gradio as gr
from engine import FunctionGemmaEngine
def build_interface(engine: FunctionGemmaEngine) -> gr.Blocks:
with gr.Blocks(title="FunctionGemma Modkit") as demo:
gr.Markdown("# πŸ€– FunctionGemma Modkit: Fine-Tuning")
gr.Markdown("Fine-tune FunctionGemma to understand your custom functions.<br>See [README](https://huggingface.co/spaces/google/functiongemma-modkit/blob/main/README.md) for more details.")
with gr.Tabs():
# --- TAB 1: PREPARING DATASET ---
with gr.TabItem("1. Preparing Dataset"):
gr.Markdown("### πŸ› οΈ Tool Schema & Data Import")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("**Step 1: Define Functions**\n\nEdit the JSON schema below to define the tools the model should learn.")
tools_editor = gr.Code(
value=engine.get_tools_json(),
language="json",
label="Tool Definitions (JSON Schema)",
lines=15
)
update_tools_btn = gr.Button("πŸ’Ύ Update Tool Schema")
tools_status = gr.Markdown("")
with gr.Column(scale=1):
gr.Markdown("**Step 2: Upload Data (Optional)**\n\nUpload a CSV file to replace the default dataset ([bebechien/SimpleToolCalling](https://huggingface.co/datasets/bebechien/SimpleToolCalling)).")
gr.Markdown("**Example CSV Row:**<br>Format: `[User Prompt, Tool Name, Tool Args JSON]`\n```csv\n\"What is the weather in London?\", \"get_weather\", \"{\"\"location\"\": \"\"London, UK\"\"}\"\n```")
import_file = gr.File(
label="Upload Dataset (.csv)",
file_types=[".csv"],
height=100
)
import_status = gr.Markdown("")
# --- TAB 2: TRAINING ---
with gr.TabItem("2. Training"):
gr.Markdown("### πŸš€ Fine-Tuning Configuration")
with gr.Group():
gr.Markdown("**Hyperparameters**")
with gr.Row():
param_epochs = gr.Slider(
minimum=1, maximum=20, value=5, step=1,
label="Epochs", info="Total training passes"
)
param_lr = gr.Number(
value=5e-5,
label="Learning Rate",
info="e.g. 5e-5"
)
param_test_size = gr.Slider(
minimum=0.1, maximum=0.9, value=0.2, step=0.05,
label="Test Split", info="Validation data ratio. Typical value is 0.2 (80% for training, 20% for testing)"
)
param_shuffle = gr.Checkbox(
value=True,
label="Shuffle Data",
info="Randomize before split"
)
with gr.Row():
run_training_btn = gr.Button("πŸš€ Run Fine-Tuning", variant="primary", scale=2)
stop_training_btn = gr.Button("πŸ›‘ Stop", variant="stop", visible=False, scale=1)
clear_reload_btn = gr.Button("πŸ”„ Reset", variant="secondary", scale=1)
with gr.Row():
# Left column: Text Logs
output_display = gr.Textbox(
lines=20,
label="Logs & Results",
value="Ready.",
interactive=False,
autoscroll=True
)
# Right column: Plot (NEW)
loss_plot = gr.Plot(label="Training Metrics")
# --- TAB 3: EXPORT ---
with gr.TabItem("3. Export"):
gr.Markdown("### πŸ“¦ Export Trained Model")
gr.Markdown("Download the fine-tuned LoRA adapters or full model weights (depending on configuration) as a ZIP file.")
with gr.Row():
zip_btn = gr.Button("⬇️ Prepare Model ZIP", variant="primary", scale=1)
download_file = gr.File(label="Download Archive", interactive=False, scale=2)
# --- EVENT WIRING ---
# Tab 1: Tools
update_tools_btn.click(
fn=engine.update_tools,
inputs=[tools_editor],
outputs=[tools_status]
)
# Tab 1: File Import
import_file.upload(
fn=engine.load_csv,
inputs=[import_file],
outputs=[import_status]
)
# Tab 2: Training
run_training_btn.click(
fn=lambda: (
gr.update(visible=False), # Hide Run
gr.update(interactive=False), # Disable Reset
gr.update(visible=True) # Show Stop
),
outputs=[run_training_btn, clear_reload_btn, stop_training_btn]
).then(
fn=engine.run_training_pipeline,
inputs=[param_epochs, param_lr, param_test_size, param_shuffle],
outputs=[output_display, loss_plot],
).then(
fn=lambda: (
gr.update(visible=True), # Show Run
gr.update(interactive=True), # Enable Reset
gr.update(visible=False) # Hide Stop
),
outputs=[run_training_btn, clear_reload_btn, stop_training_btn]
)
# Tab 2: Stop
stop_training_btn.click(
fn=lambda: (engine.trigger_stop(), "Stopping...")[1],
outputs=None
)
# Tab 2: Reset
clear_reload_btn.click(
fn=engine.refresh_data_and_model,
outputs=[output_display]
)
# Tab 3: Download
def handle_zip():
path = engine.get_zip_path()
if path:
return gr.update(value=path, visible=True)
return gr.update(value=None, visible=False)
zip_btn.click(
fn=handle_zip,
outputs=[download_file]
)
return demo