areddydev commited on
Commit
ddfb147
·
verified ·
1 Parent(s): 9894e1a

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +502 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import json
3
+ import os
4
+ import re
5
+ import tempfile
6
+
7
+ import matplotlib
8
+
9
+ matplotlib.use("Agg") # headless backend for Spaces
10
+ import matplotlib.pyplot as plt
11
+
12
+ import gradio as gr
13
+ import torch
14
+
15
+ from datasets import load_dataset
16
+ from huggingface_hub import hf_hub_download
17
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback
18
+ from trl import SFTConfig, SFTTrainer
19
+
20
+
21
+ # ----------------------------
22
+ # Config
23
+ # ----------------------------
24
+ # Both the model and the dataset are gated. Accept the licenses and set HF_TOKEN
25
+ # (a Space "secret" works) before launching:
26
+ # model: https://huggingface.co/google/functiongemma-270m-it
27
+ # dataset: https://huggingface.co/datasets/google/mobile-actions
28
+ MODEL_ID = "google/functiongemma-270m-it"
29
+ DATASET_REPO = "google/mobile-actions"
30
+ DATASET_FILE = "dataset.jsonl"
31
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
32
+
33
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
34
+ DTYPE = torch.bfloat16 if (DEVICE == "cuda" and torch.cuda.is_bf16_supported()) else torch.float32
35
+
36
+ DEFAULT_DEVELOPER = (
37
+ "Current date and time given in YYYY-MM-DDTHH:MM:SS format: 2024-11-15T05:59:00. "
38
+ "You are a model that can do function calling with the following functions"
39
+ )
40
+
41
+
42
+ # ----------------------------
43
+ # Lazy singletons
44
+ # ----------------------------
45
+ _TOKENIZER = None
46
+ _BASE_MODEL = None
47
+ _RAW = None # raw dataset (each row['text'] is a JSON string)
48
+ _TOOLS = None # shared tool schema from the dataset
49
+ _PROCESSED = None # prompt/completion/split formatted dataset
50
+ _MAXTOK = None # max_length to use for SFT
51
+
52
+
53
+ def get_tokenizer():
54
+ global _TOKENIZER
55
+ if _TOKENIZER is None:
56
+ _TOKENIZER = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
57
+ return _TOKENIZER
58
+
59
+
60
+ def load_fresh_model():
61
+ model = AutoModelForCausalLM.from_pretrained(
62
+ MODEL_ID,
63
+ torch_dtype=DTYPE,
64
+ attn_implementation="eager", # recommended for Gemma 3
65
+ token=HF_TOKEN,
66
+ )
67
+ tok = get_tokenizer()
68
+ if tok.pad_token_id is not None:
69
+ model.config.pad_token_id = tok.pad_token_id
70
+ model.to(DEVICE)
71
+ return model
72
+
73
+
74
+ def get_base_model():
75
+ global _BASE_MODEL
76
+ if _BASE_MODEL is None:
77
+ _BASE_MODEL = load_fresh_model()
78
+ _BASE_MODEL.eval()
79
+ return _BASE_MODEL
80
+
81
+
82
+ # ----------------------------
83
+ # Dataset: download, format into prompt/completion, split
84
+ # ----------------------------
85
+ def apply_format(sample):
86
+ tok = get_tokenizer()
87
+ t = json.loads(sample["text"])
88
+ full = tok.apply_chat_template(
89
+ t["messages"], tools=t["tools"], tokenize=False, add_generation_prompt=False
90
+ )
91
+ prompt = tok.apply_chat_template(
92
+ t["messages"][:-1], tools=t["tools"], tokenize=False, add_generation_prompt=True
93
+ )
94
+ completion = full[len(prompt):]
95
+ return {"prompt": prompt, "completion": completion, "split": t["metadata"]}
96
+
97
+
98
+ def ensure_dataset():
99
+ """Download + format once; cache raw rows, tools, processed splits, max_length."""
100
+ global _RAW, _TOOLS, _PROCESSED, _MAXTOK
101
+ if _PROCESSED is not None:
102
+ return
103
+ path = hf_hub_download(repo_id=DATASET_REPO, filename=DATASET_FILE,
104
+ repo_type="dataset", token=HF_TOKEN)
105
+ _RAW = load_dataset("text", data_files=path, encoding="utf-8")["train"].shuffle(seed=7)
106
+ _TOOLS = json.loads(_RAW[0]["text"])["tools"]
107
+
108
+ tok = get_tokenizer()
109
+ _PROCESSED = _RAW.map(apply_format)
110
+ longest = max(_PROCESSED, key=lambda e: len(e["prompt"] + e["completion"]))
111
+ longest_tokens = len(tok.tokenize(longest["prompt"] + longest["completion"]))
112
+ _MAXTOK = longest_tokens + 100
113
+
114
+
115
+ def get_tools():
116
+ ensure_dataset()
117
+ return _TOOLS
118
+
119
+
120
+ # ----------------------------
121
+ # Function-call parsing (from the notebook)
122
+ # ----------------------------
123
+ def extract_function_call(model_output):
124
+ results = []
125
+ call_pattern = r"<start_function_call>(.*?)<end_function_call>"
126
+ for raw_call in re.findall(call_pattern, model_output, re.DOTALL):
127
+ if not raw_call.strip().startswith("call:"):
128
+ continue
129
+ try:
130
+ pre_brace, args_segment = raw_call.split("{", 1)
131
+ function_name = pre_brace.replace("call:", "").strip()
132
+ args_content = args_segment.strip()
133
+ if args_content.endswith("}"):
134
+ args_content = args_content[:-1]
135
+ arguments = {}
136
+ arg_pattern = r"(?P<key>[^:,]*?):<escape>(?P<value>.*?)<escape>"
137
+ for m in re.finditer(arg_pattern, args_content, re.DOTALL):
138
+ arguments[m.group("key").strip()] = m.group("value")
139
+ results.append({"function": {"name": function_name, "arguments": arguments}})
140
+ except ValueError:
141
+ continue
142
+ return results
143
+
144
+
145
+ def extract_text(model_output):
146
+ if not model_output or model_output.startswith("<start_function_call>"):
147
+ return None
148
+ return model_output.replace("<end_of_turn>", "").strip()
149
+
150
+
151
+ def pretty_calls(calls):
152
+ if not calls:
153
+ return "(no function call)"
154
+ lines = []
155
+ for c in calls:
156
+ fn = c["function"]["name"]
157
+ args = ", ".join(f"{k}={v!r}" for k, v in c["function"]["arguments"].items())
158
+ lines.append(f"{fn}({args})")
159
+ return "\n".join(lines)
160
+
161
+
162
+ # ----------------------------
163
+ # Generation
164
+ # ----------------------------
165
+ @torch.no_grad()
166
+ def generate_fc(model, user_prompt, developer_content, max_new_tokens=256, temperature=0.0):
167
+ tok = get_tokenizer()
168
+ model.eval()
169
+ messages = [
170
+ {"role": "developer", "content": developer_content},
171
+ {"role": "user", "content": user_prompt},
172
+ ]
173
+ prompt = tok.apply_chat_template(
174
+ messages, tools=get_tools(), tokenize=False, add_generation_prompt=True
175
+ )
176
+ inputs = tok(prompt, return_tensors="pt").to(model.device)
177
+ gen_kwargs = dict(max_new_tokens=int(max_new_tokens), pad_token_id=tok.pad_token_id)
178
+ if temperature and temperature > 0:
179
+ gen_kwargs.update(do_sample=True, temperature=float(temperature), top_p=0.9)
180
+ else:
181
+ gen_kwargs.update(do_sample=False) # greedy: best for function calling
182
+ out = model.generate(**inputs, **gen_kwargs)
183
+ raw = tok.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=False)
184
+ raw = raw.replace(tok.eos_token or "", "").strip()
185
+ return raw
186
+
187
+
188
+ # ----------------------------
189
+ # Exact-match scoring on an eval subset
190
+ # ----------------------------
191
+ def score_model(model, n_examples, progress=None, desc=""):
192
+ ensure_dataset()
193
+ eval_rows = [r for r in _RAW if json.loads(r["text"])["metadata"] == "eval"]
194
+ eval_rows = eval_rows[: int(n_examples)]
195
+ correct = 0
196
+ for i, row in enumerate(eval_rows):
197
+ msgs = json.loads(row["text"])["messages"]
198
+ user_msg = next((m["content"] for m in msgs if m["role"] == "user"), "")
199
+ target = msgs[-1].get("tool_calls", []) or []
200
+ target_names = [fc["function"]["name"] for fc in target]
201
+ target_args = [dict(sorted(fc["function"]["arguments"].items())) for fc in target]
202
+
203
+ raw = generate_fc(model, user_msg, DEFAULT_DEVELOPER, max_new_tokens=_MAXTOK)
204
+ pred = extract_function_call(raw)
205
+ pred_names = [fc["function"]["name"] for fc in pred]
206
+ pred_args = [dict(sorted(fc["function"]["arguments"].items())) for fc in pred]
207
+
208
+ if target_names == pred_names and target_args == pred_args:
209
+ correct += 1
210
+ if progress is not None:
211
+ progress((i + 1) / len(eval_rows), desc=f"{desc} {i + 1}/{len(eval_rows)}")
212
+ return correct / max(1, len(eval_rows)), len(eval_rows)
213
+
214
+
215
+ # ----------------------------
216
+ # Loss plot (train + eval) from trainer log history
217
+ # ----------------------------
218
+ def make_loss_plot(log_history):
219
+ train_x = [l["step"] for l in log_history if "loss" in l]
220
+ train_y = [l["loss"] for l in log_history if "loss" in l]
221
+ eval_x = [l["step"] for l in log_history if "eval_loss" in l]
222
+ eval_y = [l["eval_loss"] for l in log_history if "eval_loss" in l]
223
+
224
+ fig, ax = plt.subplots(figsize=(6, 3.4))
225
+ fig.patch.set_facecolor("#ffffff")
226
+ ax.set_facecolor("#fbfbfd")
227
+ if train_y:
228
+ ax.plot(train_x, train_y, color="#7c3aed", linewidth=2.2, label="Training loss")
229
+ if eval_y:
230
+ ax.plot(eval_x, eval_y, color="#db2777", linewidth=2.0,
231
+ marker="o", markersize=4, label="Validation loss")
232
+ ax.set_xlabel("Step", fontsize=11)
233
+ ax.set_ylabel("Loss", fontsize=11)
234
+ ax.set_title("FunctionGemma SFT loss 📉", fontsize=12, fontweight="bold", color="#1f2937")
235
+ ax.grid(True, linestyle="--", alpha=0.35)
236
+ if train_y or eval_y:
237
+ ax.legend(frameon=False)
238
+ for spine in ["top", "right"]:
239
+ ax.spines[spine].set_visible(False)
240
+ fig.tight_layout()
241
+ return fig
242
+
243
+
244
+ # ----------------------------
245
+ # Gradio <-> Trainer progress bridge
246
+ # ----------------------------
247
+ class GradioCallback(TrainerCallback):
248
+ def __init__(self, progress):
249
+ self.progress = progress
250
+
251
+ def on_step_end(self, args, state, control, **kwargs):
252
+ total = state.max_steps or 1
253
+ self.progress(state.global_step / total,
254
+ desc=f"SFT step {state.global_step}/{total}")
255
+
256
+
257
+ # ----------------------------
258
+ # Actions
259
+ # ----------------------------
260
+ def base_only(user_prompt, developer_content, output_length, temperature):
261
+ if not user_prompt.strip():
262
+ return "⚠️ Enter a mobile-action request first.", ""
263
+ raw = generate_fc(get_base_model(), user_prompt, developer_content,
264
+ output_length, temperature)
265
+ return raw, pretty_calls(extract_function_call(raw))
266
+
267
+
268
+ def finetune_and_compare(
269
+ user_prompt,
270
+ developer_content,
271
+ epochs,
272
+ train_subset,
273
+ eval_subset,
274
+ learning_rate,
275
+ batch_size,
276
+ grad_accum,
277
+ output_length,
278
+ temperature,
279
+ progress=gr.Progress(),
280
+ ):
281
+ if not user_prompt.strip():
282
+ return None, "⚠️ Enter a mobile-action request first.", "", "", "", ""
283
+
284
+ progress(0.0, desc="Downloading + formatting dataset")
285
+ ensure_dataset()
286
+
287
+ train_ds = _PROCESSED.filter(lambda e: e["split"] == "train")
288
+ eval_ds = _PROCESSED.filter(lambda e: e["split"] == "eval")
289
+ train_ds = train_ds.select(range(min(int(train_subset), len(train_ds))))
290
+ eval_ds = eval_ds.select(range(min(int(eval_subset), len(eval_ds))))
291
+
292
+ # score base model first (re-used for the headline comparison)
293
+ base_acc, n_eval = score_model(get_base_model(), eval_subset, progress, "Scoring base")
294
+
295
+ torch.manual_seed(7)
296
+ model = load_fresh_model()
297
+ if DEVICE == "cuda":
298
+ model.gradient_checkpointing_enable()
299
+ model.config.use_cache = False
300
+
301
+ total_steps = max(1, (len(train_ds) // (int(batch_size) * int(grad_accum)))) * int(epochs)
302
+
303
+ with tempfile.TemporaryDirectory() as out_dir:
304
+ cfg = SFTConfig(
305
+ output_dir=out_dir,
306
+ num_train_epochs=float(epochs),
307
+ per_device_train_batch_size=int(batch_size),
308
+ gradient_accumulation_steps=int(grad_accum),
309
+ learning_rate=float(learning_rate),
310
+ lr_scheduler_type="cosine",
311
+ logging_strategy="steps",
312
+ logging_steps=1,
313
+ eval_strategy="steps" if len(eval_ds) else "no",
314
+ eval_steps=max(1, total_steps // 4),
315
+ save_strategy="no",
316
+ max_length=_MAXTOK,
317
+ gradient_checkpointing=(DEVICE == "cuda"),
318
+ packing=False,
319
+ optim="adamw_torch_fused" if DEVICE == "cuda" else "adamw_torch",
320
+ bf16=(DTYPE == torch.bfloat16),
321
+ completion_only_loss=True, # loss on the assistant turn only
322
+ report_to="none",
323
+ seed=7,
324
+ )
325
+ trainer = SFTTrainer(
326
+ model=model,
327
+ args=cfg,
328
+ train_dataset=train_ds,
329
+ eval_dataset=eval_ds if len(eval_ds) else None,
330
+ callbacks=[GradioCallback(progress)],
331
+ )
332
+ trainer.train()
333
+ log_history = list(trainer.state.log_history)
334
+
335
+ # switch back to inference mode
336
+ if DEVICE == "cuda":
337
+ model.gradient_checkpointing_disable()
338
+ model.config.use_cache = True
339
+
340
+ fig = make_loss_plot(log_history)
341
+
342
+ # tuned model outputs for the user's prompt
343
+ tuned_raw = generate_fc(model, user_prompt, developer_content, output_length, temperature)
344
+ tuned_calls = pretty_calls(extract_function_call(tuned_raw))
345
+
346
+ # score tuned model
347
+ tuned_acc, _ = score_model(model, eval_subset, progress, "Scoring tuned")
348
+
349
+ losses = [l["loss"] for l in log_history if "loss" in l]
350
+ first_loss = losses[0] if losses else 0.0
351
+ last_loss = losses[-1] if losses else 0.0
352
+ status = (
353
+ f"✅ Full fine-tuned **FunctionGemma 270M-IT** on **{len(train_ds)} train examples** "
354
+ f"for **{epochs} epoch(s)** ({total_steps} steps).\n\n"
355
+ f"Loss **{first_loss:.3f} → {last_loss:.3f}**. "
356
+ f"Exact-match function-call accuracy on {n_eval} eval examples: "
357
+ f"**base {base_acc:.0%} → tuned {tuned_acc:.0%}**.\n\n"
358
+ f"Device: `{DEVICE}` · dtype: `{str(DTYPE).replace('torch.', '')}` · "
359
+ f"max_length: `{_MAXTOK}`."
360
+ )
361
+
362
+ del trainer, model
363
+ gc.collect()
364
+ if DEVICE == "cuda":
365
+ torch.cuda.empty_cache()
366
+
367
+ return fig, status, tuned_raw, tuned_calls, f"Base accuracy: {base_acc:.0%}", \
368
+ f"Tuned accuracy: {tuned_acc:.0%}"
369
+
370
+
371
+ EXPLANATION = """
372
+ # 📱 FunctionGemma 270M — Mobile Actions SFT
373
+
374
+ Fine-tune Google's **FunctionGemma 270M-IT** to turn phone requests
375
+ ("turn on the flashlight", "schedule a team meeting tomorrow at 4pm") into
376
+ **function calls**, using the gated [`google/mobile-actions`](https://huggingface.co/datasets/google/mobile-actions)
377
+ dataset and TRL's `SFTTrainer`.
378
+
379
+ This is a full fine-tune (no LoRA) in **prompt/completion** format with
380
+ `completion_only_loss=True`, so loss is computed only on the assistant's call.
381
+ The chat template is applied with the dataset's `tools=` schema. Pick a request,
382
+ run SFT, and watch the exact-match function-call accuracy go up.
383
+
384
+ *Omitted from the original notebook: Hugging Face Hub upload and the
385
+ `.litertlm` / `ai-edge-torch` on-device conversion (not Space-friendly).*
386
+ """
387
+
388
+ CUSTOM_CSS = """
389
+ .gradio-container { max-width: 1100px !important; margin: auto !important; }
390
+ #hero {
391
+ background: linear-gradient(135deg, #7c3aed 0%, #2563eb 50%, #06b6d4 100%);
392
+ border-radius: 18px; padding: 6px 26px; color: white;
393
+ box-shadow: 0 10px 30px rgba(37, 99, 235, 0.25); margin-bottom: 8px;
394
+ }
395
+ #hero h1 { color: white !important; font-size: 2.0rem !important; }
396
+ #hero p, #hero li, #hero strong { color: rgba(255,255,255,0.95) !important; }
397
+ #hero a { color: #bae6fd !important; }
398
+ .panel-card {
399
+ border-radius: 16px !important; padding: 16px !important;
400
+ background: var(--block-background-fill);
401
+ box-shadow: 0 4px 18px rgba(0,0,0,0.06);
402
+ border: 1px solid var(--border-color-primary);
403
+ }
404
+ #train-btn { font-weight: 700 !important; }
405
+ footer { visibility: hidden; }
406
+ """
407
+
408
+ THEME = gr.themes.Soft(
409
+ primary_hue="blue",
410
+ secondary_hue="cyan",
411
+ font=[gr.themes.GoogleFont("Quicksand"), "system-ui", "sans-serif"],
412
+ )
413
+
414
+ EXAMPLE_PROMPTS = [
415
+ 'Schedule a "team meeting" tomorrow at 4pm.',
416
+ "Turn on the flashlight.",
417
+ "Show me Besançon, France on the map.",
418
+ "Open the WiFi settings.",
419
+ "Create a contact for Alex with number 555-0123.",
420
+ ]
421
+
422
+
423
+ with gr.Blocks(title="FunctionGemma 270M Mobile Actions SFT", theme=THEME, css=CUSTOM_CSS) as demo:
424
+ with gr.Group(elem_id="hero"):
425
+ gr.Markdown(EXPLANATION)
426
+
427
+ with gr.Row():
428
+ with gr.Column(scale=1):
429
+ with gr.Group(elem_classes="panel-card"):
430
+ gr.Markdown("### ⚙️ Controls")
431
+ user_prompt = gr.Textbox(
432
+ value=EXAMPLE_PROMPTS[0], lines=2,
433
+ label="Mobile-action request (user message)",
434
+ )
435
+ gr.Examples(EXAMPLE_PROMPTS, inputs=user_prompt, label="Try one")
436
+ developer_content = gr.Textbox(
437
+ value=DEFAULT_DEVELOPER, lines=3,
438
+ label="Developer message (context: date/time + role)",
439
+ )
440
+ with gr.Row():
441
+ epochs = gr.Slider(1, 3, value=1, step=1, label="Epochs")
442
+ train_subset = gr.Slider(
443
+ 50, 1000, value=200, step=50, label="Train subset",
444
+ info="Fewer = faster.",
445
+ )
446
+ eval_subset = gr.Slider(
447
+ 10, 100, value=30, step=10, label="Eval examples (for scoring)",
448
+ )
449
+ with gr.Accordion("Advanced", open=False):
450
+ learning_rate = gr.Slider(1e-6, 5e-5, value=1e-5, step=1e-6, label="Learning rate")
451
+ batch_size = gr.Slider(1, 8, value=4, step=1, label="Batch size")
452
+ grad_accum = gr.Slider(1, 16, value=8, step=1, label="Grad accumulation")
453
+ output_length = gr.Slider(64, 512, value=256, step=32, label="Max new tokens")
454
+ temperature = gr.Slider(0.0, 1.0, value=0.0, step=0.1,
455
+ label="Temperature (0 = greedy, best for tools)")
456
+
457
+ with gr.Row():
458
+ base_btn = gr.Button("🎲 Ask base model", variant="secondary")
459
+ train_btn = gr.Button("🚀 Fine-tune & Compare", variant="primary", elem_id="train-btn")
460
+
461
+ with gr.Column(scale=1):
462
+ with gr.Group(elem_classes="panel-card"):
463
+ gr.Markdown("### 🔍 Results")
464
+ with gr.Row():
465
+ base_acc_box = gr.Markdown()
466
+ tuned_acc_box = gr.Markdown()
467
+ with gr.Tab("Parsed calls"):
468
+ base_calls = gr.Textbox(lines=4, label="🎲 Base model call(s)")
469
+ tuned_calls = gr.Textbox(lines=4, label="✨ Fine-tuned call(s)")
470
+ with gr.Tab("Raw output"):
471
+ tuned_raw = gr.Textbox(lines=8, label="✨ Fine-tuned raw output")
472
+ loss_plot = gr.Plot(label="📉 Training / validation loss")
473
+ status = gr.Markdown()
474
+
475
+ base_btn.click(
476
+ base_only,
477
+ inputs=[user_prompt, developer_content, output_length, temperature],
478
+ outputs=[tuned_raw, base_calls],
479
+ )
480
+
481
+ train_btn.click(
482
+ finetune_and_compare,
483
+ inputs=[user_prompt, developer_content, epochs, train_subset, eval_subset,
484
+ learning_rate, batch_size, grad_accum, output_length, temperature],
485
+ outputs=[loss_plot, status, tuned_raw, tuned_calls, base_acc_box, tuned_acc_box],
486
+ )
487
+
488
+ with gr.Accordion("💬 Notes", open=False):
489
+ gr.Markdown(
490
+ """
491
+ - **Greedy decoding** (temperature 0) is best for function calling — you want the
492
+ single most likely call, not a creative one.
493
+ - **Exact-match** accuracy is a lower bound: a call with equivalent arguments
494
+ (e.g. a slightly reworded `query`) counts as wrong but may still be acceptable.
495
+ - A GPU is strongly recommended. On CPU, training and scoring will be slow —
496
+ shrink the train/eval subsets.
497
+ """
498
+ )
499
+
500
+
501
+ if __name__ == "__main__":
502
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers==4.57.1
3
+ trl==0.25.1
4
+ datasets==4.4.1
5
+ accelerate
6
+ sentencepiece
7
+ matplotlib
8
+ gradio