File size: 11,640 Bytes
0366d65 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 | """Score the scheduling fine-tune on the held-out set (training/data/eval.jsonl).
Reuses the deployed contract exactly: server.agent.SYSTEM + build_messages for the
prompt, server.schema.ActionPlan for parsing, and the same OpenAI-compatible
`{INFERENCE_BASE_URL}/chat/completions` call with a json_schema response_format.
Metrics (task-specific — generic LLM benchmarks don't apply here):
- schema validity : model output parses + validates as ActionPlan
- event precision/recall/F1 (matched on exact start datetime, minute precision)
- start-exact recall : did it nail the date AND time (incl. relative dates)
- end-exact rate : on matched events where gold has an end
- title similarity : difflib ratio on matched events (soft)
- location recall/sim : on matched events where gold has a location
- reminder match rate : on matched events where gold sets reminder_minutes.
NOTE: only eval_unstructured.jsonl golds encode the type rules (medical 60 /
party 30 / carpool-school 45); the structured eval.jsonl golds predate them,
so its reminder number is informational only.
- no-event accuracy : on chitchat, did it correctly return zero events
- clarification recall : on ambiguous threads, did it ask instead of inventing
Usage (needs a model serving an OpenAI-compatible endpoint):
INFERENCE_BASE_URL=http://127.0.0.1:8080/v1 python training/eval.py
A/B arms (same metrics, same gold):
PREDICTOR=stub python training/eval.py # arm A: the regex stub, no model/HTTP
VISION=1 python training/eval.py # arm C: the thread is fed ONLY as a
rendered screenshot (training/data/screenshots/<eval-stem>/<id>.png — make
them with training/render_screenshots.py); needs a vision-enabled server.
EVAL_PATH=training/data/eval_unstructured.jsonl ... # the non-formulaic set
"""
from __future__ import annotations
import base64
import json
import os
import re
import sys
from datetime import datetime
from difflib import SequenceMatcher
from pathlib import Path
import requests
if hasattr(sys.stdout, "reconfigure"): # emoji in threads vs Windows cp1252 console
sys.stdout.reconfigure(encoding="utf-8", errors="replace")
ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT))
from server.agent import ( # noqa: E402 (light import; model is lazy)
TITLES_SCHEMA,
apply_text_rules,
build_messages,
build_title_messages,
merge_titles,
)
from server.schema import ActionPlan # noqa: E402
BASE = os.environ.get("INFERENCE_BASE_URL", "http://127.0.0.1:8080/v1").rstrip("/")
MODEL = os.environ.get("INFERENCE_MODEL", "local")
PREDICTOR = os.environ.get("PREDICTOR", "model") # "model" (HTTP) | "stub" (regex)
VISION = os.environ.get("VISION") == "1" # feed the thread as a screenshot only
TITLE_POLISH = os.environ.get("TITLE_POLISH") == "1" # second pass rewriting titles
LABEL = os.environ.get("MODEL_LABEL",
"stub-regex" if PREDICTOR == "stub"
else (f"{MODEL}+vision" if VISION else MODEL))
EVAL_PATH = Path(os.environ.get("EVAL_PATH", ROOT / "training" / "data" / "eval.jsonl"))
SCREENSHOT_DIR = Path(os.environ.get(
"SCREENSHOT_DIR", ROOT / "training" / "data" / "screenshots")) / EVAL_PATH.stem
VISION_THREAD = "(the conversation is in the attached screenshot)"
SCHEMA = ActionPlan.model_json_schema()
def call_model(messages: list[dict], schema: dict = None, name: str = "ActionPlan",
max_tokens: int = 1024) -> str:
payload = {
"model": MODEL,
"messages": messages,
"temperature": 0.0, # greedy -> reproducible eval
"max_tokens": max_tokens,
"response_format": {
"type": "json_schema",
"json_schema": {"name": name, "schema": schema or SCHEMA, "strict": True},
},
"stream": False,
}
r = requests.post(f"{BASE}/chat/completions", json=payload, timeout=180)
r.raise_for_status()
return r.json()["choices"][0]["message"]["content"]
def parse_plan(raw: str):
try:
return ActionPlan(**json.loads(raw)), True
except Exception: # noqa: BLE001
m = re.search(r"\{.*\}", raw, re.DOTALL)
if m:
try:
return ActionPlan(**json.loads(m.group(0))), True
except Exception: # noqa: BLE001
pass
return ActionPlan(), False
def dt_key(s):
if not s:
return None
try:
return datetime.fromisoformat(s).strftime("%Y-%m-%dT%H:%M")
except Exception: # noqa: BLE001
return str(s)[:16]
def match_events(gold, pred):
"""Greedy one-to-one match on exact start (minute precision). Returns
(matched_pairs, n_fp, n_fn)."""
pred_left = list(pred)
pairs, fn = [], 0
for g in gold:
gk = dt_key(g.get("start"))
hit = next((p for p in pred_left if dt_key(p.get("start")) == gk), None)
if hit is not None:
pairs.append((g, hit))
pred_left.remove(hit)
else:
fn += 1
return pairs, len(pred_left), fn
def sim(a, b):
return round(SequenceMatcher(None, (a or "").lower(), (b or "").lower()).ratio(), 2)
def _screenshot_uri(rec_id: str) -> str:
png = SCREENSHOT_DIR / f"{rec_id}.png"
data = base64.b64encode(png.read_bytes()).decode()
return f"data:image/png;base64,{data}"
def predict(rec: dict, now: datetime):
"""Run one arm on one example -> (ActionPlan, schema_valid)."""
if PREDICTOR == "stub":
os.environ["USE_STUB_EXTRACTOR"] = "1"
from server.agent import run_agent # noqa: PLC0415 lazy, like the server
# memory_block="" keeps the eval from writing the server-side memory file
return run_agent(rec["thread"], now=now, memory_block=""), True
images = [_screenshot_uri(rec["id"])] if VISION else None
thread = VISION_THREAD if VISION else rec["thread"]
messages = build_messages(thread, now, [], images)
# MINIMAL_PROMPT=1: drop the system prompt entirely — measures how much task
# knowledge is INTERNALIZED (fine-tune) vs prompted-in (stock). User content
# (datetime + thread + "Return the ActionPlan JSON now.") stays identical.
if os.environ.get("MINIMAL_PROMPT") == "1":
messages = [m for m in messages if m["role"] != "system"]
plan, valid = parse_plan(call_model(messages))
# TITLE_POLISH=1: same second pass the server runs — rewrite titles only.
# Text mode only: the polish prompt quotes the thread, which VISION withholds.
if TITLE_POLISH and not VISION and plan.events:
raw = call_model(
build_title_messages(rec["thread"], [e.model_dump() for e in plan.events]),
schema=TITLES_SCHEMA, name="Titles", max_tokens=256,
)
plan = merge_titles(plan, raw)
# Same deterministic logistics post-pass the server applies (arrival shift +
# reminder rules). Text mode only — it reads the thread text.
if not VISION:
plan = apply_text_rules(rec["thread"], plan)
return plan, valid
def main():
records = [json.loads(l) for l in EVAL_PATH.read_text(encoding="utf-8").splitlines() if l.strip()]
where = "(local regex stub)" if PREDICTOR == "stub" else f"@ {BASE}"
print(f"Scoring {LABEL} on {len(records)} held-out examples {where}"
f"{' [vision: screenshots only]' if VISION else ''}\n")
n_valid = 0
tp = fp = fn = 0
end_ok = end_tot = 0
loc_ok = loc_tot = 0
rem_ok = rem_tot = 0
title_sims = []
loc_sims = []
no_event = {"ok": 0, "tot": 0}
clarify = {"ok": 0, "tot": 0}
rows = []
mismatches = [] # gold-vs-pred event detail for failing extraction examples
for rec in records:
now = datetime.fromisoformat(rec["now"])
try:
pred, valid = predict(rec, now)
except Exception as e: # noqa: BLE001
print(f" [{rec['id']}] request failed: {str(e)[:120]}")
rows.append((rec["id"], rec["category"], "ERR")); continue
n_valid += int(valid)
gold = rec["gold"]
g_ev, p_ev = gold["events"], [e.model_dump() for e in pred.events]
pairs, e_fp, e_fn = match_events(g_ev, p_ev)
tp += len(pairs); fp += e_fp; fn += e_fn
for g, p in pairs:
title_sims.append(sim(g.get("title"), p.get("title")))
if g.get("end"):
end_tot += 1
end_ok += int(dt_key(g["end"]) == dt_key(p.get("end")))
if g.get("location"):
loc_tot += 1
s = sim(g.get("location"), p.get("location"))
loc_sims.append(s)
loc_ok += int(bool(p.get("location")) and s >= 0.5)
if g.get("reminder_minutes") is not None:
rem_tot += 1
rem_ok += int(p.get("reminder_minutes") == g["reminder_minutes"])
cat = rec["category"]
flag = ""
if cat == "no_event":
no_event["tot"] += 1
good = len(p_ev) == 0
no_event["ok"] += int(good)
flag = "no-event OK" if good else f"HALLUCINATED {len(p_ev)}"
elif cat == "clarify":
clarify["tot"] += 1
good = bool(pred.needs_clarification) and len(p_ev) == 0
clarify["ok"] += int(good)
flag = "asked OK" if good else "did NOT ask"
else:
flag = f"events {len(pairs)}/{len(g_ev)} matched" + (f", +{e_fp} extra" if e_fp else "")
if e_fp or e_fn: # capture what the model actually produced, to diagnose
mismatches.append({
"id": rec["id"],
"gold": [(e.get("start"), e.get("title")) for e in g_ev],
"pred": [(e.get("start"), e.get("title")) for e in p_ev],
})
rows.append((rec["id"], cat, ("ok" if valid else "BADJSON") + " | " + flag))
n = len(records)
prec = tp / (tp + fp) if (tp + fp) else 0.0
rec_ = tp / (tp + fn) if (tp + fn) else 0.0
f1 = 2 * prec * rec_ / (prec + rec_) if (prec + rec_) else 0.0
print("\nPer-example:")
for rid, cat, info in rows:
print(f" {rid:4s} {cat:9s} {info}")
if mismatches:
print("\nMismatch detail (start, title) — gold vs what the model produced:")
for mm in mismatches:
print(f" [{mm['id']}]")
print(f" gold: {mm['gold']}")
print(f" pred: {mm['pred']}")
summary = {
"model": LABEL,
"n_examples": n,
"schema_validity": round(n_valid / n, 3),
"event_precision": round(prec, 3),
"event_recall_start_exact": round(rec_, 3),
"event_f1": round(f1, 3),
"end_exact_rate": round(end_ok / end_tot, 3) if end_tot else None,
"title_similarity_avg": round(sum(title_sims) / len(title_sims), 3) if title_sims else None,
"location_recall": round(loc_ok / loc_tot, 3) if loc_tot else None,
"location_similarity_avg": round(sum(loc_sims) / len(loc_sims), 3) if loc_sims else None,
"reminder_match_rate": round(rem_ok / rem_tot, 3) if rem_tot else None,
"no_event_accuracy": round(no_event["ok"] / no_event["tot"], 3) if no_event["tot"] else None,
"clarification_recall": round(clarify["ok"] / clarify["tot"], 3) if clarify["tot"] else None,
"events_tp_fp_fn": [tp, fp, fn],
}
print("\n==================== SUMMARY ====================")
for k, v in summary.items():
print(f" {k:28s} {v}")
print("RESULTS_JSON:", json.dumps(summary))
if __name__ == "__main__":
main()
|