File size: 11,640 Bytes
0366d65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
"""Score the scheduling fine-tune on the held-out set (training/data/eval.jsonl).

Reuses the deployed contract exactly: server.agent.SYSTEM + build_messages for the
prompt, server.schema.ActionPlan for parsing, and the same OpenAI-compatible
`{INFERENCE_BASE_URL}/chat/completions` call with a json_schema response_format.

Metrics (task-specific — generic LLM benchmarks don't apply here):
  - schema validity        : model output parses + validates as ActionPlan
  - event precision/recall/F1 (matched on exact start datetime, minute precision)
  - start-exact recall     : did it nail the date AND time (incl. relative dates)
  - end-exact rate         : on matched events where gold has an end
  - title similarity       : difflib ratio on matched events (soft)
  - location recall/sim    : on matched events where gold has a location
  - reminder match rate    : on matched events where gold sets reminder_minutes.
      NOTE: only eval_unstructured.jsonl golds encode the type rules (medical 60 /
      party 30 / carpool-school 45); the structured eval.jsonl golds predate them,
      so its reminder number is informational only.
  - no-event accuracy      : on chitchat, did it correctly return zero events
  - clarification recall   : on ambiguous threads, did it ask instead of inventing

Usage (needs a model serving an OpenAI-compatible endpoint):
  INFERENCE_BASE_URL=http://127.0.0.1:8080/v1 python training/eval.py

A/B arms (same metrics, same gold):
  PREDICTOR=stub python training/eval.py     # arm A: the regex stub, no model/HTTP
  VISION=1 python training/eval.py           # arm C: the thread is fed ONLY as a
      rendered screenshot (training/data/screenshots/<eval-stem>/<id>.png — make
      them with training/render_screenshots.py); needs a vision-enabled server.
  EVAL_PATH=training/data/eval_unstructured.jsonl ...   # the non-formulaic set
"""
from __future__ import annotations

import base64
import json
import os
import re
import sys
from datetime import datetime
from difflib import SequenceMatcher
from pathlib import Path

import requests

if hasattr(sys.stdout, "reconfigure"):  # emoji in threads vs Windows cp1252 console
    sys.stdout.reconfigure(encoding="utf-8", errors="replace")

ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT))
from server.agent import (  # noqa: E402  (light import; model is lazy)
    TITLES_SCHEMA,
    apply_text_rules,
    build_messages,
    build_title_messages,
    merge_titles,
)
from server.schema import ActionPlan  # noqa: E402

BASE = os.environ.get("INFERENCE_BASE_URL", "http://127.0.0.1:8080/v1").rstrip("/")
MODEL = os.environ.get("INFERENCE_MODEL", "local")
PREDICTOR = os.environ.get("PREDICTOR", "model")  # "model" (HTTP) | "stub" (regex)
VISION = os.environ.get("VISION") == "1"  # feed the thread as a screenshot only
TITLE_POLISH = os.environ.get("TITLE_POLISH") == "1"  # second pass rewriting titles
LABEL = os.environ.get("MODEL_LABEL",
                       "stub-regex" if PREDICTOR == "stub"
                       else (f"{MODEL}+vision" if VISION else MODEL))
EVAL_PATH = Path(os.environ.get("EVAL_PATH", ROOT / "training" / "data" / "eval.jsonl"))
SCREENSHOT_DIR = Path(os.environ.get(
    "SCREENSHOT_DIR", ROOT / "training" / "data" / "screenshots")) / EVAL_PATH.stem
VISION_THREAD = "(the conversation is in the attached screenshot)"
SCHEMA = ActionPlan.model_json_schema()


def call_model(messages: list[dict], schema: dict = None, name: str = "ActionPlan",
               max_tokens: int = 1024) -> str:
    payload = {
        "model": MODEL,
        "messages": messages,
        "temperature": 0.0,  # greedy -> reproducible eval
        "max_tokens": max_tokens,
        "response_format": {
            "type": "json_schema",
            "json_schema": {"name": name, "schema": schema or SCHEMA, "strict": True},
        },
        "stream": False,
    }
    r = requests.post(f"{BASE}/chat/completions", json=payload, timeout=180)
    r.raise_for_status()
    return r.json()["choices"][0]["message"]["content"]


def parse_plan(raw: str):
    try:
        return ActionPlan(**json.loads(raw)), True
    except Exception:  # noqa: BLE001
        m = re.search(r"\{.*\}", raw, re.DOTALL)
        if m:
            try:
                return ActionPlan(**json.loads(m.group(0))), True
            except Exception:  # noqa: BLE001
                pass
        return ActionPlan(), False


def dt_key(s):
    if not s:
        return None
    try:
        return datetime.fromisoformat(s).strftime("%Y-%m-%dT%H:%M")
    except Exception:  # noqa: BLE001
        return str(s)[:16]


def match_events(gold, pred):
    """Greedy one-to-one match on exact start (minute precision). Returns
    (matched_pairs, n_fp, n_fn)."""
    pred_left = list(pred)
    pairs, fn = [], 0
    for g in gold:
        gk = dt_key(g.get("start"))
        hit = next((p for p in pred_left if dt_key(p.get("start")) == gk), None)
        if hit is not None:
            pairs.append((g, hit))
            pred_left.remove(hit)
        else:
            fn += 1
    return pairs, len(pred_left), fn


def sim(a, b):
    return round(SequenceMatcher(None, (a or "").lower(), (b or "").lower()).ratio(), 2)


def _screenshot_uri(rec_id: str) -> str:
    png = SCREENSHOT_DIR / f"{rec_id}.png"
    data = base64.b64encode(png.read_bytes()).decode()
    return f"data:image/png;base64,{data}"


def predict(rec: dict, now: datetime):
    """Run one arm on one example -> (ActionPlan, schema_valid)."""
    if PREDICTOR == "stub":
        os.environ["USE_STUB_EXTRACTOR"] = "1"
        from server.agent import run_agent  # noqa: PLC0415  lazy, like the server

        # memory_block="" keeps the eval from writing the server-side memory file
        return run_agent(rec["thread"], now=now, memory_block=""), True
    images = [_screenshot_uri(rec["id"])] if VISION else None
    thread = VISION_THREAD if VISION else rec["thread"]
    messages = build_messages(thread, now, [], images)
    # MINIMAL_PROMPT=1: drop the system prompt entirely — measures how much task
    # knowledge is INTERNALIZED (fine-tune) vs prompted-in (stock). User content
    # (datetime + thread + "Return the ActionPlan JSON now.") stays identical.
    if os.environ.get("MINIMAL_PROMPT") == "1":
        messages = [m for m in messages if m["role"] != "system"]
    plan, valid = parse_plan(call_model(messages))
    # TITLE_POLISH=1: same second pass the server runs — rewrite titles only.
    # Text mode only: the polish prompt quotes the thread, which VISION withholds.
    if TITLE_POLISH and not VISION and plan.events:
        raw = call_model(
            build_title_messages(rec["thread"], [e.model_dump() for e in plan.events]),
            schema=TITLES_SCHEMA, name="Titles", max_tokens=256,
        )
        plan = merge_titles(plan, raw)
    # Same deterministic logistics post-pass the server applies (arrival shift +
    # reminder rules). Text mode only — it reads the thread text.
    if not VISION:
        plan = apply_text_rules(rec["thread"], plan)
    return plan, valid


def main():
    records = [json.loads(l) for l in EVAL_PATH.read_text(encoding="utf-8").splitlines() if l.strip()]
    where = "(local regex stub)" if PREDICTOR == "stub" else f"@ {BASE}"
    print(f"Scoring {LABEL} on {len(records)} held-out examples {where}"
          f"{' [vision: screenshots only]' if VISION else ''}\n")

    n_valid = 0
    tp = fp = fn = 0
    end_ok = end_tot = 0
    loc_ok = loc_tot = 0
    rem_ok = rem_tot = 0
    title_sims = []
    loc_sims = []
    no_event = {"ok": 0, "tot": 0}
    clarify = {"ok": 0, "tot": 0}
    rows = []
    mismatches = []  # gold-vs-pred event detail for failing extraction examples

    for rec in records:
        now = datetime.fromisoformat(rec["now"])
        try:
            pred, valid = predict(rec, now)
        except Exception as e:  # noqa: BLE001
            print(f"  [{rec['id']}] request failed: {str(e)[:120]}")
            rows.append((rec["id"], rec["category"], "ERR")); continue

        n_valid += int(valid)
        gold = rec["gold"]
        g_ev, p_ev = gold["events"], [e.model_dump() for e in pred.events]

        pairs, e_fp, e_fn = match_events(g_ev, p_ev)
        tp += len(pairs); fp += e_fp; fn += e_fn
        for g, p in pairs:
            title_sims.append(sim(g.get("title"), p.get("title")))
            if g.get("end"):
                end_tot += 1
                end_ok += int(dt_key(g["end"]) == dt_key(p.get("end")))
            if g.get("location"):
                loc_tot += 1
                s = sim(g.get("location"), p.get("location"))
                loc_sims.append(s)
                loc_ok += int(bool(p.get("location")) and s >= 0.5)
            if g.get("reminder_minutes") is not None:
                rem_tot += 1
                rem_ok += int(p.get("reminder_minutes") == g["reminder_minutes"])

        cat = rec["category"]
        flag = ""
        if cat == "no_event":
            no_event["tot"] += 1
            good = len(p_ev) == 0
            no_event["ok"] += int(good)
            flag = "no-event OK" if good else f"HALLUCINATED {len(p_ev)}"
        elif cat == "clarify":
            clarify["tot"] += 1
            good = bool(pred.needs_clarification) and len(p_ev) == 0
            clarify["ok"] += int(good)
            flag = "asked OK" if good else "did NOT ask"
        else:
            flag = f"events {len(pairs)}/{len(g_ev)} matched" + (f", +{e_fp} extra" if e_fp else "")
            if e_fp or e_fn:  # capture what the model actually produced, to diagnose
                mismatches.append({
                    "id": rec["id"],
                    "gold": [(e.get("start"), e.get("title")) for e in g_ev],
                    "pred": [(e.get("start"), e.get("title")) for e in p_ev],
                })
        rows.append((rec["id"], cat, ("ok" if valid else "BADJSON") + " | " + flag))

    n = len(records)
    prec = tp / (tp + fp) if (tp + fp) else 0.0
    rec_ = tp / (tp + fn) if (tp + fn) else 0.0
    f1 = 2 * prec * rec_ / (prec + rec_) if (prec + rec_) else 0.0

    print("\nPer-example:")
    for rid, cat, info in rows:
        print(f"  {rid:4s} {cat:9s} {info}")

    if mismatches:
        print("\nMismatch detail (start, title) — gold vs what the model produced:")
        for mm in mismatches:
            print(f"  [{mm['id']}]")
            print(f"     gold: {mm['gold']}")
            print(f"     pred: {mm['pred']}")

    summary = {
        "model": LABEL,
        "n_examples": n,
        "schema_validity": round(n_valid / n, 3),
        "event_precision": round(prec, 3),
        "event_recall_start_exact": round(rec_, 3),
        "event_f1": round(f1, 3),
        "end_exact_rate": round(end_ok / end_tot, 3) if end_tot else None,
        "title_similarity_avg": round(sum(title_sims) / len(title_sims), 3) if title_sims else None,
        "location_recall": round(loc_ok / loc_tot, 3) if loc_tot else None,
        "location_similarity_avg": round(sum(loc_sims) / len(loc_sims), 3) if loc_sims else None,
        "reminder_match_rate": round(rem_ok / rem_tot, 3) if rem_tot else None,
        "no_event_accuracy": round(no_event["ok"] / no_event["tot"], 3) if no_event["tot"] else None,
        "clarification_recall": round(clarify["ok"] / clarify["tot"], 3) if clarify["tot"] else None,
        "events_tp_fp_fn": [tp, fp, fn],
    }
    print("\n==================== SUMMARY ====================")
    for k, v in summary.items():
        print(f"  {k:28s} {v}")
    print("RESULTS_JSON:", json.dumps(summary))


if __name__ == "__main__":
    main()