ParetoOptimal's picture
Initial Commit
0366d65
Raw
History Blame Contribute Delete
11.6 kB
"""Score the scheduling fine-tune on the held-out set (training/data/eval.jsonl).
Reuses the deployed contract exactly: server.agent.SYSTEM + build_messages for the
prompt, server.schema.ActionPlan for parsing, and the same OpenAI-compatible
`{INFERENCE_BASE_URL}/chat/completions` call with a json_schema response_format.
Metrics (task-specific — generic LLM benchmarks don't apply here):
- schema validity : model output parses + validates as ActionPlan
- event precision/recall/F1 (matched on exact start datetime, minute precision)
- start-exact recall : did it nail the date AND time (incl. relative dates)
- end-exact rate : on matched events where gold has an end
- title similarity : difflib ratio on matched events (soft)
- location recall/sim : on matched events where gold has a location
- reminder match rate : on matched events where gold sets reminder_minutes.
NOTE: only eval_unstructured.jsonl golds encode the type rules (medical 60 /
party 30 / carpool-school 45); the structured eval.jsonl golds predate them,
so its reminder number is informational only.
- no-event accuracy : on chitchat, did it correctly return zero events
- clarification recall : on ambiguous threads, did it ask instead of inventing
Usage (needs a model serving an OpenAI-compatible endpoint):
INFERENCE_BASE_URL=http://127.0.0.1:8080/v1 python training/eval.py
A/B arms (same metrics, same gold):
PREDICTOR=stub python training/eval.py # arm A: the regex stub, no model/HTTP
VISION=1 python training/eval.py # arm C: the thread is fed ONLY as a
rendered screenshot (training/data/screenshots/<eval-stem>/<id>.png — make
them with training/render_screenshots.py); needs a vision-enabled server.
EVAL_PATH=training/data/eval_unstructured.jsonl ... # the non-formulaic set
"""
from __future__ import annotations
import base64
import json
import os
import re
import sys
from datetime import datetime
from difflib import SequenceMatcher
from pathlib import Path
import requests
if hasattr(sys.stdout, "reconfigure"): # emoji in threads vs Windows cp1252 console
sys.stdout.reconfigure(encoding="utf-8", errors="replace")
ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT))
from server.agent import ( # noqa: E402 (light import; model is lazy)
TITLES_SCHEMA,
apply_text_rules,
build_messages,
build_title_messages,
merge_titles,
)
from server.schema import ActionPlan # noqa: E402
BASE = os.environ.get("INFERENCE_BASE_URL", "http://127.0.0.1:8080/v1").rstrip("/")
MODEL = os.environ.get("INFERENCE_MODEL", "local")
PREDICTOR = os.environ.get("PREDICTOR", "model") # "model" (HTTP) | "stub" (regex)
VISION = os.environ.get("VISION") == "1" # feed the thread as a screenshot only
TITLE_POLISH = os.environ.get("TITLE_POLISH") == "1" # second pass rewriting titles
LABEL = os.environ.get("MODEL_LABEL",
"stub-regex" if PREDICTOR == "stub"
else (f"{MODEL}+vision" if VISION else MODEL))
EVAL_PATH = Path(os.environ.get("EVAL_PATH", ROOT / "training" / "data" / "eval.jsonl"))
SCREENSHOT_DIR = Path(os.environ.get(
"SCREENSHOT_DIR", ROOT / "training" / "data" / "screenshots")) / EVAL_PATH.stem
VISION_THREAD = "(the conversation is in the attached screenshot)"
SCHEMA = ActionPlan.model_json_schema()
def call_model(messages: list[dict], schema: dict = None, name: str = "ActionPlan",
max_tokens: int = 1024) -> str:
payload = {
"model": MODEL,
"messages": messages,
"temperature": 0.0, # greedy -> reproducible eval
"max_tokens": max_tokens,
"response_format": {
"type": "json_schema",
"json_schema": {"name": name, "schema": schema or SCHEMA, "strict": True},
},
"stream": False,
}
r = requests.post(f"{BASE}/chat/completions", json=payload, timeout=180)
r.raise_for_status()
return r.json()["choices"][0]["message"]["content"]
def parse_plan(raw: str):
try:
return ActionPlan(**json.loads(raw)), True
except Exception: # noqa: BLE001
m = re.search(r"\{.*\}", raw, re.DOTALL)
if m:
try:
return ActionPlan(**json.loads(m.group(0))), True
except Exception: # noqa: BLE001
pass
return ActionPlan(), False
def dt_key(s):
if not s:
return None
try:
return datetime.fromisoformat(s).strftime("%Y-%m-%dT%H:%M")
except Exception: # noqa: BLE001
return str(s)[:16]
def match_events(gold, pred):
"""Greedy one-to-one match on exact start (minute precision). Returns
(matched_pairs, n_fp, n_fn)."""
pred_left = list(pred)
pairs, fn = [], 0
for g in gold:
gk = dt_key(g.get("start"))
hit = next((p for p in pred_left if dt_key(p.get("start")) == gk), None)
if hit is not None:
pairs.append((g, hit))
pred_left.remove(hit)
else:
fn += 1
return pairs, len(pred_left), fn
def sim(a, b):
return round(SequenceMatcher(None, (a or "").lower(), (b or "").lower()).ratio(), 2)
def _screenshot_uri(rec_id: str) -> str:
png = SCREENSHOT_DIR / f"{rec_id}.png"
data = base64.b64encode(png.read_bytes()).decode()
return f"data:image/png;base64,{data}"
def predict(rec: dict, now: datetime):
"""Run one arm on one example -> (ActionPlan, schema_valid)."""
if PREDICTOR == "stub":
os.environ["USE_STUB_EXTRACTOR"] = "1"
from server.agent import run_agent # noqa: PLC0415 lazy, like the server
# memory_block="" keeps the eval from writing the server-side memory file
return run_agent(rec["thread"], now=now, memory_block=""), True
images = [_screenshot_uri(rec["id"])] if VISION else None
thread = VISION_THREAD if VISION else rec["thread"]
messages = build_messages(thread, now, [], images)
# MINIMAL_PROMPT=1: drop the system prompt entirely — measures how much task
# knowledge is INTERNALIZED (fine-tune) vs prompted-in (stock). User content
# (datetime + thread + "Return the ActionPlan JSON now.") stays identical.
if os.environ.get("MINIMAL_PROMPT") == "1":
messages = [m for m in messages if m["role"] != "system"]
plan, valid = parse_plan(call_model(messages))
# TITLE_POLISH=1: same second pass the server runs — rewrite titles only.
# Text mode only: the polish prompt quotes the thread, which VISION withholds.
if TITLE_POLISH and not VISION and plan.events:
raw = call_model(
build_title_messages(rec["thread"], [e.model_dump() for e in plan.events]),
schema=TITLES_SCHEMA, name="Titles", max_tokens=256,
)
plan = merge_titles(plan, raw)
# Same deterministic logistics post-pass the server applies (arrival shift +
# reminder rules). Text mode only — it reads the thread text.
if not VISION:
plan = apply_text_rules(rec["thread"], plan)
return plan, valid
def main():
records = [json.loads(l) for l in EVAL_PATH.read_text(encoding="utf-8").splitlines() if l.strip()]
where = "(local regex stub)" if PREDICTOR == "stub" else f"@ {BASE}"
print(f"Scoring {LABEL} on {len(records)} held-out examples {where}"
f"{' [vision: screenshots only]' if VISION else ''}\n")
n_valid = 0
tp = fp = fn = 0
end_ok = end_tot = 0
loc_ok = loc_tot = 0
rem_ok = rem_tot = 0
title_sims = []
loc_sims = []
no_event = {"ok": 0, "tot": 0}
clarify = {"ok": 0, "tot": 0}
rows = []
mismatches = [] # gold-vs-pred event detail for failing extraction examples
for rec in records:
now = datetime.fromisoformat(rec["now"])
try:
pred, valid = predict(rec, now)
except Exception as e: # noqa: BLE001
print(f" [{rec['id']}] request failed: {str(e)[:120]}")
rows.append((rec["id"], rec["category"], "ERR")); continue
n_valid += int(valid)
gold = rec["gold"]
g_ev, p_ev = gold["events"], [e.model_dump() for e in pred.events]
pairs, e_fp, e_fn = match_events(g_ev, p_ev)
tp += len(pairs); fp += e_fp; fn += e_fn
for g, p in pairs:
title_sims.append(sim(g.get("title"), p.get("title")))
if g.get("end"):
end_tot += 1
end_ok += int(dt_key(g["end"]) == dt_key(p.get("end")))
if g.get("location"):
loc_tot += 1
s = sim(g.get("location"), p.get("location"))
loc_sims.append(s)
loc_ok += int(bool(p.get("location")) and s >= 0.5)
if g.get("reminder_minutes") is not None:
rem_tot += 1
rem_ok += int(p.get("reminder_minutes") == g["reminder_minutes"])
cat = rec["category"]
flag = ""
if cat == "no_event":
no_event["tot"] += 1
good = len(p_ev) == 0
no_event["ok"] += int(good)
flag = "no-event OK" if good else f"HALLUCINATED {len(p_ev)}"
elif cat == "clarify":
clarify["tot"] += 1
good = bool(pred.needs_clarification) and len(p_ev) == 0
clarify["ok"] += int(good)
flag = "asked OK" if good else "did NOT ask"
else:
flag = f"events {len(pairs)}/{len(g_ev)} matched" + (f", +{e_fp} extra" if e_fp else "")
if e_fp or e_fn: # capture what the model actually produced, to diagnose
mismatches.append({
"id": rec["id"],
"gold": [(e.get("start"), e.get("title")) for e in g_ev],
"pred": [(e.get("start"), e.get("title")) for e in p_ev],
})
rows.append((rec["id"], cat, ("ok" if valid else "BADJSON") + " | " + flag))
n = len(records)
prec = tp / (tp + fp) if (tp + fp) else 0.0
rec_ = tp / (tp + fn) if (tp + fn) else 0.0
f1 = 2 * prec * rec_ / (prec + rec_) if (prec + rec_) else 0.0
print("\nPer-example:")
for rid, cat, info in rows:
print(f" {rid:4s} {cat:9s} {info}")
if mismatches:
print("\nMismatch detail (start, title) — gold vs what the model produced:")
for mm in mismatches:
print(f" [{mm['id']}]")
print(f" gold: {mm['gold']}")
print(f" pred: {mm['pred']}")
summary = {
"model": LABEL,
"n_examples": n,
"schema_validity": round(n_valid / n, 3),
"event_precision": round(prec, 3),
"event_recall_start_exact": round(rec_, 3),
"event_f1": round(f1, 3),
"end_exact_rate": round(end_ok / end_tot, 3) if end_tot else None,
"title_similarity_avg": round(sum(title_sims) / len(title_sims), 3) if title_sims else None,
"location_recall": round(loc_ok / loc_tot, 3) if loc_tot else None,
"location_similarity_avg": round(sum(loc_sims) / len(loc_sims), 3) if loc_sims else None,
"reminder_match_rate": round(rem_ok / rem_tot, 3) if rem_tot else None,
"no_event_accuracy": round(no_event["ok"] / no_event["tot"], 3) if no_event["tot"] else None,
"clarification_recall": round(clarify["ok"] / clarify["tot"], 3) if clarify["tot"] else None,
"events_tp_fp_fn": [tp, fp, fn],
}
print("\n==================== SUMMARY ====================")
for k, v in summary.items():
print(f" {k:28s} {v}")
print("RESULTS_JSON:", json.dumps(summary))
if __name__ == "__main__":
main()