moe100m-physics-tinybpe / eval_phys.py
AlexWortega's picture
Upload eval_phys.py with huggingface_hub
05e1a70 verified
"""Eval the trained tiny-vocab physics MoE on the 30 bench scenarios.
Builds a greedy `gen(prompt, max_tokens, stop)` over the trained model +
tokenizer, then a `rollout` mirroring physics_core.rollout (autoregressive
next-frame prediction, scored as mean position error % of scene diagonal vs a
fresh Pymunk ground truth). Prompts use the SAME reduced serialization the
model was trained on (reduced header + reduced frames + "Predict next frame:").
Reports @15-frame and @80-frame mean %-diag, split trained vs held-out.
"""
from __future__ import annotations
import argparse, json, os, sys, time
from pathlib import Path
import torch
_HERE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.join(_HERE, "..", "scaffold"))
from model import MoEModel # noqa: E402
from tokenizers import Tokenizer # noqa: E402
import physics_core as pc # noqa: E402
import physics_serialize as psz # noqa: E402
HELD_OUT = {"pong", "bowling", "ramp_roll", "angry_birds", "hourglass", "newtons_cradle"}
class TinyLLM:
"""Greedy decoder wrapper exposing .gen() over the trained MoE."""
def __init__(self, ckpt_path, tokenizer_path, device="cuda"):
ck = torch.load(ckpt_path, map_location="cpu")
from config_100m import make_config
cfgd = ck["cfg"]
# rebuild cfg from saved dict
from model import MoEModelConfig
cfg = MoEModelConfig(**cfgd)
self.cfg = cfg
self.model = MoEModel(cfg)
self.model.load_state_dict(ck["model"])
self.model.to(device).eval()
self.tok = Tokenizer.from_file(tokenizer_path)
self.device = device
self.eos_id = self.tok.token_to_id("<eos>")
# raw encode without auto bos/eos for prompt continuation
self._bos = self.tok.token_to_id("<bos>")
def _encode_prompt(self, text):
# prepend <bos>, do NOT append <eos> (we are continuing)
ids = self.tok.encode(text, add_special_tokens=False).ids
return [self._bos] + ids
def n_tokens(self, text):
return 1 + len(self.tok.encode(text, add_special_tokens=False).ids)
@property
def max_len(self):
return self.cfg.max_seq_len
@torch.no_grad()
def gen(self, prompt, max_tokens=256, stop=None):
ids = self._encode_prompt(prompt)
# hard safety: never feed more than (max_len - max_tokens) prompt tokens
# (RoPE cache is sized to max_seq_len). Keep <bos> + most-recent tokens.
cap = self.cfg.max_seq_len - max_tokens - 1
if len(ids) > cap:
ids = [ids[0]] + ids[-(cap - 1):]
stop = stop or []
out_ids = []
cur = torch.tensor([ids], dtype=torch.long, device=self.device)
gen_text = ""
for _ in range(max_tokens):
# forward without labels returns full logits as out[0]
with torch.cuda.amp.autocast(dtype=torch.float16):
logits = self.model(cur)[0]
nxt = int(logits[0, -1].float().argmax().item())
if nxt == self.eos_id:
break
out_ids.append(nxt)
cur = torch.cat([cur, torch.tensor([[nxt]], device=self.device)], dim=1)
gen_text = _safe_decode(self.tok, out_ids)
if any(s in gen_text for s in stop):
break
if cur.shape[1] >= self.cfg.max_seq_len:
break
return _safe_decode(self.tok, out_ids)
def _safe_decode(tok, ids):
"""ByteLevel decode that tolerates invalid-UTF8 token sequences (the model
can emit byte tokens that don't form valid UTF-8 mid-generation)."""
try:
return tok.decode(ids)
except Exception:
out = []
for i in ids:
try:
out.append(tok.decode([i]))
except Exception:
pass
return "".join(out)
def build_prompt(header, frames, next_idx=None):
# MUST match the training serialization EXACTLY: the model was trained on
# packed `header + frame1 + frame2 + ...` with NO instruction text and NO
# priming. Pure autoregressive continuation: the model emits the next
# `Frame N: ...` block on its own. (Appending "Predict next frame:" — never
# seen in training — derails it into emitting garbage/headers.)
txt = psz.fmt_header_reduced(header)
txt += "".join(psz.fmt_frame_reduced(fr) for fr in frames)
return txt
def rollout(llm, scenario, n_frames, max_seconds=900.0):
import time as _t
header = scenario["header"]
initial = scenario.get("initial_frames") or []
n_obj = (header.get("object_count") or len(header.get("objects", []))
or (len(initial[0]["objects"]) if initial else 0))
x0, x1, y0, y1 = pc.scene_bounds(header)
diag = ((x1 - x0) ** 2 + (y1 - y0) ** 2) ** 0.5
gt_frames = pc.pymunk_rollout(header, initial[-1], int(n_frames))
gt_by_frame = {f["frame"]: f for f in gt_frames}
rolled = list(initial)
last_idx = initial[-1]["frame"] if initial else 0
per_frame = []
# generation budget for one frame: 48 tok/object (measured) + slack.
gen_budget = min(900, max(64, n_obj * 50 + 32))
ctx_cap = llm.max_len - gen_budget - 8
# Fittability: a single frame (header ~91 + 48 tok/obj) must leave room.
one_frame_tok = 91 + n_obj * 50 + 10
fittable = (one_frame_tok + gen_budget) <= llm.max_len
# Unfittable scenes can't roll out properly at this ctx; cap them to a few
# frames (enough to record fit=False) instead of burning the full budget.
eff_frames = int(n_frames) if fittable else min(int(n_frames), 8)
t0 = _t.time()
for _ in range(eff_frames):
if _t.time() - t0 > max_seconds:
break
# Trim trailing frames so <bos>+header+frames fits in max_seq_len with
# room for the generation budget (RoPE cache cap).
next_idx = last_idx + 1
keep = len(rolled)
prompt = build_prompt(header, rolled[-keep:])
while keep > 1 and llm.n_tokens(prompt) > ctx_cap:
keep -= 1
prompt = build_prompt(header, rolled[-keep:])
# Pure continuation: the model emits "Frame {next_idx}: <desc>\n obj...".
# Stop at the frame AFTER the one we want.
stops = [f"Frame {next_idx+d}:" for d in range(1, 4)]
text = llm.gen(prompt, max_tokens=gen_budget, stop=stops)
parsed = pc.parse_frame(pc.split_first_frame(text), n_obj)
if not parsed:
parsed = pc.parse_frame(text, n_obj)
modeled = len(parsed)
prev_objs = {o["id"]: o for o in rolled[-1]["objects"]} if rolled else {}
new_objs = dict(parsed) if parsed else dict(prev_objs)
if modeled < n_obj:
for oid, o in prev_objs.items():
new_objs.setdefault(oid, o)
last_idx += 1
rolled.append({"frame": last_idx, "description": "Frame %d: " % last_idx,
"objects": list(new_objs.values())})
gt = gt_by_frame.get(last_idx)
rec = {"frame": last_idx, "modeled": modeled, "mean_dist": None}
if gt:
gt_pos = {o["id"]: o["position"] for o in gt["objects"]}
errs = []
for oid, o in new_objs.items():
if oid in gt_pos:
dx = gt_pos[oid]["x"] - o["position"]["x"]
dy = gt_pos[oid]["y"] - o["position"]["y"]
errs.append((dx * dx + dy * dy) ** 0.5)
if errs:
rec["mean_dist"] = sum(errs) / len(errs)
per_frame.append(rec)
valid = [p for p in per_frame if p["mean_dist"] is not None]
mean_dist = (sum(p["mean_dist"] for p in valid) / len(valid)) if valid else None
return {"n_obj": n_obj, "diag": diag, "mean_dist": mean_dist,
"fittable": fittable, "frames_done": len(per_frame),
"pct_diag": (mean_dist / diag * 100.0) if mean_dist else None,
"per_frame": per_frame}
def pct_at(res, k):
vals = [p["mean_dist"] for p in res["per_frame"][:k] if p["mean_dist"] is not None]
if not vals:
return None
return sum(vals) / len(vals) / res["diag"] * 100.0
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--ckpt", default="ckpts/best.pt")
ap.add_argument("--tokenizer", default="tokenizer.json")
ap.add_argument("--scenarios", default="scenarios")
ap.add_argument("--frames", type=int, default=80)
ap.add_argument("--max-scene-seconds", type=float, default=600.0)
ap.add_argument("--out", default="EVAL_RESULTS.json")
args = ap.parse_args()
llm = TinyLLM(args.ckpt, args.tokenizer)
scen_files = sorted(Path(args.scenarios).glob("*.jsonl"))
rows = []
t0 = time.time()
for sf in scen_files:
if sf.name.startswith("._"):
continue # skip macOS AppleDouble resource-fork junk
stype = sf.stem
try:
scenario = pc.load_scenario(sf)
except (UnicodeDecodeError, ValueError, KeyError) as e:
print(f"[eval] {stype:16s} SKIP (unreadable: {e})", flush=True)
continue
res = rollout(llm, scenario, args.frames, max_seconds=args.max_scene_seconds)
p15 = pct_at(res, 15)
p80 = pct_at(res, 80)
held = stype in HELD_OUT
rows.append({"type": stype, "held_out": held, "p15": p15, "p80": p80,
"n_obj": res["n_obj"], "fittable": res["fittable"],
"frames_done": res["frames_done"]})
print(f"[eval] {stype:16s} held={held} fit={res['fittable']} "
f"@15f={p15 if p15 is None else round(p15,3)}% "
f"@80f={p80 if p80 is None else round(p80,3)}% "
f"({res['frames_done']}f, {round(time.time()-t0,0)}s cum)", flush=True)
def agg(key, held, fit_only=False):
vals = [r[key] for r in rows if r["held_out"] == held
and r[key] is not None and (not fit_only or r["fittable"])]
return (sum(vals) / len(vals)) if vals else None
summary = {
"trained": {"p15": agg("p15", False), "p80": agg("p80", False)},
"held_out": {"p15": agg("p15", True), "p80": agg("p80", True)},
"trained_fittable": {"p15": agg("p15", False, True), "p80": agg("p80", False, True)},
"held_out_fittable": {"p15": agg("p15", True, True), "p80": agg("p80", True, True)},
"n_trained": sum(1 for r in rows if not r["held_out"]),
"n_held_out": sum(1 for r in rows if r["held_out"]),
"n_fittable": sum(1 for r in rows if r["fittable"]),
"orbit_p80": next((r["p80"] for r in rows if r["type"] == "orbit"), None),
"rows": rows,
"elapsed_s": round(time.time() - t0, 1),
}
with open(args.out, "w") as f:
json.dump(summary, f, indent=2)
print("\n=== SUMMARY ===")
print(json.dumps({k: v for k, v in summary.items() if k != "rows"}, indent=2))
if __name__ == "__main__":
main()