Instructions to use AlexWortega/moe100m-physics-tinybpe with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use AlexWortega/moe100m-physics-tinybpe with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("AlexWortega/moe100m-physics-tinybpe", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """Eval the trained tiny-vocab physics MoE on the 30 bench scenarios. | |
| Builds a greedy `gen(prompt, max_tokens, stop)` over the trained model + | |
| tokenizer, then a `rollout` mirroring physics_core.rollout (autoregressive | |
| next-frame prediction, scored as mean position error % of scene diagonal vs a | |
| fresh Pymunk ground truth). Prompts use the SAME reduced serialization the | |
| model was trained on (reduced header + reduced frames + "Predict next frame:"). | |
| Reports @15-frame and @80-frame mean %-diag, split trained vs held-out. | |
| """ | |
| from __future__ import annotations | |
| import argparse, json, os, sys, time | |
| from pathlib import Path | |
| import torch | |
| _HERE = os.path.dirname(os.path.abspath(__file__)) | |
| sys.path.insert(0, os.path.join(_HERE, "..", "scaffold")) | |
| from model import MoEModel # noqa: E402 | |
| from tokenizers import Tokenizer # noqa: E402 | |
| import physics_core as pc # noqa: E402 | |
| import physics_serialize as psz # noqa: E402 | |
| HELD_OUT = {"pong", "bowling", "ramp_roll", "angry_birds", "hourglass", "newtons_cradle"} | |
| class TinyLLM: | |
| """Greedy decoder wrapper exposing .gen() over the trained MoE.""" | |
| def __init__(self, ckpt_path, tokenizer_path, device="cuda"): | |
| ck = torch.load(ckpt_path, map_location="cpu") | |
| from config_100m import make_config | |
| cfgd = ck["cfg"] | |
| # rebuild cfg from saved dict | |
| from model import MoEModelConfig | |
| cfg = MoEModelConfig(**cfgd) | |
| self.cfg = cfg | |
| self.model = MoEModel(cfg) | |
| self.model.load_state_dict(ck["model"]) | |
| self.model.to(device).eval() | |
| self.tok = Tokenizer.from_file(tokenizer_path) | |
| self.device = device | |
| self.eos_id = self.tok.token_to_id("<eos>") | |
| # raw encode without auto bos/eos for prompt continuation | |
| self._bos = self.tok.token_to_id("<bos>") | |
| def _encode_prompt(self, text): | |
| # prepend <bos>, do NOT append <eos> (we are continuing) | |
| ids = self.tok.encode(text, add_special_tokens=False).ids | |
| return [self._bos] + ids | |
| def n_tokens(self, text): | |
| return 1 + len(self.tok.encode(text, add_special_tokens=False).ids) | |
| def max_len(self): | |
| return self.cfg.max_seq_len | |
| def gen(self, prompt, max_tokens=256, stop=None): | |
| ids = self._encode_prompt(prompt) | |
| # hard safety: never feed more than (max_len - max_tokens) prompt tokens | |
| # (RoPE cache is sized to max_seq_len). Keep <bos> + most-recent tokens. | |
| cap = self.cfg.max_seq_len - max_tokens - 1 | |
| if len(ids) > cap: | |
| ids = [ids[0]] + ids[-(cap - 1):] | |
| stop = stop or [] | |
| out_ids = [] | |
| cur = torch.tensor([ids], dtype=torch.long, device=self.device) | |
| gen_text = "" | |
| for _ in range(max_tokens): | |
| # forward without labels returns full logits as out[0] | |
| with torch.cuda.amp.autocast(dtype=torch.float16): | |
| logits = self.model(cur)[0] | |
| nxt = int(logits[0, -1].float().argmax().item()) | |
| if nxt == self.eos_id: | |
| break | |
| out_ids.append(nxt) | |
| cur = torch.cat([cur, torch.tensor([[nxt]], device=self.device)], dim=1) | |
| gen_text = _safe_decode(self.tok, out_ids) | |
| if any(s in gen_text for s in stop): | |
| break | |
| if cur.shape[1] >= self.cfg.max_seq_len: | |
| break | |
| return _safe_decode(self.tok, out_ids) | |
| def _safe_decode(tok, ids): | |
| """ByteLevel decode that tolerates invalid-UTF8 token sequences (the model | |
| can emit byte tokens that don't form valid UTF-8 mid-generation).""" | |
| try: | |
| return tok.decode(ids) | |
| except Exception: | |
| out = [] | |
| for i in ids: | |
| try: | |
| out.append(tok.decode([i])) | |
| except Exception: | |
| pass | |
| return "".join(out) | |
| def build_prompt(header, frames, next_idx=None): | |
| # MUST match the training serialization EXACTLY: the model was trained on | |
| # packed `header + frame1 + frame2 + ...` with NO instruction text and NO | |
| # priming. Pure autoregressive continuation: the model emits the next | |
| # `Frame N: ...` block on its own. (Appending "Predict next frame:" — never | |
| # seen in training — derails it into emitting garbage/headers.) | |
| txt = psz.fmt_header_reduced(header) | |
| txt += "".join(psz.fmt_frame_reduced(fr) for fr in frames) | |
| return txt | |
| def rollout(llm, scenario, n_frames, max_seconds=900.0): | |
| import time as _t | |
| header = scenario["header"] | |
| initial = scenario.get("initial_frames") or [] | |
| n_obj = (header.get("object_count") or len(header.get("objects", [])) | |
| or (len(initial[0]["objects"]) if initial else 0)) | |
| x0, x1, y0, y1 = pc.scene_bounds(header) | |
| diag = ((x1 - x0) ** 2 + (y1 - y0) ** 2) ** 0.5 | |
| gt_frames = pc.pymunk_rollout(header, initial[-1], int(n_frames)) | |
| gt_by_frame = {f["frame"]: f for f in gt_frames} | |
| rolled = list(initial) | |
| last_idx = initial[-1]["frame"] if initial else 0 | |
| per_frame = [] | |
| # generation budget for one frame: 48 tok/object (measured) + slack. | |
| gen_budget = min(900, max(64, n_obj * 50 + 32)) | |
| ctx_cap = llm.max_len - gen_budget - 8 | |
| # Fittability: a single frame (header ~91 + 48 tok/obj) must leave room. | |
| one_frame_tok = 91 + n_obj * 50 + 10 | |
| fittable = (one_frame_tok + gen_budget) <= llm.max_len | |
| # Unfittable scenes can't roll out properly at this ctx; cap them to a few | |
| # frames (enough to record fit=False) instead of burning the full budget. | |
| eff_frames = int(n_frames) if fittable else min(int(n_frames), 8) | |
| t0 = _t.time() | |
| for _ in range(eff_frames): | |
| if _t.time() - t0 > max_seconds: | |
| break | |
| # Trim trailing frames so <bos>+header+frames fits in max_seq_len with | |
| # room for the generation budget (RoPE cache cap). | |
| next_idx = last_idx + 1 | |
| keep = len(rolled) | |
| prompt = build_prompt(header, rolled[-keep:]) | |
| while keep > 1 and llm.n_tokens(prompt) > ctx_cap: | |
| keep -= 1 | |
| prompt = build_prompt(header, rolled[-keep:]) | |
| # Pure continuation: the model emits "Frame {next_idx}: <desc>\n obj...". | |
| # Stop at the frame AFTER the one we want. | |
| stops = [f"Frame {next_idx+d}:" for d in range(1, 4)] | |
| text = llm.gen(prompt, max_tokens=gen_budget, stop=stops) | |
| parsed = pc.parse_frame(pc.split_first_frame(text), n_obj) | |
| if not parsed: | |
| parsed = pc.parse_frame(text, n_obj) | |
| modeled = len(parsed) | |
| prev_objs = {o["id"]: o for o in rolled[-1]["objects"]} if rolled else {} | |
| new_objs = dict(parsed) if parsed else dict(prev_objs) | |
| if modeled < n_obj: | |
| for oid, o in prev_objs.items(): | |
| new_objs.setdefault(oid, o) | |
| last_idx += 1 | |
| rolled.append({"frame": last_idx, "description": "Frame %d: " % last_idx, | |
| "objects": list(new_objs.values())}) | |
| gt = gt_by_frame.get(last_idx) | |
| rec = {"frame": last_idx, "modeled": modeled, "mean_dist": None} | |
| if gt: | |
| gt_pos = {o["id"]: o["position"] for o in gt["objects"]} | |
| errs = [] | |
| for oid, o in new_objs.items(): | |
| if oid in gt_pos: | |
| dx = gt_pos[oid]["x"] - o["position"]["x"] | |
| dy = gt_pos[oid]["y"] - o["position"]["y"] | |
| errs.append((dx * dx + dy * dy) ** 0.5) | |
| if errs: | |
| rec["mean_dist"] = sum(errs) / len(errs) | |
| per_frame.append(rec) | |
| valid = [p for p in per_frame if p["mean_dist"] is not None] | |
| mean_dist = (sum(p["mean_dist"] for p in valid) / len(valid)) if valid else None | |
| return {"n_obj": n_obj, "diag": diag, "mean_dist": mean_dist, | |
| "fittable": fittable, "frames_done": len(per_frame), | |
| "pct_diag": (mean_dist / diag * 100.0) if mean_dist else None, | |
| "per_frame": per_frame} | |
| def pct_at(res, k): | |
| vals = [p["mean_dist"] for p in res["per_frame"][:k] if p["mean_dist"] is not None] | |
| if not vals: | |
| return None | |
| return sum(vals) / len(vals) / res["diag"] * 100.0 | |
| def main(): | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--ckpt", default="ckpts/best.pt") | |
| ap.add_argument("--tokenizer", default="tokenizer.json") | |
| ap.add_argument("--scenarios", default="scenarios") | |
| ap.add_argument("--frames", type=int, default=80) | |
| ap.add_argument("--max-scene-seconds", type=float, default=600.0) | |
| ap.add_argument("--out", default="EVAL_RESULTS.json") | |
| args = ap.parse_args() | |
| llm = TinyLLM(args.ckpt, args.tokenizer) | |
| scen_files = sorted(Path(args.scenarios).glob("*.jsonl")) | |
| rows = [] | |
| t0 = time.time() | |
| for sf in scen_files: | |
| if sf.name.startswith("._"): | |
| continue # skip macOS AppleDouble resource-fork junk | |
| stype = sf.stem | |
| try: | |
| scenario = pc.load_scenario(sf) | |
| except (UnicodeDecodeError, ValueError, KeyError) as e: | |
| print(f"[eval] {stype:16s} SKIP (unreadable: {e})", flush=True) | |
| continue | |
| res = rollout(llm, scenario, args.frames, max_seconds=args.max_scene_seconds) | |
| p15 = pct_at(res, 15) | |
| p80 = pct_at(res, 80) | |
| held = stype in HELD_OUT | |
| rows.append({"type": stype, "held_out": held, "p15": p15, "p80": p80, | |
| "n_obj": res["n_obj"], "fittable": res["fittable"], | |
| "frames_done": res["frames_done"]}) | |
| print(f"[eval] {stype:16s} held={held} fit={res['fittable']} " | |
| f"@15f={p15 if p15 is None else round(p15,3)}% " | |
| f"@80f={p80 if p80 is None else round(p80,3)}% " | |
| f"({res['frames_done']}f, {round(time.time()-t0,0)}s cum)", flush=True) | |
| def agg(key, held, fit_only=False): | |
| vals = [r[key] for r in rows if r["held_out"] == held | |
| and r[key] is not None and (not fit_only or r["fittable"])] | |
| return (sum(vals) / len(vals)) if vals else None | |
| summary = { | |
| "trained": {"p15": agg("p15", False), "p80": agg("p80", False)}, | |
| "held_out": {"p15": agg("p15", True), "p80": agg("p80", True)}, | |
| "trained_fittable": {"p15": agg("p15", False, True), "p80": agg("p80", False, True)}, | |
| "held_out_fittable": {"p15": agg("p15", True, True), "p80": agg("p80", True, True)}, | |
| "n_trained": sum(1 for r in rows if not r["held_out"]), | |
| "n_held_out": sum(1 for r in rows if r["held_out"]), | |
| "n_fittable": sum(1 for r in rows if r["fittable"]), | |
| "orbit_p80": next((r["p80"] for r in rows if r["type"] == "orbit"), None), | |
| "rows": rows, | |
| "elapsed_s": round(time.time() - t0, 1), | |
| } | |
| with open(args.out, "w") as f: | |
| json.dump(summary, f, indent=2) | |
| print("\n=== SUMMARY ===") | |
| print(json.dumps({k: v for k, v in summary.items() if k != "rows"}, indent=2)) | |
| if __name__ == "__main__": | |
| main() | |