import gc import json import os import re import tempfile import matplotlib matplotlib.use("Agg") # headless backend for Spaces import matplotlib.pyplot as plt import gradio as gr import torch from datasets import load_dataset from huggingface_hub import hf_hub_download from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback from trl import SFTConfig, SFTTrainer # ---------------------------- # Config # ---------------------------- # Both the model and the dataset are gated. Accept the licenses and set HF_TOKEN # (a Space "secret" works) before launching: # model: https://huggingface.co/google/functiongemma-270m-it # dataset: https://huggingface.co/datasets/google/mobile-actions MODEL_ID = "google/functiongemma-270m-it" DATASET_REPO = "google/mobile-actions" DATASET_FILE = "dataset.jsonl" HF_TOKEN = os.environ.get("HF_TOKEN", None) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.bfloat16 if (DEVICE == "cuda" and torch.cuda.is_bf16_supported()) else torch.float32 DEFAULT_DEVELOPER = ( "Current date and time given in YYYY-MM-DDTHH:MM:SS format: 2024-11-15T05:59:00. " "You are a model that can do function calling with the following functions" ) # ---------------------------- # Lazy singletons # ---------------------------- _TOKENIZER = None _BASE_MODEL = None _RAW = None # raw dataset (each row['text'] is a JSON string) _TOOLS = None # shared tool schema from the dataset _PROCESSED = None # prompt/completion/split formatted dataset _MAXTOK = None # max_length to use for SFT def get_tokenizer(): global _TOKENIZER if _TOKENIZER is None: _TOKENIZER = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN) return _TOKENIZER def load_fresh_model(): model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=DTYPE, attn_implementation="eager", # recommended for Gemma 3 token=HF_TOKEN, ) tok = get_tokenizer() if tok.pad_token_id is not None: model.config.pad_token_id = tok.pad_token_id model.to(DEVICE) return model def get_base_model(): global _BASE_MODEL if _BASE_MODEL is None: _BASE_MODEL = load_fresh_model() _BASE_MODEL.eval() return _BASE_MODEL # ---------------------------- # Dataset: download, format into prompt/completion, split # ---------------------------- def apply_format(sample): tok = get_tokenizer() t = json.loads(sample["text"]) full = tok.apply_chat_template( t["messages"], tools=t["tools"], tokenize=False, add_generation_prompt=False ) prompt = tok.apply_chat_template( t["messages"][:-1], tools=t["tools"], tokenize=False, add_generation_prompt=True ) completion = full[len(prompt):] return {"prompt": prompt, "completion": completion, "split": t["metadata"]} def ensure_dataset(): """Download + format once; cache raw rows, tools, processed splits, max_length.""" global _RAW, _TOOLS, _PROCESSED, _MAXTOK if _PROCESSED is not None: return path = hf_hub_download(repo_id=DATASET_REPO, filename=DATASET_FILE, repo_type="dataset", token=HF_TOKEN) _RAW = load_dataset("text", data_files=path, encoding="utf-8")["train"].shuffle(seed=7) _TOOLS = json.loads(_RAW[0]["text"])["tools"] tok = get_tokenizer() _PROCESSED = _RAW.map(apply_format) longest = max(_PROCESSED, key=lambda e: len(e["prompt"] + e["completion"])) longest_tokens = len(tok.tokenize(longest["prompt"] + longest["completion"])) _MAXTOK = longest_tokens + 100 def get_tools(): ensure_dataset() return _TOOLS # ---------------------------- # Function-call parsing (from the notebook) # ---------------------------- def extract_function_call(model_output): results = [] call_pattern = r"(.*?)" for raw_call in re.findall(call_pattern, model_output, re.DOTALL): if not raw_call.strip().startswith("call:"): continue try: pre_brace, args_segment = raw_call.split("{", 1) function_name = pre_brace.replace("call:", "").strip() args_content = args_segment.strip() if args_content.endswith("}"): args_content = args_content[:-1] arguments = {} arg_pattern = r"(?P[^:,]*?):(?P.*?)" for m in re.finditer(arg_pattern, args_content, re.DOTALL): arguments[m.group("key").strip()] = m.group("value") results.append({"function": {"name": function_name, "arguments": arguments}}) except ValueError: continue return results def extract_text(model_output): if not model_output or model_output.startswith(""): return None return model_output.replace("", "").strip() def pretty_calls(calls): if not calls: return "(no function call)" lines = [] for c in calls: fn = c["function"]["name"] args = ", ".join(f"{k}={v!r}" for k, v in c["function"]["arguments"].items()) lines.append(f"{fn}({args})") return "\n".join(lines) # ---------------------------- # Generation # ---------------------------- @torch.no_grad() def generate_fc(model, user_prompt, developer_content, max_new_tokens=256, temperature=0.0): tok = get_tokenizer() model.eval() messages = [ {"role": "developer", "content": developer_content}, {"role": "user", "content": user_prompt}, ] prompt = tok.apply_chat_template( messages, tools=get_tools(), tokenize=False, add_generation_prompt=True ) inputs = tok(prompt, return_tensors="pt").to(model.device) gen_kwargs = dict(max_new_tokens=int(max_new_tokens), pad_token_id=tok.pad_token_id) if temperature and temperature > 0: gen_kwargs.update(do_sample=True, temperature=float(temperature), top_p=0.9) else: gen_kwargs.update(do_sample=False) # greedy: best for function calling out = model.generate(**inputs, **gen_kwargs) raw = tok.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=False) raw = raw.replace(tok.eos_token or "", "").strip() return raw # ---------------------------- # Exact-match scoring on an eval subset # ---------------------------- def score_model(model, n_examples, progress=None, desc=""): ensure_dataset() eval_rows = [r for r in _RAW if json.loads(r["text"])["metadata"] == "eval"] eval_rows = eval_rows[: int(n_examples)] correct = 0 for i, row in enumerate(eval_rows): msgs = json.loads(row["text"])["messages"] user_msg = next((m["content"] for m in msgs if m["role"] == "user"), "") target = msgs[-1].get("tool_calls", []) or [] target_names = [fc["function"]["name"] for fc in target] target_args = [dict(sorted(fc["function"]["arguments"].items())) for fc in target] raw = generate_fc(model, user_msg, DEFAULT_DEVELOPER, max_new_tokens=_MAXTOK) pred = extract_function_call(raw) pred_names = [fc["function"]["name"] for fc in pred] pred_args = [dict(sorted(fc["function"]["arguments"].items())) for fc in pred] if target_names == pred_names and target_args == pred_args: correct += 1 if progress is not None: progress((i + 1) / len(eval_rows), desc=f"{desc} {i + 1}/{len(eval_rows)}") return correct / max(1, len(eval_rows)), len(eval_rows) # ---------------------------- # Loss plot (train + eval) from trainer log history # ---------------------------- def make_loss_plot(log_history): train_x = [l["step"] for l in log_history if "loss" in l] train_y = [l["loss"] for l in log_history if "loss" in l] eval_x = [l["step"] for l in log_history if "eval_loss" in l] eval_y = [l["eval_loss"] for l in log_history if "eval_loss" in l] fig, ax = plt.subplots(figsize=(6, 3.4)) fig.patch.set_facecolor("#ffffff") ax.set_facecolor("#fbfbfd") if train_y: ax.plot(train_x, train_y, color="#7c3aed", linewidth=2.2, label="Training loss") if eval_y: ax.plot(eval_x, eval_y, color="#db2777", linewidth=2.0, marker="o", markersize=4, label="Validation loss") ax.set_xlabel("Step", fontsize=11) ax.set_ylabel("Loss", fontsize=11) ax.set_title("FunctionGemma SFT loss πŸ“‰", fontsize=12, fontweight="bold", color="#1f2937") ax.grid(True, linestyle="--", alpha=0.35) if train_y or eval_y: ax.legend(frameon=False) for spine in ["top", "right"]: ax.spines[spine].set_visible(False) fig.tight_layout() return fig # ---------------------------- # Gradio <-> Trainer progress bridge # ---------------------------- class GradioCallback(TrainerCallback): def __init__(self, progress): self.progress = progress def on_step_end(self, args, state, control, **kwargs): total = state.max_steps or 1 self.progress(state.global_step / total, desc=f"SFT step {state.global_step}/{total}") # ---------------------------- # Actions # ---------------------------- def base_only(user_prompt, developer_content, output_length, temperature): if not user_prompt.strip(): return "⚠️ Enter a mobile-action request first.", "" raw = generate_fc(get_base_model(), user_prompt, developer_content, output_length, temperature) return raw, pretty_calls(extract_function_call(raw)) def finetune_and_compare( user_prompt, developer_content, epochs, train_subset, eval_subset, learning_rate, batch_size, grad_accum, output_length, temperature, progress=gr.Progress(), ): if not user_prompt.strip(): return None, "⚠️ Enter a mobile-action request first.", "", "", "", "" progress(0.0, desc="Downloading + formatting dataset") ensure_dataset() train_ds = _PROCESSED.filter(lambda e: e["split"] == "train") eval_ds = _PROCESSED.filter(lambda e: e["split"] == "eval") train_ds = train_ds.select(range(min(int(train_subset), len(train_ds)))) eval_ds = eval_ds.select(range(min(int(eval_subset), len(eval_ds)))) # score base model first (re-used for the headline comparison) base_acc, n_eval = score_model(get_base_model(), eval_subset, progress, "Scoring base") torch.manual_seed(7) model = load_fresh_model() if DEVICE == "cuda": model.gradient_checkpointing_enable() model.config.use_cache = False total_steps = max(1, (len(train_ds) // (int(batch_size) * int(grad_accum)))) * int(epochs) with tempfile.TemporaryDirectory() as out_dir: cfg = SFTConfig( output_dir=out_dir, num_train_epochs=float(epochs), per_device_train_batch_size=int(batch_size), gradient_accumulation_steps=int(grad_accum), learning_rate=float(learning_rate), lr_scheduler_type="cosine", logging_strategy="steps", logging_steps=1, eval_strategy="steps" if len(eval_ds) else "no", eval_steps=max(1, total_steps // 4), save_strategy="no", max_length=_MAXTOK, gradient_checkpointing=(DEVICE == "cuda"), packing=False, optim="adamw_torch_fused" if DEVICE == "cuda" else "adamw_torch", bf16=(DTYPE == torch.bfloat16), completion_only_loss=True, # loss on the assistant turn only report_to="none", seed=7, ) trainer = SFTTrainer( model=model, args=cfg, train_dataset=train_ds, eval_dataset=eval_ds if len(eval_ds) else None, callbacks=[GradioCallback(progress)], ) trainer.train() log_history = list(trainer.state.log_history) # switch back to inference mode if DEVICE == "cuda": model.gradient_checkpointing_disable() model.config.use_cache = True fig = make_loss_plot(log_history) # tuned model outputs for the user's prompt tuned_raw = generate_fc(model, user_prompt, developer_content, output_length, temperature) tuned_calls = pretty_calls(extract_function_call(tuned_raw)) # score tuned model tuned_acc, _ = score_model(model, eval_subset, progress, "Scoring tuned") losses = [l["loss"] for l in log_history if "loss" in l] first_loss = losses[0] if losses else 0.0 last_loss = losses[-1] if losses else 0.0 status = ( f"βœ… Full fine-tuned **FunctionGemma 270M-IT** on **{len(train_ds)} train examples** " f"for **{epochs} epoch(s)** ({total_steps} steps).\n\n" f"Loss **{first_loss:.3f} β†’ {last_loss:.3f}**. " f"Exact-match function-call accuracy on {n_eval} eval examples: " f"**base {base_acc:.0%} β†’ tuned {tuned_acc:.0%}**.\n\n" f"Device: `{DEVICE}` Β· dtype: `{str(DTYPE).replace('torch.', '')}` Β· " f"max_length: `{_MAXTOK}`." ) del trainer, model gc.collect() if DEVICE == "cuda": torch.cuda.empty_cache() return fig, status, tuned_raw, tuned_calls, f"Base accuracy: {base_acc:.0%}", \ f"Tuned accuracy: {tuned_acc:.0%}" EXPLANATION = """ # πŸ“± FunctionGemma 270M β€” Mobile Actions SFT Fine-tune Google's **FunctionGemma 270M-IT** to turn phone requests ("turn on the flashlight", "schedule a team meeting tomorrow at 4pm") into **function calls**, using the gated [`google/mobile-actions`](https://huggingface.co/datasets/google/mobile-actions) dataset and TRL's `SFTTrainer`. This is a full fine-tune (no LoRA) in **prompt/completion** format with `completion_only_loss=True`, so loss is computed only on the assistant's call. The chat template is applied with the dataset's `tools=` schema. Pick a request, run SFT, and watch the exact-match function-call accuracy go up. *Omitted from the original notebook: Hugging Face Hub upload and the `.litertlm` / `ai-edge-torch` on-device conversion (not Space-friendly).* """ CUSTOM_CSS = """ .gradio-container { max-width: 1100px !important; margin: auto !important; } #hero { background: linear-gradient(135deg, #7c3aed 0%, #2563eb 50%, #06b6d4 100%); border-radius: 18px; padding: 6px 26px; color: white; box-shadow: 0 10px 30px rgba(37, 99, 235, 0.25); margin-bottom: 8px; } #hero h1 { color: white !important; font-size: 2.0rem !important; } #hero p, #hero li, #hero strong { color: rgba(255,255,255,0.95) !important; } #hero a { color: #bae6fd !important; } .panel-card { border-radius: 16px !important; padding: 16px !important; background: var(--block-background-fill); box-shadow: 0 4px 18px rgba(0,0,0,0.06); border: 1px solid var(--border-color-primary); } #train-btn { font-weight: 700 !important; } footer { visibility: hidden; } """ THEME = gr.themes.Soft( primary_hue="blue", secondary_hue="cyan", font=[gr.themes.GoogleFont("Quicksand"), "system-ui", "sans-serif"], ) EXAMPLE_PROMPTS = [ 'Schedule a "team meeting" tomorrow at 4pm.', "Turn on the flashlight.", "Show me BesanΓ§on, France on the map.", "Open the WiFi settings.", "Create a contact for Alex with number 555-0123.", ] with gr.Blocks(title="FunctionGemma 270M Mobile Actions SFT", theme=THEME, css=CUSTOM_CSS) as demo: with gr.Group(elem_id="hero"): gr.Markdown(EXPLANATION) with gr.Row(): with gr.Column(scale=1): with gr.Group(elem_classes="panel-card"): gr.Markdown("### βš™οΈ Controls") user_prompt = gr.Textbox( value=EXAMPLE_PROMPTS[0], lines=2, label="Mobile-action request (user message)", ) gr.Examples(EXAMPLE_PROMPTS, inputs=user_prompt, label="Try one") developer_content = gr.Textbox( value=DEFAULT_DEVELOPER, lines=3, label="Developer message (context: date/time + role)", ) with gr.Row(): epochs = gr.Slider(1, 3, value=1, step=1, label="Epochs") train_subset = gr.Slider( 50, 1000, value=200, step=50, label="Train subset", info="Fewer = faster.", ) eval_subset = gr.Slider( 10, 100, value=30, step=10, label="Eval examples (for scoring)", ) with gr.Accordion("Advanced", open=False): learning_rate = gr.Slider(1e-6, 5e-5, value=1e-5, step=1e-6, label="Learning rate") batch_size = gr.Slider(1, 8, value=4, step=1, label="Batch size") grad_accum = gr.Slider(1, 16, value=8, step=1, label="Grad accumulation") output_length = gr.Slider(64, 512, value=256, step=32, label="Max new tokens") temperature = gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="Temperature (0 = greedy, best for tools)") with gr.Row(): base_btn = gr.Button("🎲 Ask base model", variant="secondary") train_btn = gr.Button("πŸš€ Fine-tune & Compare", variant="primary", elem_id="train-btn") with gr.Column(scale=1): with gr.Group(elem_classes="panel-card"): gr.Markdown("### πŸ” Results") with gr.Row(): base_acc_box = gr.Markdown() tuned_acc_box = gr.Markdown() with gr.Tab("Parsed calls"): base_calls = gr.Textbox(lines=4, label="🎲 Base model call(s)") tuned_calls = gr.Textbox(lines=4, label="✨ Fine-tuned call(s)") with gr.Tab("Raw output"): tuned_raw = gr.Textbox(lines=8, label="✨ Fine-tuned raw output") loss_plot = gr.Plot(label="πŸ“‰ Training / validation loss") status = gr.Markdown() base_btn.click( base_only, inputs=[user_prompt, developer_content, output_length, temperature], outputs=[tuned_raw, base_calls], ) train_btn.click( finetune_and_compare, inputs=[user_prompt, developer_content, epochs, train_subset, eval_subset, learning_rate, batch_size, grad_accum, output_length, temperature], outputs=[loss_plot, status, tuned_raw, tuned_calls, base_acc_box, tuned_acc_box], ) with gr.Accordion("πŸ’¬ Notes", open=False): gr.Markdown( """ - **Greedy decoding** (temperature 0) is best for function calling β€” you want the single most likely call, not a creative one. - **Exact-match** accuracy is a lower bound: a call with equivalent arguments (e.g. a slightly reworded `query`) counts as wrong but may still be acceptable. - A GPU is strongly recommended. On CPU, training and scoring will be slow β€” shrink the train/eval subsets. """ ) if __name__ == "__main__": demo.launch()