Spaces:
Sleeping
Sleeping
| """Score the scheduling fine-tune on the held-out set (training/data/eval.jsonl). | |
| Reuses the deployed contract exactly: server.agent.SYSTEM + build_messages for the | |
| prompt, server.schema.ActionPlan for parsing, and the same OpenAI-compatible | |
| `{INFERENCE_BASE_URL}/chat/completions` call with a json_schema response_format. | |
| Metrics (task-specific — generic LLM benchmarks don't apply here): | |
| - schema validity : model output parses + validates as ActionPlan | |
| - event precision/recall/F1 (matched on exact start datetime, minute precision) | |
| - start-exact recall : did it nail the date AND time (incl. relative dates) | |
| - end-exact rate : on matched events where gold has an end | |
| - title similarity : difflib ratio on matched events (soft) | |
| - location recall/sim : on matched events where gold has a location | |
| - reminder match rate : on matched events where gold sets reminder_minutes. | |
| NOTE: only eval_unstructured.jsonl golds encode the type rules (medical 60 / | |
| party 30 / carpool-school 45); the structured eval.jsonl golds predate them, | |
| so its reminder number is informational only. | |
| - no-event accuracy : on chitchat, did it correctly return zero events | |
| - clarification recall : on ambiguous threads, did it ask instead of inventing | |
| Usage (needs a model serving an OpenAI-compatible endpoint): | |
| INFERENCE_BASE_URL=http://127.0.0.1:8080/v1 python training/eval.py | |
| A/B arms (same metrics, same gold): | |
| PREDICTOR=stub python training/eval.py # arm A: the regex stub, no model/HTTP | |
| VISION=1 python training/eval.py # arm C: the thread is fed ONLY as a | |
| rendered screenshot (training/data/screenshots/<eval-stem>/<id>.png — make | |
| them with training/render_screenshots.py); needs a vision-enabled server. | |
| EVAL_PATH=training/data/eval_unstructured.jsonl ... # the non-formulaic set | |
| """ | |
| from __future__ import annotations | |
| import base64 | |
| import json | |
| import os | |
| import re | |
| import sys | |
| from datetime import datetime | |
| from difflib import SequenceMatcher | |
| from pathlib import Path | |
| import requests | |
| if hasattr(sys.stdout, "reconfigure"): # emoji in threads vs Windows cp1252 console | |
| sys.stdout.reconfigure(encoding="utf-8", errors="replace") | |
| ROOT = Path(__file__).resolve().parent.parent | |
| sys.path.insert(0, str(ROOT)) | |
| from server.agent import ( # noqa: E402 (light import; model is lazy) | |
| TITLES_SCHEMA, | |
| apply_text_rules, | |
| build_messages, | |
| build_title_messages, | |
| merge_titles, | |
| ) | |
| from server.schema import ActionPlan # noqa: E402 | |
| BASE = os.environ.get("INFERENCE_BASE_URL", "http://127.0.0.1:8080/v1").rstrip("/") | |
| MODEL = os.environ.get("INFERENCE_MODEL", "local") | |
| PREDICTOR = os.environ.get("PREDICTOR", "model") # "model" (HTTP) | "stub" (regex) | |
| VISION = os.environ.get("VISION") == "1" # feed the thread as a screenshot only | |
| TITLE_POLISH = os.environ.get("TITLE_POLISH") == "1" # second pass rewriting titles | |
| LABEL = os.environ.get("MODEL_LABEL", | |
| "stub-regex" if PREDICTOR == "stub" | |
| else (f"{MODEL}+vision" if VISION else MODEL)) | |
| EVAL_PATH = Path(os.environ.get("EVAL_PATH", ROOT / "training" / "data" / "eval.jsonl")) | |
| SCREENSHOT_DIR = Path(os.environ.get( | |
| "SCREENSHOT_DIR", ROOT / "training" / "data" / "screenshots")) / EVAL_PATH.stem | |
| VISION_THREAD = "(the conversation is in the attached screenshot)" | |
| SCHEMA = ActionPlan.model_json_schema() | |
| def call_model(messages: list[dict], schema: dict = None, name: str = "ActionPlan", | |
| max_tokens: int = 1024) -> str: | |
| payload = { | |
| "model": MODEL, | |
| "messages": messages, | |
| "temperature": 0.0, # greedy -> reproducible eval | |
| "max_tokens": max_tokens, | |
| "response_format": { | |
| "type": "json_schema", | |
| "json_schema": {"name": name, "schema": schema or SCHEMA, "strict": True}, | |
| }, | |
| "stream": False, | |
| } | |
| r = requests.post(f"{BASE}/chat/completions", json=payload, timeout=180) | |
| r.raise_for_status() | |
| return r.json()["choices"][0]["message"]["content"] | |
| def parse_plan(raw: str): | |
| try: | |
| return ActionPlan(**json.loads(raw)), True | |
| except Exception: # noqa: BLE001 | |
| m = re.search(r"\{.*\}", raw, re.DOTALL) | |
| if m: | |
| try: | |
| return ActionPlan(**json.loads(m.group(0))), True | |
| except Exception: # noqa: BLE001 | |
| pass | |
| return ActionPlan(), False | |
| def dt_key(s): | |
| if not s: | |
| return None | |
| try: | |
| return datetime.fromisoformat(s).strftime("%Y-%m-%dT%H:%M") | |
| except Exception: # noqa: BLE001 | |
| return str(s)[:16] | |
| def match_events(gold, pred): | |
| """Greedy one-to-one match on exact start (minute precision). Returns | |
| (matched_pairs, n_fp, n_fn).""" | |
| pred_left = list(pred) | |
| pairs, fn = [], 0 | |
| for g in gold: | |
| gk = dt_key(g.get("start")) | |
| hit = next((p for p in pred_left if dt_key(p.get("start")) == gk), None) | |
| if hit is not None: | |
| pairs.append((g, hit)) | |
| pred_left.remove(hit) | |
| else: | |
| fn += 1 | |
| return pairs, len(pred_left), fn | |
| def sim(a, b): | |
| return round(SequenceMatcher(None, (a or "").lower(), (b or "").lower()).ratio(), 2) | |
| def _screenshot_uri(rec_id: str) -> str: | |
| png = SCREENSHOT_DIR / f"{rec_id}.png" | |
| data = base64.b64encode(png.read_bytes()).decode() | |
| return f"data:image/png;base64,{data}" | |
| def predict(rec: dict, now: datetime): | |
| """Run one arm on one example -> (ActionPlan, schema_valid).""" | |
| if PREDICTOR == "stub": | |
| os.environ["USE_STUB_EXTRACTOR"] = "1" | |
| from server.agent import run_agent # noqa: PLC0415 lazy, like the server | |
| # memory_block="" keeps the eval from writing the server-side memory file | |
| return run_agent(rec["thread"], now=now, memory_block=""), True | |
| images = [_screenshot_uri(rec["id"])] if VISION else None | |
| thread = VISION_THREAD if VISION else rec["thread"] | |
| messages = build_messages(thread, now, [], images) | |
| # MINIMAL_PROMPT=1: drop the system prompt entirely — measures how much task | |
| # knowledge is INTERNALIZED (fine-tune) vs prompted-in (stock). User content | |
| # (datetime + thread + "Return the ActionPlan JSON now.") stays identical. | |
| if os.environ.get("MINIMAL_PROMPT") == "1": | |
| messages = [m for m in messages if m["role"] != "system"] | |
| plan, valid = parse_plan(call_model(messages)) | |
| # TITLE_POLISH=1: same second pass the server runs — rewrite titles only. | |
| # Text mode only: the polish prompt quotes the thread, which VISION withholds. | |
| if TITLE_POLISH and not VISION and plan.events: | |
| raw = call_model( | |
| build_title_messages(rec["thread"], [e.model_dump() for e in plan.events]), | |
| schema=TITLES_SCHEMA, name="Titles", max_tokens=256, | |
| ) | |
| plan = merge_titles(plan, raw) | |
| # Same deterministic logistics post-pass the server applies (arrival shift + | |
| # reminder rules). Text mode only — it reads the thread text. | |
| if not VISION: | |
| plan = apply_text_rules(rec["thread"], plan) | |
| return plan, valid | |
| def main(): | |
| records = [json.loads(l) for l in EVAL_PATH.read_text(encoding="utf-8").splitlines() if l.strip()] | |
| where = "(local regex stub)" if PREDICTOR == "stub" else f"@ {BASE}" | |
| print(f"Scoring {LABEL} on {len(records)} held-out examples {where}" | |
| f"{' [vision: screenshots only]' if VISION else ''}\n") | |
| n_valid = 0 | |
| tp = fp = fn = 0 | |
| end_ok = end_tot = 0 | |
| loc_ok = loc_tot = 0 | |
| rem_ok = rem_tot = 0 | |
| title_sims = [] | |
| loc_sims = [] | |
| no_event = {"ok": 0, "tot": 0} | |
| clarify = {"ok": 0, "tot": 0} | |
| rows = [] | |
| mismatches = [] # gold-vs-pred event detail for failing extraction examples | |
| for rec in records: | |
| now = datetime.fromisoformat(rec["now"]) | |
| try: | |
| pred, valid = predict(rec, now) | |
| except Exception as e: # noqa: BLE001 | |
| print(f" [{rec['id']}] request failed: {str(e)[:120]}") | |
| rows.append((rec["id"], rec["category"], "ERR")); continue | |
| n_valid += int(valid) | |
| gold = rec["gold"] | |
| g_ev, p_ev = gold["events"], [e.model_dump() for e in pred.events] | |
| pairs, e_fp, e_fn = match_events(g_ev, p_ev) | |
| tp += len(pairs); fp += e_fp; fn += e_fn | |
| for g, p in pairs: | |
| title_sims.append(sim(g.get("title"), p.get("title"))) | |
| if g.get("end"): | |
| end_tot += 1 | |
| end_ok += int(dt_key(g["end"]) == dt_key(p.get("end"))) | |
| if g.get("location"): | |
| loc_tot += 1 | |
| s = sim(g.get("location"), p.get("location")) | |
| loc_sims.append(s) | |
| loc_ok += int(bool(p.get("location")) and s >= 0.5) | |
| if g.get("reminder_minutes") is not None: | |
| rem_tot += 1 | |
| rem_ok += int(p.get("reminder_minutes") == g["reminder_minutes"]) | |
| cat = rec["category"] | |
| flag = "" | |
| if cat == "no_event": | |
| no_event["tot"] += 1 | |
| good = len(p_ev) == 0 | |
| no_event["ok"] += int(good) | |
| flag = "no-event OK" if good else f"HALLUCINATED {len(p_ev)}" | |
| elif cat == "clarify": | |
| clarify["tot"] += 1 | |
| good = bool(pred.needs_clarification) and len(p_ev) == 0 | |
| clarify["ok"] += int(good) | |
| flag = "asked OK" if good else "did NOT ask" | |
| else: | |
| flag = f"events {len(pairs)}/{len(g_ev)} matched" + (f", +{e_fp} extra" if e_fp else "") | |
| if e_fp or e_fn: # capture what the model actually produced, to diagnose | |
| mismatches.append({ | |
| "id": rec["id"], | |
| "gold": [(e.get("start"), e.get("title")) for e in g_ev], | |
| "pred": [(e.get("start"), e.get("title")) for e in p_ev], | |
| }) | |
| rows.append((rec["id"], cat, ("ok" if valid else "BADJSON") + " | " + flag)) | |
| n = len(records) | |
| prec = tp / (tp + fp) if (tp + fp) else 0.0 | |
| rec_ = tp / (tp + fn) if (tp + fn) else 0.0 | |
| f1 = 2 * prec * rec_ / (prec + rec_) if (prec + rec_) else 0.0 | |
| print("\nPer-example:") | |
| for rid, cat, info in rows: | |
| print(f" {rid:4s} {cat:9s} {info}") | |
| if mismatches: | |
| print("\nMismatch detail (start, title) — gold vs what the model produced:") | |
| for mm in mismatches: | |
| print(f" [{mm['id']}]") | |
| print(f" gold: {mm['gold']}") | |
| print(f" pred: {mm['pred']}") | |
| summary = { | |
| "model": LABEL, | |
| "n_examples": n, | |
| "schema_validity": round(n_valid / n, 3), | |
| "event_precision": round(prec, 3), | |
| "event_recall_start_exact": round(rec_, 3), | |
| "event_f1": round(f1, 3), | |
| "end_exact_rate": round(end_ok / end_tot, 3) if end_tot else None, | |
| "title_similarity_avg": round(sum(title_sims) / len(title_sims), 3) if title_sims else None, | |
| "location_recall": round(loc_ok / loc_tot, 3) if loc_tot else None, | |
| "location_similarity_avg": round(sum(loc_sims) / len(loc_sims), 3) if loc_sims else None, | |
| "reminder_match_rate": round(rem_ok / rem_tot, 3) if rem_tot else None, | |
| "no_event_accuracy": round(no_event["ok"] / no_event["tot"], 3) if no_event["tot"] else None, | |
| "clarification_recall": round(clarify["ok"] / clarify["tot"], 3) if clarify["tot"] else None, | |
| "events_tp_fp_fn": [tp, fp, fn], | |
| } | |
| print("\n==================== SUMMARY ====================") | |
| for k, v in summary.items(): | |
| print(f" {k:28s} {v}") | |
| print("RESULTS_JSON:", json.dumps(summary)) | |
| if __name__ == "__main__": | |
| main() | |