Spaces:
Sleeping
Sleeping
| import gc | |
| import json | |
| import os | |
| import re | |
| import tempfile | |
| import matplotlib | |
| matplotlib.use("Agg") # headless backend for Spaces | |
| import matplotlib.pyplot as plt | |
| import gradio as gr | |
| import torch | |
| from datasets import load_dataset | |
| from huggingface_hub import hf_hub_download | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback | |
| from trl import SFTConfig, SFTTrainer | |
| # ---------------------------- | |
| # Config | |
| # ---------------------------- | |
| # Both the model and the dataset are gated. Accept the licenses and set HF_TOKEN | |
| # (a Space "secret" works) before launching: | |
| # model: https://huggingface.co/google/functiongemma-270m-it | |
| # dataset: https://huggingface.co/datasets/google/mobile-actions | |
| MODEL_ID = "google/functiongemma-270m-it" | |
| DATASET_REPO = "google/mobile-actions" | |
| DATASET_FILE = "dataset.jsonl" | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DTYPE = torch.bfloat16 if (DEVICE == "cuda" and torch.cuda.is_bf16_supported()) else torch.float32 | |
| DEFAULT_DEVELOPER = ( | |
| "Current date and time given in YYYY-MM-DDTHH:MM:SS format: 2024-11-15T05:59:00. " | |
| "You are a model that can do function calling with the following functions" | |
| ) | |
| # ---------------------------- | |
| # Lazy singletons | |
| # ---------------------------- | |
| _TOKENIZER = None | |
| _BASE_MODEL = None | |
| _RAW = None # raw dataset (each row['text'] is a JSON string) | |
| _TOOLS = None # shared tool schema from the dataset | |
| _PROCESSED = None # prompt/completion/split formatted dataset | |
| _MAXTOK = None # max_length to use for SFT | |
| def get_tokenizer(): | |
| global _TOKENIZER | |
| if _TOKENIZER is None: | |
| _TOKENIZER = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN) | |
| return _TOKENIZER | |
| def load_fresh_model(): | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=DTYPE, | |
| attn_implementation="eager", # recommended for Gemma 3 | |
| token=HF_TOKEN, | |
| ) | |
| tok = get_tokenizer() | |
| if tok.pad_token_id is not None: | |
| model.config.pad_token_id = tok.pad_token_id | |
| model.to(DEVICE) | |
| return model | |
| def get_base_model(): | |
| global _BASE_MODEL | |
| if _BASE_MODEL is None: | |
| _BASE_MODEL = load_fresh_model() | |
| _BASE_MODEL.eval() | |
| return _BASE_MODEL | |
| # ---------------------------- | |
| # Dataset: download, format into prompt/completion, split | |
| # ---------------------------- | |
| def apply_format(sample): | |
| tok = get_tokenizer() | |
| t = json.loads(sample["text"]) | |
| full = tok.apply_chat_template( | |
| t["messages"], tools=t["tools"], tokenize=False, add_generation_prompt=False | |
| ) | |
| prompt = tok.apply_chat_template( | |
| t["messages"][:-1], tools=t["tools"], tokenize=False, add_generation_prompt=True | |
| ) | |
| completion = full[len(prompt):] | |
| return {"prompt": prompt, "completion": completion, "split": t["metadata"]} | |
| def ensure_dataset(): | |
| """Download + format once; cache raw rows, tools, processed splits, max_length.""" | |
| global _RAW, _TOOLS, _PROCESSED, _MAXTOK | |
| if _PROCESSED is not None: | |
| return | |
| path = hf_hub_download(repo_id=DATASET_REPO, filename=DATASET_FILE, | |
| repo_type="dataset", token=HF_TOKEN) | |
| _RAW = load_dataset("text", data_files=path, encoding="utf-8")["train"].shuffle(seed=7) | |
| _TOOLS = json.loads(_RAW[0]["text"])["tools"] | |
| tok = get_tokenizer() | |
| _PROCESSED = _RAW.map(apply_format) | |
| longest = max(_PROCESSED, key=lambda e: len(e["prompt"] + e["completion"])) | |
| longest_tokens = len(tok.tokenize(longest["prompt"] + longest["completion"])) | |
| _MAXTOK = longest_tokens + 100 | |
| def get_tools(): | |
| ensure_dataset() | |
| return _TOOLS | |
| # ---------------------------- | |
| # Function-call parsing (from the notebook) | |
| # ---------------------------- | |
| def extract_function_call(model_output): | |
| results = [] | |
| call_pattern = r"<start_function_call>(.*?)<end_function_call>" | |
| for raw_call in re.findall(call_pattern, model_output, re.DOTALL): | |
| if not raw_call.strip().startswith("call:"): | |
| continue | |
| try: | |
| pre_brace, args_segment = raw_call.split("{", 1) | |
| function_name = pre_brace.replace("call:", "").strip() | |
| args_content = args_segment.strip() | |
| if args_content.endswith("}"): | |
| args_content = args_content[:-1] | |
| arguments = {} | |
| arg_pattern = r"(?P<key>[^:,]*?):<escape>(?P<value>.*?)<escape>" | |
| for m in re.finditer(arg_pattern, args_content, re.DOTALL): | |
| arguments[m.group("key").strip()] = m.group("value") | |
| results.append({"function": {"name": function_name, "arguments": arguments}}) | |
| except ValueError: | |
| continue | |
| return results | |
| def extract_text(model_output): | |
| if not model_output or model_output.startswith("<start_function_call>"): | |
| return None | |
| return model_output.replace("<end_of_turn>", "").strip() | |
| def pretty_calls(calls): | |
| if not calls: | |
| return "(no function call)" | |
| lines = [] | |
| for c in calls: | |
| fn = c["function"]["name"] | |
| args = ", ".join(f"{k}={v!r}" for k, v in c["function"]["arguments"].items()) | |
| lines.append(f"{fn}({args})") | |
| return "\n".join(lines) | |
| # ---------------------------- | |
| # Generation | |
| # ---------------------------- | |
| def generate_fc(model, user_prompt, developer_content, max_new_tokens=256, temperature=0.0): | |
| tok = get_tokenizer() | |
| model.eval() | |
| messages = [ | |
| {"role": "developer", "content": developer_content}, | |
| {"role": "user", "content": user_prompt}, | |
| ] | |
| prompt = tok.apply_chat_template( | |
| messages, tools=get_tools(), tokenize=False, add_generation_prompt=True | |
| ) | |
| inputs = tok(prompt, return_tensors="pt").to(model.device) | |
| gen_kwargs = dict(max_new_tokens=int(max_new_tokens), pad_token_id=tok.pad_token_id) | |
| if temperature and temperature > 0: | |
| gen_kwargs.update(do_sample=True, temperature=float(temperature), top_p=0.9) | |
| else: | |
| gen_kwargs.update(do_sample=False) # greedy: best for function calling | |
| out = model.generate(**inputs, **gen_kwargs) | |
| raw = tok.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=False) | |
| raw = raw.replace(tok.eos_token or "", "").strip() | |
| return raw | |
| # ---------------------------- | |
| # Exact-match scoring on an eval subset | |
| # ---------------------------- | |
| def score_model(model, n_examples, progress=None, desc=""): | |
| ensure_dataset() | |
| eval_rows = [r for r in _RAW if json.loads(r["text"])["metadata"] == "eval"] | |
| eval_rows = eval_rows[: int(n_examples)] | |
| correct = 0 | |
| for i, row in enumerate(eval_rows): | |
| msgs = json.loads(row["text"])["messages"] | |
| user_msg = next((m["content"] for m in msgs if m["role"] == "user"), "") | |
| target = msgs[-1].get("tool_calls", []) or [] | |
| target_names = [fc["function"]["name"] for fc in target] | |
| target_args = [dict(sorted(fc["function"]["arguments"].items())) for fc in target] | |
| raw = generate_fc(model, user_msg, DEFAULT_DEVELOPER, max_new_tokens=_MAXTOK) | |
| pred = extract_function_call(raw) | |
| pred_names = [fc["function"]["name"] for fc in pred] | |
| pred_args = [dict(sorted(fc["function"]["arguments"].items())) for fc in pred] | |
| if target_names == pred_names and target_args == pred_args: | |
| correct += 1 | |
| if progress is not None: | |
| progress((i + 1) / len(eval_rows), desc=f"{desc} {i + 1}/{len(eval_rows)}") | |
| return correct / max(1, len(eval_rows)), len(eval_rows) | |
| # ---------------------------- | |
| # Loss plot (train + eval) from trainer log history | |
| # ---------------------------- | |
| def make_loss_plot(log_history): | |
| train_x = [l["step"] for l in log_history if "loss" in l] | |
| train_y = [l["loss"] for l in log_history if "loss" in l] | |
| eval_x = [l["step"] for l in log_history if "eval_loss" in l] | |
| eval_y = [l["eval_loss"] for l in log_history if "eval_loss" in l] | |
| fig, ax = plt.subplots(figsize=(6, 3.4)) | |
| fig.patch.set_facecolor("#ffffff") | |
| ax.set_facecolor("#fbfbfd") | |
| if train_y: | |
| ax.plot(train_x, train_y, color="#7c3aed", linewidth=2.2, label="Training loss") | |
| if eval_y: | |
| ax.plot(eval_x, eval_y, color="#db2777", linewidth=2.0, | |
| marker="o", markersize=4, label="Validation loss") | |
| ax.set_xlabel("Step", fontsize=11) | |
| ax.set_ylabel("Loss", fontsize=11) | |
| ax.set_title("FunctionGemma SFT loss 📉", fontsize=12, fontweight="bold", color="#1f2937") | |
| ax.grid(True, linestyle="--", alpha=0.35) | |
| if train_y or eval_y: | |
| ax.legend(frameon=False) | |
| for spine in ["top", "right"]: | |
| ax.spines[spine].set_visible(False) | |
| fig.tight_layout() | |
| return fig | |
| # ---------------------------- | |
| # Gradio <-> Trainer progress bridge | |
| # ---------------------------- | |
| class GradioCallback(TrainerCallback): | |
| def __init__(self, progress): | |
| self.progress = progress | |
| def on_step_end(self, args, state, control, **kwargs): | |
| total = state.max_steps or 1 | |
| self.progress(state.global_step / total, | |
| desc=f"SFT step {state.global_step}/{total}") | |
| # ---------------------------- | |
| # Actions | |
| # ---------------------------- | |
| def base_only(user_prompt, developer_content, output_length, temperature): | |
| if not user_prompt.strip(): | |
| return "⚠️ Enter a mobile-action request first.", "" | |
| raw = generate_fc(get_base_model(), user_prompt, developer_content, | |
| output_length, temperature) | |
| return raw, pretty_calls(extract_function_call(raw)) | |
| def finetune_and_compare( | |
| user_prompt, | |
| developer_content, | |
| epochs, | |
| train_subset, | |
| eval_subset, | |
| learning_rate, | |
| batch_size, | |
| grad_accum, | |
| output_length, | |
| temperature, | |
| progress=gr.Progress(), | |
| ): | |
| if not user_prompt.strip(): | |
| return None, "⚠️ Enter a mobile-action request first.", "", "", "", "" | |
| progress(0.0, desc="Downloading + formatting dataset") | |
| ensure_dataset() | |
| train_ds = _PROCESSED.filter(lambda e: e["split"] == "train") | |
| eval_ds = _PROCESSED.filter(lambda e: e["split"] == "eval") | |
| train_ds = train_ds.select(range(min(int(train_subset), len(train_ds)))) | |
| eval_ds = eval_ds.select(range(min(int(eval_subset), len(eval_ds)))) | |
| # score base model first (re-used for the headline comparison) | |
| base_acc, n_eval = score_model(get_base_model(), eval_subset, progress, "Scoring base") | |
| torch.manual_seed(7) | |
| model = load_fresh_model() | |
| if DEVICE == "cuda": | |
| model.gradient_checkpointing_enable() | |
| model.config.use_cache = False | |
| total_steps = max(1, (len(train_ds) // (int(batch_size) * int(grad_accum)))) * int(epochs) | |
| with tempfile.TemporaryDirectory() as out_dir: | |
| cfg = SFTConfig( | |
| output_dir=out_dir, | |
| num_train_epochs=float(epochs), | |
| per_device_train_batch_size=int(batch_size), | |
| gradient_accumulation_steps=int(grad_accum), | |
| learning_rate=float(learning_rate), | |
| lr_scheduler_type="cosine", | |
| logging_strategy="steps", | |
| logging_steps=1, | |
| eval_strategy="steps" if len(eval_ds) else "no", | |
| eval_steps=max(1, total_steps // 4), | |
| save_strategy="no", | |
| max_length=_MAXTOK, | |
| gradient_checkpointing=(DEVICE == "cuda"), | |
| packing=False, | |
| optim="adamw_torch_fused" if DEVICE == "cuda" else "adamw_torch", | |
| bf16=(DTYPE == torch.bfloat16), | |
| completion_only_loss=True, # loss on the assistant turn only | |
| report_to="none", | |
| seed=7, | |
| ) | |
| trainer = SFTTrainer( | |
| model=model, | |
| args=cfg, | |
| train_dataset=train_ds, | |
| eval_dataset=eval_ds if len(eval_ds) else None, | |
| callbacks=[GradioCallback(progress)], | |
| ) | |
| trainer.train() | |
| log_history = list(trainer.state.log_history) | |
| # switch back to inference mode | |
| if DEVICE == "cuda": | |
| model.gradient_checkpointing_disable() | |
| model.config.use_cache = True | |
| fig = make_loss_plot(log_history) | |
| # tuned model outputs for the user's prompt | |
| tuned_raw = generate_fc(model, user_prompt, developer_content, output_length, temperature) | |
| tuned_calls = pretty_calls(extract_function_call(tuned_raw)) | |
| # score tuned model | |
| tuned_acc, _ = score_model(model, eval_subset, progress, "Scoring tuned") | |
| losses = [l["loss"] for l in log_history if "loss" in l] | |
| first_loss = losses[0] if losses else 0.0 | |
| last_loss = losses[-1] if losses else 0.0 | |
| status = ( | |
| f"✅ Full fine-tuned **FunctionGemma 270M-IT** on **{len(train_ds)} train examples** " | |
| f"for **{epochs} epoch(s)** ({total_steps} steps).\n\n" | |
| f"Loss **{first_loss:.3f} → {last_loss:.3f}**. " | |
| f"Exact-match function-call accuracy on {n_eval} eval examples: " | |
| f"**base {base_acc:.0%} → tuned {tuned_acc:.0%}**.\n\n" | |
| f"Device: `{DEVICE}` · dtype: `{str(DTYPE).replace('torch.', '')}` · " | |
| f"max_length: `{_MAXTOK}`." | |
| ) | |
| del trainer, model | |
| gc.collect() | |
| if DEVICE == "cuda": | |
| torch.cuda.empty_cache() | |
| return fig, status, tuned_raw, tuned_calls, f"Base accuracy: {base_acc:.0%}", \ | |
| f"Tuned accuracy: {tuned_acc:.0%}" | |
| EXPLANATION = """ | |
| # 📱 FunctionGemma 270M — Mobile Actions SFT | |
| Fine-tune Google's **FunctionGemma 270M-IT** to turn phone requests | |
| ("turn on the flashlight", "schedule a team meeting tomorrow at 4pm") into | |
| **function calls**, using the gated [`google/mobile-actions`](https://huggingface.co/datasets/google/mobile-actions) | |
| dataset and TRL's `SFTTrainer`. | |
| This is a full fine-tune (no LoRA) in **prompt/completion** format with | |
| `completion_only_loss=True`, so loss is computed only on the assistant's call. | |
| The chat template is applied with the dataset's `tools=` schema. Pick a request, | |
| run SFT, and watch the exact-match function-call accuracy go up. | |
| *Omitted from the original notebook: Hugging Face Hub upload and the | |
| `.litertlm` / `ai-edge-torch` on-device conversion (not Space-friendly).* | |
| """ | |
| CUSTOM_CSS = """ | |
| .gradio-container { max-width: 1100px !important; margin: auto !important; } | |
| #hero { | |
| background: linear-gradient(135deg, #7c3aed 0%, #2563eb 50%, #06b6d4 100%); | |
| border-radius: 18px; padding: 6px 26px; color: white; | |
| box-shadow: 0 10px 30px rgba(37, 99, 235, 0.25); margin-bottom: 8px; | |
| } | |
| #hero h1 { color: white !important; font-size: 2.0rem !important; } | |
| #hero p, #hero li, #hero strong { color: rgba(255,255,255,0.95) !important; } | |
| #hero a { color: #bae6fd !important; } | |
| .panel-card { | |
| border-radius: 16px !important; padding: 16px !important; | |
| background: var(--block-background-fill); | |
| box-shadow: 0 4px 18px rgba(0,0,0,0.06); | |
| border: 1px solid var(--border-color-primary); | |
| } | |
| #train-btn { font-weight: 700 !important; } | |
| footer { visibility: hidden; } | |
| """ | |
| THEME = gr.themes.Soft( | |
| primary_hue="blue", | |
| secondary_hue="cyan", | |
| font=[gr.themes.GoogleFont("Quicksand"), "system-ui", "sans-serif"], | |
| ) | |
| EXAMPLE_PROMPTS = [ | |
| 'Schedule a "team meeting" tomorrow at 4pm.', | |
| "Turn on the flashlight.", | |
| "Show me Besançon, France on the map.", | |
| "Open the WiFi settings.", | |
| "Create a contact for Alex with number 555-0123.", | |
| ] | |
| with gr.Blocks(title="FunctionGemma 270M Mobile Actions SFT", theme=THEME, css=CUSTOM_CSS) as demo: | |
| with gr.Group(elem_id="hero"): | |
| gr.Markdown(EXPLANATION) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| with gr.Group(elem_classes="panel-card"): | |
| gr.Markdown("### ⚙️ Controls") | |
| user_prompt = gr.Textbox( | |
| value=EXAMPLE_PROMPTS[0], lines=2, | |
| label="Mobile-action request (user message)", | |
| ) | |
| gr.Examples(EXAMPLE_PROMPTS, inputs=user_prompt, label="Try one") | |
| developer_content = gr.Textbox( | |
| value=DEFAULT_DEVELOPER, lines=3, | |
| label="Developer message (context: date/time + role)", | |
| ) | |
| with gr.Row(): | |
| epochs = gr.Slider(1, 3, value=1, step=1, label="Epochs") | |
| train_subset = gr.Slider( | |
| 50, 1000, value=200, step=50, label="Train subset", | |
| info="Fewer = faster.", | |
| ) | |
| eval_subset = gr.Slider( | |
| 10, 100, value=30, step=10, label="Eval examples (for scoring)", | |
| ) | |
| with gr.Accordion("Advanced", open=False): | |
| learning_rate = gr.Slider(1e-6, 5e-5, value=1e-5, step=1e-6, label="Learning rate") | |
| batch_size = gr.Slider(1, 8, value=4, step=1, label="Batch size") | |
| grad_accum = gr.Slider(1, 16, value=8, step=1, label="Grad accumulation") | |
| output_length = gr.Slider(64, 512, value=256, step=32, label="Max new tokens") | |
| temperature = gr.Slider(0.0, 1.0, value=0.0, step=0.1, | |
| label="Temperature (0 = greedy, best for tools)") | |
| with gr.Row(): | |
| base_btn = gr.Button("🎲 Ask base model", variant="secondary") | |
| train_btn = gr.Button("🚀 Fine-tune & Compare", variant="primary", elem_id="train-btn") | |
| with gr.Column(scale=1): | |
| with gr.Group(elem_classes="panel-card"): | |
| gr.Markdown("### 🔍 Results") | |
| with gr.Row(): | |
| base_acc_box = gr.Markdown() | |
| tuned_acc_box = gr.Markdown() | |
| with gr.Tab("Parsed calls"): | |
| base_calls = gr.Textbox(lines=4, label="🎲 Base model call(s)") | |
| tuned_calls = gr.Textbox(lines=4, label="✨ Fine-tuned call(s)") | |
| with gr.Tab("Raw output"): | |
| tuned_raw = gr.Textbox(lines=8, label="✨ Fine-tuned raw output") | |
| loss_plot = gr.Plot(label="📉 Training / validation loss") | |
| status = gr.Markdown() | |
| base_btn.click( | |
| base_only, | |
| inputs=[user_prompt, developer_content, output_length, temperature], | |
| outputs=[tuned_raw, base_calls], | |
| ) | |
| train_btn.click( | |
| finetune_and_compare, | |
| inputs=[user_prompt, developer_content, epochs, train_subset, eval_subset, | |
| learning_rate, batch_size, grad_accum, output_length, temperature], | |
| outputs=[loss_plot, status, tuned_raw, tuned_calls, base_acc_box, tuned_acc_box], | |
| ) | |
| with gr.Accordion("💬 Notes", open=False): | |
| gr.Markdown( | |
| """ | |
| - **Greedy decoding** (temperature 0) is best for function calling — you want the | |
| single most likely call, not a creative one. | |
| - **Exact-match** accuracy is a lower bound: a call with equivalent arguments | |
| (e.g. a slightly reworded `query`) counts as wrong but may still be acceptable. | |
| - A GPU is strongly recommended. On CPU, training and scoring will be slow — | |
| shrink the train/eval subsets. | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |