File size: 10,917 Bytes
05e1a70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
"""Eval the trained tiny-vocab physics MoE on the 30 bench scenarios.

Builds a greedy `gen(prompt, max_tokens, stop)` over the trained model +
tokenizer, then a `rollout` mirroring physics_core.rollout (autoregressive
next-frame prediction, scored as mean position error % of scene diagonal vs a
fresh Pymunk ground truth). Prompts use the SAME reduced serialization the
model was trained on (reduced header + reduced frames + "Predict next frame:").

Reports @15-frame and @80-frame mean %-diag, split trained vs held-out.
"""
from __future__ import annotations
import argparse, json, os, sys, time
from pathlib import Path

import torch

_HERE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.join(_HERE, "..", "scaffold"))

from model import MoEModel  # noqa: E402
from tokenizers import Tokenizer  # noqa: E402

import physics_core as pc  # noqa: E402
import physics_serialize as psz  # noqa: E402

HELD_OUT = {"pong", "bowling", "ramp_roll", "angry_birds", "hourglass", "newtons_cradle"}


class TinyLLM:
    """Greedy decoder wrapper exposing .gen() over the trained MoE."""
    def __init__(self, ckpt_path, tokenizer_path, device="cuda"):
        ck = torch.load(ckpt_path, map_location="cpu")
        from config_100m import make_config
        cfgd = ck["cfg"]
        # rebuild cfg from saved dict
        from model import MoEModelConfig
        cfg = MoEModelConfig(**cfgd)
        self.cfg = cfg
        self.model = MoEModel(cfg)
        self.model.load_state_dict(ck["model"])
        self.model.to(device).eval()
        self.tok = Tokenizer.from_file(tokenizer_path)
        self.device = device
        self.eos_id = self.tok.token_to_id("<eos>")
        # raw encode without auto bos/eos for prompt continuation
        self._bos = self.tok.token_to_id("<bos>")

    def _encode_prompt(self, text):
        # prepend <bos>, do NOT append <eos> (we are continuing)
        ids = self.tok.encode(text, add_special_tokens=False).ids
        return [self._bos] + ids

    def n_tokens(self, text):
        return 1 + len(self.tok.encode(text, add_special_tokens=False).ids)

    @property
    def max_len(self):
        return self.cfg.max_seq_len

    @torch.no_grad()
    def gen(self, prompt, max_tokens=256, stop=None):
        ids = self._encode_prompt(prompt)
        # hard safety: never feed more than (max_len - max_tokens) prompt tokens
        # (RoPE cache is sized to max_seq_len). Keep <bos> + most-recent tokens.
        cap = self.cfg.max_seq_len - max_tokens - 1
        if len(ids) > cap:
            ids = [ids[0]] + ids[-(cap - 1):]
        stop = stop or []
        out_ids = []
        cur = torch.tensor([ids], dtype=torch.long, device=self.device)
        gen_text = ""
        for _ in range(max_tokens):
            # forward without labels returns full logits as out[0]
            with torch.cuda.amp.autocast(dtype=torch.float16):
                logits = self.model(cur)[0]
            nxt = int(logits[0, -1].float().argmax().item())
            if nxt == self.eos_id:
                break
            out_ids.append(nxt)
            cur = torch.cat([cur, torch.tensor([[nxt]], device=self.device)], dim=1)
            gen_text = _safe_decode(self.tok, out_ids)
            if any(s in gen_text for s in stop):
                break
            if cur.shape[1] >= self.cfg.max_seq_len:
                break
        return _safe_decode(self.tok, out_ids)


def _safe_decode(tok, ids):
    """ByteLevel decode that tolerates invalid-UTF8 token sequences (the model
    can emit byte tokens that don't form valid UTF-8 mid-generation)."""
    try:
        return tok.decode(ids)
    except Exception:
        out = []
        for i in ids:
            try:
                out.append(tok.decode([i]))
            except Exception:
                pass
        return "".join(out)


def build_prompt(header, frames, next_idx=None):
    # MUST match the training serialization EXACTLY: the model was trained on
    # packed `header + frame1 + frame2 + ...` with NO instruction text and NO
    # priming. Pure autoregressive continuation: the model emits the next
    # `Frame N: ...` block on its own. (Appending "Predict next frame:" — never
    # seen in training — derails it into emitting garbage/headers.)
    txt = psz.fmt_header_reduced(header)
    txt += "".join(psz.fmt_frame_reduced(fr) for fr in frames)
    return txt


def rollout(llm, scenario, n_frames, max_seconds=900.0):
    import time as _t
    header = scenario["header"]
    initial = scenario.get("initial_frames") or []
    n_obj = (header.get("object_count") or len(header.get("objects", []))
             or (len(initial[0]["objects"]) if initial else 0))
    x0, x1, y0, y1 = pc.scene_bounds(header)
    diag = ((x1 - x0) ** 2 + (y1 - y0) ** 2) ** 0.5

    gt_frames = pc.pymunk_rollout(header, initial[-1], int(n_frames))
    gt_by_frame = {f["frame"]: f for f in gt_frames}

    rolled = list(initial)
    last_idx = initial[-1]["frame"] if initial else 0
    per_frame = []

    # generation budget for one frame: 48 tok/object (measured) + slack.
    gen_budget = min(900, max(64, n_obj * 50 + 32))
    ctx_cap = llm.max_len - gen_budget - 8

    # Fittability: a single frame (header ~91 + 48 tok/obj) must leave room.
    one_frame_tok = 91 + n_obj * 50 + 10
    fittable = (one_frame_tok + gen_budget) <= llm.max_len
    # Unfittable scenes can't roll out properly at this ctx; cap them to a few
    # frames (enough to record fit=False) instead of burning the full budget.
    eff_frames = int(n_frames) if fittable else min(int(n_frames), 8)
    t0 = _t.time()

    for _ in range(eff_frames):
        if _t.time() - t0 > max_seconds:
            break
        # Trim trailing frames so <bos>+header+frames fits in max_seq_len with
        # room for the generation budget (RoPE cache cap).
        next_idx = last_idx + 1
        keep = len(rolled)
        prompt = build_prompt(header, rolled[-keep:])
        while keep > 1 and llm.n_tokens(prompt) > ctx_cap:
            keep -= 1
            prompt = build_prompt(header, rolled[-keep:])
        # Pure continuation: the model emits "Frame {next_idx}: <desc>\n obj...".
        # Stop at the frame AFTER the one we want.
        stops = [f"Frame {next_idx+d}:" for d in range(1, 4)]
        text = llm.gen(prompt, max_tokens=gen_budget, stop=stops)
        parsed = pc.parse_frame(pc.split_first_frame(text), n_obj)
        if not parsed:
            parsed = pc.parse_frame(text, n_obj)
        modeled = len(parsed)
        prev_objs = {o["id"]: o for o in rolled[-1]["objects"]} if rolled else {}
        new_objs = dict(parsed) if parsed else dict(prev_objs)
        if modeled < n_obj:
            for oid, o in prev_objs.items():
                new_objs.setdefault(oid, o)
        last_idx += 1
        rolled.append({"frame": last_idx, "description": "Frame %d: " % last_idx,
                       "objects": list(new_objs.values())})

        gt = gt_by_frame.get(last_idx)
        rec = {"frame": last_idx, "modeled": modeled, "mean_dist": None}
        if gt:
            gt_pos = {o["id"]: o["position"] for o in gt["objects"]}
            errs = []
            for oid, o in new_objs.items():
                if oid in gt_pos:
                    dx = gt_pos[oid]["x"] - o["position"]["x"]
                    dy = gt_pos[oid]["y"] - o["position"]["y"]
                    errs.append((dx * dx + dy * dy) ** 0.5)
            if errs:
                rec["mean_dist"] = sum(errs) / len(errs)
        per_frame.append(rec)

    valid = [p for p in per_frame if p["mean_dist"] is not None]
    mean_dist = (sum(p["mean_dist"] for p in valid) / len(valid)) if valid else None
    return {"n_obj": n_obj, "diag": diag, "mean_dist": mean_dist,
            "fittable": fittable, "frames_done": len(per_frame),
            "pct_diag": (mean_dist / diag * 100.0) if mean_dist else None,
            "per_frame": per_frame}


def pct_at(res, k):
    vals = [p["mean_dist"] for p in res["per_frame"][:k] if p["mean_dist"] is not None]
    if not vals:
        return None
    return sum(vals) / len(vals) / res["diag"] * 100.0


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--ckpt", default="ckpts/best.pt")
    ap.add_argument("--tokenizer", default="tokenizer.json")
    ap.add_argument("--scenarios", default="scenarios")
    ap.add_argument("--frames", type=int, default=80)
    ap.add_argument("--max-scene-seconds", type=float, default=600.0)
    ap.add_argument("--out", default="EVAL_RESULTS.json")
    args = ap.parse_args()

    llm = TinyLLM(args.ckpt, args.tokenizer)
    scen_files = sorted(Path(args.scenarios).glob("*.jsonl"))
    rows = []
    t0 = time.time()
    for sf in scen_files:
        if sf.name.startswith("._"):
            continue  # skip macOS AppleDouble resource-fork junk
        stype = sf.stem
        try:
            scenario = pc.load_scenario(sf)
        except (UnicodeDecodeError, ValueError, KeyError) as e:
            print(f"[eval] {stype:16s} SKIP (unreadable: {e})", flush=True)
            continue
        res = rollout(llm, scenario, args.frames, max_seconds=args.max_scene_seconds)
        p15 = pct_at(res, 15)
        p80 = pct_at(res, 80)
        held = stype in HELD_OUT
        rows.append({"type": stype, "held_out": held, "p15": p15, "p80": p80,
                     "n_obj": res["n_obj"], "fittable": res["fittable"],
                     "frames_done": res["frames_done"]})
        print(f"[eval] {stype:16s} held={held} fit={res['fittable']} "
              f"@15f={p15 if p15 is None else round(p15,3)}% "
              f"@80f={p80 if p80 is None else round(p80,3)}% "
              f"({res['frames_done']}f, {round(time.time()-t0,0)}s cum)", flush=True)

    def agg(key, held, fit_only=False):
        vals = [r[key] for r in rows if r["held_out"] == held
                and r[key] is not None and (not fit_only or r["fittable"])]
        return (sum(vals) / len(vals)) if vals else None

    summary = {
        "trained": {"p15": agg("p15", False), "p80": agg("p80", False)},
        "held_out": {"p15": agg("p15", True), "p80": agg("p80", True)},
        "trained_fittable": {"p15": agg("p15", False, True), "p80": agg("p80", False, True)},
        "held_out_fittable": {"p15": agg("p15", True, True), "p80": agg("p80", True, True)},
        "n_trained": sum(1 for r in rows if not r["held_out"]),
        "n_held_out": sum(1 for r in rows if r["held_out"]),
        "n_fittable": sum(1 for r in rows if r["fittable"]),
        "orbit_p80": next((r["p80"] for r in rows if r["type"] == "orbit"), None),
        "rows": rows,
        "elapsed_s": round(time.time() - t0, 1),
    }
    with open(args.out, "w") as f:
        json.dump(summary, f, indent=2)
    print("\n=== SUMMARY ===")
    print(json.dumps({k: v for k, v in summary.items() if k != "rows"}, indent=2))


if __name__ == "__main__":
    main()