areddydev's picture
Upload 2 files
ddfb147 verified
Raw
History Blame Contribute Delete
19.1 kB
import gc
import json
import os
import re
import tempfile
import matplotlib
matplotlib.use("Agg") # headless backend for Spaces
import matplotlib.pyplot as plt
import gradio as gr
import torch
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback
from trl import SFTConfig, SFTTrainer
# ----------------------------
# Config
# ----------------------------
# Both the model and the dataset are gated. Accept the licenses and set HF_TOKEN
# (a Space "secret" works) before launching:
# model: https://huggingface.co/google/functiongemma-270m-it
# dataset: https://huggingface.co/datasets/google/mobile-actions
MODEL_ID = "google/functiongemma-270m-it"
DATASET_REPO = "google/mobile-actions"
DATASET_FILE = "dataset.jsonl"
HF_TOKEN = os.environ.get("HF_TOKEN", None)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16 if (DEVICE == "cuda" and torch.cuda.is_bf16_supported()) else torch.float32
DEFAULT_DEVELOPER = (
"Current date and time given in YYYY-MM-DDTHH:MM:SS format: 2024-11-15T05:59:00. "
"You are a model that can do function calling with the following functions"
)
# ----------------------------
# Lazy singletons
# ----------------------------
_TOKENIZER = None
_BASE_MODEL = None
_RAW = None # raw dataset (each row['text'] is a JSON string)
_TOOLS = None # shared tool schema from the dataset
_PROCESSED = None # prompt/completion/split formatted dataset
_MAXTOK = None # max_length to use for SFT
def get_tokenizer():
global _TOKENIZER
if _TOKENIZER is None:
_TOKENIZER = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
return _TOKENIZER
def load_fresh_model():
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=DTYPE,
attn_implementation="eager", # recommended for Gemma 3
token=HF_TOKEN,
)
tok = get_tokenizer()
if tok.pad_token_id is not None:
model.config.pad_token_id = tok.pad_token_id
model.to(DEVICE)
return model
def get_base_model():
global _BASE_MODEL
if _BASE_MODEL is None:
_BASE_MODEL = load_fresh_model()
_BASE_MODEL.eval()
return _BASE_MODEL
# ----------------------------
# Dataset: download, format into prompt/completion, split
# ----------------------------
def apply_format(sample):
tok = get_tokenizer()
t = json.loads(sample["text"])
full = tok.apply_chat_template(
t["messages"], tools=t["tools"], tokenize=False, add_generation_prompt=False
)
prompt = tok.apply_chat_template(
t["messages"][:-1], tools=t["tools"], tokenize=False, add_generation_prompt=True
)
completion = full[len(prompt):]
return {"prompt": prompt, "completion": completion, "split": t["metadata"]}
def ensure_dataset():
"""Download + format once; cache raw rows, tools, processed splits, max_length."""
global _RAW, _TOOLS, _PROCESSED, _MAXTOK
if _PROCESSED is not None:
return
path = hf_hub_download(repo_id=DATASET_REPO, filename=DATASET_FILE,
repo_type="dataset", token=HF_TOKEN)
_RAW = load_dataset("text", data_files=path, encoding="utf-8")["train"].shuffle(seed=7)
_TOOLS = json.loads(_RAW[0]["text"])["tools"]
tok = get_tokenizer()
_PROCESSED = _RAW.map(apply_format)
longest = max(_PROCESSED, key=lambda e: len(e["prompt"] + e["completion"]))
longest_tokens = len(tok.tokenize(longest["prompt"] + longest["completion"]))
_MAXTOK = longest_tokens + 100
def get_tools():
ensure_dataset()
return _TOOLS
# ----------------------------
# Function-call parsing (from the notebook)
# ----------------------------
def extract_function_call(model_output):
results = []
call_pattern = r"<start_function_call>(.*?)<end_function_call>"
for raw_call in re.findall(call_pattern, model_output, re.DOTALL):
if not raw_call.strip().startswith("call:"):
continue
try:
pre_brace, args_segment = raw_call.split("{", 1)
function_name = pre_brace.replace("call:", "").strip()
args_content = args_segment.strip()
if args_content.endswith("}"):
args_content = args_content[:-1]
arguments = {}
arg_pattern = r"(?P<key>[^:,]*?):<escape>(?P<value>.*?)<escape>"
for m in re.finditer(arg_pattern, args_content, re.DOTALL):
arguments[m.group("key").strip()] = m.group("value")
results.append({"function": {"name": function_name, "arguments": arguments}})
except ValueError:
continue
return results
def extract_text(model_output):
if not model_output or model_output.startswith("<start_function_call>"):
return None
return model_output.replace("<end_of_turn>", "").strip()
def pretty_calls(calls):
if not calls:
return "(no function call)"
lines = []
for c in calls:
fn = c["function"]["name"]
args = ", ".join(f"{k}={v!r}" for k, v in c["function"]["arguments"].items())
lines.append(f"{fn}({args})")
return "\n".join(lines)
# ----------------------------
# Generation
# ----------------------------
@torch.no_grad()
def generate_fc(model, user_prompt, developer_content, max_new_tokens=256, temperature=0.0):
tok = get_tokenizer()
model.eval()
messages = [
{"role": "developer", "content": developer_content},
{"role": "user", "content": user_prompt},
]
prompt = tok.apply_chat_template(
messages, tools=get_tools(), tokenize=False, add_generation_prompt=True
)
inputs = tok(prompt, return_tensors="pt").to(model.device)
gen_kwargs = dict(max_new_tokens=int(max_new_tokens), pad_token_id=tok.pad_token_id)
if temperature and temperature > 0:
gen_kwargs.update(do_sample=True, temperature=float(temperature), top_p=0.9)
else:
gen_kwargs.update(do_sample=False) # greedy: best for function calling
out = model.generate(**inputs, **gen_kwargs)
raw = tok.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=False)
raw = raw.replace(tok.eos_token or "", "").strip()
return raw
# ----------------------------
# Exact-match scoring on an eval subset
# ----------------------------
def score_model(model, n_examples, progress=None, desc=""):
ensure_dataset()
eval_rows = [r for r in _RAW if json.loads(r["text"])["metadata"] == "eval"]
eval_rows = eval_rows[: int(n_examples)]
correct = 0
for i, row in enumerate(eval_rows):
msgs = json.loads(row["text"])["messages"]
user_msg = next((m["content"] for m in msgs if m["role"] == "user"), "")
target = msgs[-1].get("tool_calls", []) or []
target_names = [fc["function"]["name"] for fc in target]
target_args = [dict(sorted(fc["function"]["arguments"].items())) for fc in target]
raw = generate_fc(model, user_msg, DEFAULT_DEVELOPER, max_new_tokens=_MAXTOK)
pred = extract_function_call(raw)
pred_names = [fc["function"]["name"] for fc in pred]
pred_args = [dict(sorted(fc["function"]["arguments"].items())) for fc in pred]
if target_names == pred_names and target_args == pred_args:
correct += 1
if progress is not None:
progress((i + 1) / len(eval_rows), desc=f"{desc} {i + 1}/{len(eval_rows)}")
return correct / max(1, len(eval_rows)), len(eval_rows)
# ----------------------------
# Loss plot (train + eval) from trainer log history
# ----------------------------
def make_loss_plot(log_history):
train_x = [l["step"] for l in log_history if "loss" in l]
train_y = [l["loss"] for l in log_history if "loss" in l]
eval_x = [l["step"] for l in log_history if "eval_loss" in l]
eval_y = [l["eval_loss"] for l in log_history if "eval_loss" in l]
fig, ax = plt.subplots(figsize=(6, 3.4))
fig.patch.set_facecolor("#ffffff")
ax.set_facecolor("#fbfbfd")
if train_y:
ax.plot(train_x, train_y, color="#7c3aed", linewidth=2.2, label="Training loss")
if eval_y:
ax.plot(eval_x, eval_y, color="#db2777", linewidth=2.0,
marker="o", markersize=4, label="Validation loss")
ax.set_xlabel("Step", fontsize=11)
ax.set_ylabel("Loss", fontsize=11)
ax.set_title("FunctionGemma SFT loss 📉", fontsize=12, fontweight="bold", color="#1f2937")
ax.grid(True, linestyle="--", alpha=0.35)
if train_y or eval_y:
ax.legend(frameon=False)
for spine in ["top", "right"]:
ax.spines[spine].set_visible(False)
fig.tight_layout()
return fig
# ----------------------------
# Gradio <-> Trainer progress bridge
# ----------------------------
class GradioCallback(TrainerCallback):
def __init__(self, progress):
self.progress = progress
def on_step_end(self, args, state, control, **kwargs):
total = state.max_steps or 1
self.progress(state.global_step / total,
desc=f"SFT step {state.global_step}/{total}")
# ----------------------------
# Actions
# ----------------------------
def base_only(user_prompt, developer_content, output_length, temperature):
if not user_prompt.strip():
return "⚠️ Enter a mobile-action request first.", ""
raw = generate_fc(get_base_model(), user_prompt, developer_content,
output_length, temperature)
return raw, pretty_calls(extract_function_call(raw))
def finetune_and_compare(
user_prompt,
developer_content,
epochs,
train_subset,
eval_subset,
learning_rate,
batch_size,
grad_accum,
output_length,
temperature,
progress=gr.Progress(),
):
if not user_prompt.strip():
return None, "⚠️ Enter a mobile-action request first.", "", "", "", ""
progress(0.0, desc="Downloading + formatting dataset")
ensure_dataset()
train_ds = _PROCESSED.filter(lambda e: e["split"] == "train")
eval_ds = _PROCESSED.filter(lambda e: e["split"] == "eval")
train_ds = train_ds.select(range(min(int(train_subset), len(train_ds))))
eval_ds = eval_ds.select(range(min(int(eval_subset), len(eval_ds))))
# score base model first (re-used for the headline comparison)
base_acc, n_eval = score_model(get_base_model(), eval_subset, progress, "Scoring base")
torch.manual_seed(7)
model = load_fresh_model()
if DEVICE == "cuda":
model.gradient_checkpointing_enable()
model.config.use_cache = False
total_steps = max(1, (len(train_ds) // (int(batch_size) * int(grad_accum)))) * int(epochs)
with tempfile.TemporaryDirectory() as out_dir:
cfg = SFTConfig(
output_dir=out_dir,
num_train_epochs=float(epochs),
per_device_train_batch_size=int(batch_size),
gradient_accumulation_steps=int(grad_accum),
learning_rate=float(learning_rate),
lr_scheduler_type="cosine",
logging_strategy="steps",
logging_steps=1,
eval_strategy="steps" if len(eval_ds) else "no",
eval_steps=max(1, total_steps // 4),
save_strategy="no",
max_length=_MAXTOK,
gradient_checkpointing=(DEVICE == "cuda"),
packing=False,
optim="adamw_torch_fused" if DEVICE == "cuda" else "adamw_torch",
bf16=(DTYPE == torch.bfloat16),
completion_only_loss=True, # loss on the assistant turn only
report_to="none",
seed=7,
)
trainer = SFTTrainer(
model=model,
args=cfg,
train_dataset=train_ds,
eval_dataset=eval_ds if len(eval_ds) else None,
callbacks=[GradioCallback(progress)],
)
trainer.train()
log_history = list(trainer.state.log_history)
# switch back to inference mode
if DEVICE == "cuda":
model.gradient_checkpointing_disable()
model.config.use_cache = True
fig = make_loss_plot(log_history)
# tuned model outputs for the user's prompt
tuned_raw = generate_fc(model, user_prompt, developer_content, output_length, temperature)
tuned_calls = pretty_calls(extract_function_call(tuned_raw))
# score tuned model
tuned_acc, _ = score_model(model, eval_subset, progress, "Scoring tuned")
losses = [l["loss"] for l in log_history if "loss" in l]
first_loss = losses[0] if losses else 0.0
last_loss = losses[-1] if losses else 0.0
status = (
f"✅ Full fine-tuned **FunctionGemma 270M-IT** on **{len(train_ds)} train examples** "
f"for **{epochs} epoch(s)** ({total_steps} steps).\n\n"
f"Loss **{first_loss:.3f}{last_loss:.3f}**. "
f"Exact-match function-call accuracy on {n_eval} eval examples: "
f"**base {base_acc:.0%} → tuned {tuned_acc:.0%}**.\n\n"
f"Device: `{DEVICE}` · dtype: `{str(DTYPE).replace('torch.', '')}` · "
f"max_length: `{_MAXTOK}`."
)
del trainer, model
gc.collect()
if DEVICE == "cuda":
torch.cuda.empty_cache()
return fig, status, tuned_raw, tuned_calls, f"Base accuracy: {base_acc:.0%}", \
f"Tuned accuracy: {tuned_acc:.0%}"
EXPLANATION = """
# 📱 FunctionGemma 270M — Mobile Actions SFT
Fine-tune Google's **FunctionGemma 270M-IT** to turn phone requests
("turn on the flashlight", "schedule a team meeting tomorrow at 4pm") into
**function calls**, using the gated [`google/mobile-actions`](https://huggingface.co/datasets/google/mobile-actions)
dataset and TRL's `SFTTrainer`.
This is a full fine-tune (no LoRA) in **prompt/completion** format with
`completion_only_loss=True`, so loss is computed only on the assistant's call.
The chat template is applied with the dataset's `tools=` schema. Pick a request,
run SFT, and watch the exact-match function-call accuracy go up.
*Omitted from the original notebook: Hugging Face Hub upload and the
`.litertlm` / `ai-edge-torch` on-device conversion (not Space-friendly).*
"""
CUSTOM_CSS = """
.gradio-container { max-width: 1100px !important; margin: auto !important; }
#hero {
background: linear-gradient(135deg, #7c3aed 0%, #2563eb 50%, #06b6d4 100%);
border-radius: 18px; padding: 6px 26px; color: white;
box-shadow: 0 10px 30px rgba(37, 99, 235, 0.25); margin-bottom: 8px;
}
#hero h1 { color: white !important; font-size: 2.0rem !important; }
#hero p, #hero li, #hero strong { color: rgba(255,255,255,0.95) !important; }
#hero a { color: #bae6fd !important; }
.panel-card {
border-radius: 16px !important; padding: 16px !important;
background: var(--block-background-fill);
box-shadow: 0 4px 18px rgba(0,0,0,0.06);
border: 1px solid var(--border-color-primary);
}
#train-btn { font-weight: 700 !important; }
footer { visibility: hidden; }
"""
THEME = gr.themes.Soft(
primary_hue="blue",
secondary_hue="cyan",
font=[gr.themes.GoogleFont("Quicksand"), "system-ui", "sans-serif"],
)
EXAMPLE_PROMPTS = [
'Schedule a "team meeting" tomorrow at 4pm.',
"Turn on the flashlight.",
"Show me Besançon, France on the map.",
"Open the WiFi settings.",
"Create a contact for Alex with number 555-0123.",
]
with gr.Blocks(title="FunctionGemma 270M Mobile Actions SFT", theme=THEME, css=CUSTOM_CSS) as demo:
with gr.Group(elem_id="hero"):
gr.Markdown(EXPLANATION)
with gr.Row():
with gr.Column(scale=1):
with gr.Group(elem_classes="panel-card"):
gr.Markdown("### ⚙️ Controls")
user_prompt = gr.Textbox(
value=EXAMPLE_PROMPTS[0], lines=2,
label="Mobile-action request (user message)",
)
gr.Examples(EXAMPLE_PROMPTS, inputs=user_prompt, label="Try one")
developer_content = gr.Textbox(
value=DEFAULT_DEVELOPER, lines=3,
label="Developer message (context: date/time + role)",
)
with gr.Row():
epochs = gr.Slider(1, 3, value=1, step=1, label="Epochs")
train_subset = gr.Slider(
50, 1000, value=200, step=50, label="Train subset",
info="Fewer = faster.",
)
eval_subset = gr.Slider(
10, 100, value=30, step=10, label="Eval examples (for scoring)",
)
with gr.Accordion("Advanced", open=False):
learning_rate = gr.Slider(1e-6, 5e-5, value=1e-5, step=1e-6, label="Learning rate")
batch_size = gr.Slider(1, 8, value=4, step=1, label="Batch size")
grad_accum = gr.Slider(1, 16, value=8, step=1, label="Grad accumulation")
output_length = gr.Slider(64, 512, value=256, step=32, label="Max new tokens")
temperature = gr.Slider(0.0, 1.0, value=0.0, step=0.1,
label="Temperature (0 = greedy, best for tools)")
with gr.Row():
base_btn = gr.Button("🎲 Ask base model", variant="secondary")
train_btn = gr.Button("🚀 Fine-tune & Compare", variant="primary", elem_id="train-btn")
with gr.Column(scale=1):
with gr.Group(elem_classes="panel-card"):
gr.Markdown("### 🔍 Results")
with gr.Row():
base_acc_box = gr.Markdown()
tuned_acc_box = gr.Markdown()
with gr.Tab("Parsed calls"):
base_calls = gr.Textbox(lines=4, label="🎲 Base model call(s)")
tuned_calls = gr.Textbox(lines=4, label="✨ Fine-tuned call(s)")
with gr.Tab("Raw output"):
tuned_raw = gr.Textbox(lines=8, label="✨ Fine-tuned raw output")
loss_plot = gr.Plot(label="📉 Training / validation loss")
status = gr.Markdown()
base_btn.click(
base_only,
inputs=[user_prompt, developer_content, output_length, temperature],
outputs=[tuned_raw, base_calls],
)
train_btn.click(
finetune_and_compare,
inputs=[user_prompt, developer_content, epochs, train_subset, eval_subset,
learning_rate, batch_size, grad_accum, output_length, temperature],
outputs=[loss_plot, status, tuned_raw, tuned_calls, base_acc_box, tuned_acc_box],
)
with gr.Accordion("💬 Notes", open=False):
gr.Markdown(
"""
- **Greedy decoding** (temperature 0) is best for function calling — you want the
single most likely call, not a creative one.
- **Exact-match** accuracy is a lower bound: a call with equivalent arguments
(e.g. a slightly reworded `query`) counts as wrong but may still be acceptable.
- A GPU is strongly recommended. On CPU, training and scoring will be slow —
shrink the train/eval subsets.
"""
)
if __name__ == "__main__":
demo.launch()