| """Score the scheduling fine-tune on the held-out set (training/data/eval.jsonl). |
| |
| Reuses the deployed contract exactly: server.agent.SYSTEM + build_messages for the |
| prompt, server.schema.ActionPlan for parsing, and the same OpenAI-compatible |
| `{INFERENCE_BASE_URL}/chat/completions` call with a json_schema response_format. |
| |
| Metrics (task-specific — generic LLM benchmarks don't apply here): |
| - schema validity : model output parses + validates as ActionPlan |
| - event precision/recall/F1 (matched on exact start datetime, minute precision) |
| - start-exact recall : did it nail the date AND time (incl. relative dates) |
| - end-exact rate : on matched events where gold has an end |
| - title similarity : difflib ratio on matched events (soft) |
| - location recall/sim : on matched events where gold has a location |
| - reminder match rate : on matched events where gold sets reminder_minutes. |
| NOTE: only eval_unstructured.jsonl golds encode the type rules (medical 60 / |
| party 30 / carpool-school 45); the structured eval.jsonl golds predate them, |
| so its reminder number is informational only. |
| - no-event accuracy : on chitchat, did it correctly return zero events |
| - clarification recall : on ambiguous threads, did it ask instead of inventing |
| |
| Usage (needs a model serving an OpenAI-compatible endpoint): |
| INFERENCE_BASE_URL=http://127.0.0.1:8080/v1 python training/eval.py |
| |
| A/B arms (same metrics, same gold): |
| PREDICTOR=stub python training/eval.py # arm A: the regex stub, no model/HTTP |
| VISION=1 python training/eval.py # arm C: the thread is fed ONLY as a |
| rendered screenshot (training/data/screenshots/<eval-stem>/<id>.png — make |
| them with training/render_screenshots.py); needs a vision-enabled server. |
| EVAL_PATH=training/data/eval_unstructured.jsonl ... # the non-formulaic set |
| """ |
| from __future__ import annotations |
|
|
| import base64 |
| import json |
| import os |
| import re |
| import sys |
| from datetime import datetime |
| from difflib import SequenceMatcher |
| from pathlib import Path |
|
|
| import requests |
|
|
| if hasattr(sys.stdout, "reconfigure"): |
| sys.stdout.reconfigure(encoding="utf-8", errors="replace") |
|
|
| ROOT = Path(__file__).resolve().parent.parent |
| sys.path.insert(0, str(ROOT)) |
| from server.agent import ( |
| TITLES_SCHEMA, |
| apply_text_rules, |
| build_messages, |
| build_title_messages, |
| merge_titles, |
| ) |
| from server.schema import ActionPlan |
|
|
| BASE = os.environ.get("INFERENCE_BASE_URL", "http://127.0.0.1:8080/v1").rstrip("/") |
| MODEL = os.environ.get("INFERENCE_MODEL", "local") |
| PREDICTOR = os.environ.get("PREDICTOR", "model") |
| VISION = os.environ.get("VISION") == "1" |
| TITLE_POLISH = os.environ.get("TITLE_POLISH") == "1" |
| LABEL = os.environ.get("MODEL_LABEL", |
| "stub-regex" if PREDICTOR == "stub" |
| else (f"{MODEL}+vision" if VISION else MODEL)) |
| EVAL_PATH = Path(os.environ.get("EVAL_PATH", ROOT / "training" / "data" / "eval.jsonl")) |
| SCREENSHOT_DIR = Path(os.environ.get( |
| "SCREENSHOT_DIR", ROOT / "training" / "data" / "screenshots")) / EVAL_PATH.stem |
| VISION_THREAD = "(the conversation is in the attached screenshot)" |
| SCHEMA = ActionPlan.model_json_schema() |
|
|
|
|
| def call_model(messages: list[dict], schema: dict = None, name: str = "ActionPlan", |
| max_tokens: int = 1024) -> str: |
| payload = { |
| "model": MODEL, |
| "messages": messages, |
| "temperature": 0.0, |
| "max_tokens": max_tokens, |
| "response_format": { |
| "type": "json_schema", |
| "json_schema": {"name": name, "schema": schema or SCHEMA, "strict": True}, |
| }, |
| "stream": False, |
| } |
| r = requests.post(f"{BASE}/chat/completions", json=payload, timeout=180) |
| r.raise_for_status() |
| return r.json()["choices"][0]["message"]["content"] |
|
|
|
|
| def parse_plan(raw: str): |
| try: |
| return ActionPlan(**json.loads(raw)), True |
| except Exception: |
| m = re.search(r"\{.*\}", raw, re.DOTALL) |
| if m: |
| try: |
| return ActionPlan(**json.loads(m.group(0))), True |
| except Exception: |
| pass |
| return ActionPlan(), False |
|
|
|
|
| def dt_key(s): |
| if not s: |
| return None |
| try: |
| return datetime.fromisoformat(s).strftime("%Y-%m-%dT%H:%M") |
| except Exception: |
| return str(s)[:16] |
|
|
|
|
| def match_events(gold, pred): |
| """Greedy one-to-one match on exact start (minute precision). Returns |
| (matched_pairs, n_fp, n_fn).""" |
| pred_left = list(pred) |
| pairs, fn = [], 0 |
| for g in gold: |
| gk = dt_key(g.get("start")) |
| hit = next((p for p in pred_left if dt_key(p.get("start")) == gk), None) |
| if hit is not None: |
| pairs.append((g, hit)) |
| pred_left.remove(hit) |
| else: |
| fn += 1 |
| return pairs, len(pred_left), fn |
|
|
|
|
| def sim(a, b): |
| return round(SequenceMatcher(None, (a or "").lower(), (b or "").lower()).ratio(), 2) |
|
|
|
|
| def _screenshot_uri(rec_id: str) -> str: |
| png = SCREENSHOT_DIR / f"{rec_id}.png" |
| data = base64.b64encode(png.read_bytes()).decode() |
| return f"data:image/png;base64,{data}" |
|
|
|
|
| def predict(rec: dict, now: datetime): |
| """Run one arm on one example -> (ActionPlan, schema_valid).""" |
| if PREDICTOR == "stub": |
| os.environ["USE_STUB_EXTRACTOR"] = "1" |
| from server.agent import run_agent |
|
|
| |
| return run_agent(rec["thread"], now=now, memory_block=""), True |
| images = [_screenshot_uri(rec["id"])] if VISION else None |
| thread = VISION_THREAD if VISION else rec["thread"] |
| messages = build_messages(thread, now, [], images) |
| |
| |
| |
| if os.environ.get("MINIMAL_PROMPT") == "1": |
| messages = [m for m in messages if m["role"] != "system"] |
| plan, valid = parse_plan(call_model(messages)) |
| |
| |
| if TITLE_POLISH and not VISION and plan.events: |
| raw = call_model( |
| build_title_messages(rec["thread"], [e.model_dump() for e in plan.events]), |
| schema=TITLES_SCHEMA, name="Titles", max_tokens=256, |
| ) |
| plan = merge_titles(plan, raw) |
| |
| |
| if not VISION: |
| plan = apply_text_rules(rec["thread"], plan) |
| return plan, valid |
|
|
|
|
| def main(): |
| records = [json.loads(l) for l in EVAL_PATH.read_text(encoding="utf-8").splitlines() if l.strip()] |
| where = "(local regex stub)" if PREDICTOR == "stub" else f"@ {BASE}" |
| print(f"Scoring {LABEL} on {len(records)} held-out examples {where}" |
| f"{' [vision: screenshots only]' if VISION else ''}\n") |
|
|
| n_valid = 0 |
| tp = fp = fn = 0 |
| end_ok = end_tot = 0 |
| loc_ok = loc_tot = 0 |
| rem_ok = rem_tot = 0 |
| title_sims = [] |
| loc_sims = [] |
| no_event = {"ok": 0, "tot": 0} |
| clarify = {"ok": 0, "tot": 0} |
| rows = [] |
| mismatches = [] |
|
|
| for rec in records: |
| now = datetime.fromisoformat(rec["now"]) |
| try: |
| pred, valid = predict(rec, now) |
| except Exception as e: |
| print(f" [{rec['id']}] request failed: {str(e)[:120]}") |
| rows.append((rec["id"], rec["category"], "ERR")); continue |
|
|
| n_valid += int(valid) |
| gold = rec["gold"] |
| g_ev, p_ev = gold["events"], [e.model_dump() for e in pred.events] |
|
|
| pairs, e_fp, e_fn = match_events(g_ev, p_ev) |
| tp += len(pairs); fp += e_fp; fn += e_fn |
| for g, p in pairs: |
| title_sims.append(sim(g.get("title"), p.get("title"))) |
| if g.get("end"): |
| end_tot += 1 |
| end_ok += int(dt_key(g["end"]) == dt_key(p.get("end"))) |
| if g.get("location"): |
| loc_tot += 1 |
| s = sim(g.get("location"), p.get("location")) |
| loc_sims.append(s) |
| loc_ok += int(bool(p.get("location")) and s >= 0.5) |
| if g.get("reminder_minutes") is not None: |
| rem_tot += 1 |
| rem_ok += int(p.get("reminder_minutes") == g["reminder_minutes"]) |
|
|
| cat = rec["category"] |
| flag = "" |
| if cat == "no_event": |
| no_event["tot"] += 1 |
| good = len(p_ev) == 0 |
| no_event["ok"] += int(good) |
| flag = "no-event OK" if good else f"HALLUCINATED {len(p_ev)}" |
| elif cat == "clarify": |
| clarify["tot"] += 1 |
| good = bool(pred.needs_clarification) and len(p_ev) == 0 |
| clarify["ok"] += int(good) |
| flag = "asked OK" if good else "did NOT ask" |
| else: |
| flag = f"events {len(pairs)}/{len(g_ev)} matched" + (f", +{e_fp} extra" if e_fp else "") |
| if e_fp or e_fn: |
| mismatches.append({ |
| "id": rec["id"], |
| "gold": [(e.get("start"), e.get("title")) for e in g_ev], |
| "pred": [(e.get("start"), e.get("title")) for e in p_ev], |
| }) |
| rows.append((rec["id"], cat, ("ok" if valid else "BADJSON") + " | " + flag)) |
|
|
| n = len(records) |
| prec = tp / (tp + fp) if (tp + fp) else 0.0 |
| rec_ = tp / (tp + fn) if (tp + fn) else 0.0 |
| f1 = 2 * prec * rec_ / (prec + rec_) if (prec + rec_) else 0.0 |
|
|
| print("\nPer-example:") |
| for rid, cat, info in rows: |
| print(f" {rid:4s} {cat:9s} {info}") |
|
|
| if mismatches: |
| print("\nMismatch detail (start, title) — gold vs what the model produced:") |
| for mm in mismatches: |
| print(f" [{mm['id']}]") |
| print(f" gold: {mm['gold']}") |
| print(f" pred: {mm['pred']}") |
|
|
| summary = { |
| "model": LABEL, |
| "n_examples": n, |
| "schema_validity": round(n_valid / n, 3), |
| "event_precision": round(prec, 3), |
| "event_recall_start_exact": round(rec_, 3), |
| "event_f1": round(f1, 3), |
| "end_exact_rate": round(end_ok / end_tot, 3) if end_tot else None, |
| "title_similarity_avg": round(sum(title_sims) / len(title_sims), 3) if title_sims else None, |
| "location_recall": round(loc_ok / loc_tot, 3) if loc_tot else None, |
| "location_similarity_avg": round(sum(loc_sims) / len(loc_sims), 3) if loc_sims else None, |
| "reminder_match_rate": round(rem_ok / rem_tot, 3) if rem_tot else None, |
| "no_event_accuracy": round(no_event["ok"] / no_event["tot"], 3) if no_event["tot"] else None, |
| "clarification_recall": round(clarify["ok"] / clarify["tot"], 3) if clarify["tot"] else None, |
| "events_tp_fp_fn": [tp, fp, fn], |
| } |
| print("\n==================== SUMMARY ====================") |
| for k, v in summary.items(): |
| print(f" {k:28s} {v}") |
| print("RESULTS_JSON:", json.dumps(summary)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|