Instructions to use AlexWortega/moe100m-physics-tinybpe with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use AlexWortega/moe100m-physics-tinybpe with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("AlexWortega/moe100m-physics-tinybpe", dtype="auto") - Notebooks
- Google Colab
- Kaggle
File size: 10,917 Bytes
05e1a70 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 | """Eval the trained tiny-vocab physics MoE on the 30 bench scenarios.
Builds a greedy `gen(prompt, max_tokens, stop)` over the trained model +
tokenizer, then a `rollout` mirroring physics_core.rollout (autoregressive
next-frame prediction, scored as mean position error % of scene diagonal vs a
fresh Pymunk ground truth). Prompts use the SAME reduced serialization the
model was trained on (reduced header + reduced frames + "Predict next frame:").
Reports @15-frame and @80-frame mean %-diag, split trained vs held-out.
"""
from __future__ import annotations
import argparse, json, os, sys, time
from pathlib import Path
import torch
_HERE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.join(_HERE, "..", "scaffold"))
from model import MoEModel # noqa: E402
from tokenizers import Tokenizer # noqa: E402
import physics_core as pc # noqa: E402
import physics_serialize as psz # noqa: E402
HELD_OUT = {"pong", "bowling", "ramp_roll", "angry_birds", "hourglass", "newtons_cradle"}
class TinyLLM:
"""Greedy decoder wrapper exposing .gen() over the trained MoE."""
def __init__(self, ckpt_path, tokenizer_path, device="cuda"):
ck = torch.load(ckpt_path, map_location="cpu")
from config_100m import make_config
cfgd = ck["cfg"]
# rebuild cfg from saved dict
from model import MoEModelConfig
cfg = MoEModelConfig(**cfgd)
self.cfg = cfg
self.model = MoEModel(cfg)
self.model.load_state_dict(ck["model"])
self.model.to(device).eval()
self.tok = Tokenizer.from_file(tokenizer_path)
self.device = device
self.eos_id = self.tok.token_to_id("<eos>")
# raw encode without auto bos/eos for prompt continuation
self._bos = self.tok.token_to_id("<bos>")
def _encode_prompt(self, text):
# prepend <bos>, do NOT append <eos> (we are continuing)
ids = self.tok.encode(text, add_special_tokens=False).ids
return [self._bos] + ids
def n_tokens(self, text):
return 1 + len(self.tok.encode(text, add_special_tokens=False).ids)
@property
def max_len(self):
return self.cfg.max_seq_len
@torch.no_grad()
def gen(self, prompt, max_tokens=256, stop=None):
ids = self._encode_prompt(prompt)
# hard safety: never feed more than (max_len - max_tokens) prompt tokens
# (RoPE cache is sized to max_seq_len). Keep <bos> + most-recent tokens.
cap = self.cfg.max_seq_len - max_tokens - 1
if len(ids) > cap:
ids = [ids[0]] + ids[-(cap - 1):]
stop = stop or []
out_ids = []
cur = torch.tensor([ids], dtype=torch.long, device=self.device)
gen_text = ""
for _ in range(max_tokens):
# forward without labels returns full logits as out[0]
with torch.cuda.amp.autocast(dtype=torch.float16):
logits = self.model(cur)[0]
nxt = int(logits[0, -1].float().argmax().item())
if nxt == self.eos_id:
break
out_ids.append(nxt)
cur = torch.cat([cur, torch.tensor([[nxt]], device=self.device)], dim=1)
gen_text = _safe_decode(self.tok, out_ids)
if any(s in gen_text for s in stop):
break
if cur.shape[1] >= self.cfg.max_seq_len:
break
return _safe_decode(self.tok, out_ids)
def _safe_decode(tok, ids):
"""ByteLevel decode that tolerates invalid-UTF8 token sequences (the model
can emit byte tokens that don't form valid UTF-8 mid-generation)."""
try:
return tok.decode(ids)
except Exception:
out = []
for i in ids:
try:
out.append(tok.decode([i]))
except Exception:
pass
return "".join(out)
def build_prompt(header, frames, next_idx=None):
# MUST match the training serialization EXACTLY: the model was trained on
# packed `header + frame1 + frame2 + ...` with NO instruction text and NO
# priming. Pure autoregressive continuation: the model emits the next
# `Frame N: ...` block on its own. (Appending "Predict next frame:" — never
# seen in training — derails it into emitting garbage/headers.)
txt = psz.fmt_header_reduced(header)
txt += "".join(psz.fmt_frame_reduced(fr) for fr in frames)
return txt
def rollout(llm, scenario, n_frames, max_seconds=900.0):
import time as _t
header = scenario["header"]
initial = scenario.get("initial_frames") or []
n_obj = (header.get("object_count") or len(header.get("objects", []))
or (len(initial[0]["objects"]) if initial else 0))
x0, x1, y0, y1 = pc.scene_bounds(header)
diag = ((x1 - x0) ** 2 + (y1 - y0) ** 2) ** 0.5
gt_frames = pc.pymunk_rollout(header, initial[-1], int(n_frames))
gt_by_frame = {f["frame"]: f for f in gt_frames}
rolled = list(initial)
last_idx = initial[-1]["frame"] if initial else 0
per_frame = []
# generation budget for one frame: 48 tok/object (measured) + slack.
gen_budget = min(900, max(64, n_obj * 50 + 32))
ctx_cap = llm.max_len - gen_budget - 8
# Fittability: a single frame (header ~91 + 48 tok/obj) must leave room.
one_frame_tok = 91 + n_obj * 50 + 10
fittable = (one_frame_tok + gen_budget) <= llm.max_len
# Unfittable scenes can't roll out properly at this ctx; cap them to a few
# frames (enough to record fit=False) instead of burning the full budget.
eff_frames = int(n_frames) if fittable else min(int(n_frames), 8)
t0 = _t.time()
for _ in range(eff_frames):
if _t.time() - t0 > max_seconds:
break
# Trim trailing frames so <bos>+header+frames fits in max_seq_len with
# room for the generation budget (RoPE cache cap).
next_idx = last_idx + 1
keep = len(rolled)
prompt = build_prompt(header, rolled[-keep:])
while keep > 1 and llm.n_tokens(prompt) > ctx_cap:
keep -= 1
prompt = build_prompt(header, rolled[-keep:])
# Pure continuation: the model emits "Frame {next_idx}: <desc>\n obj...".
# Stop at the frame AFTER the one we want.
stops = [f"Frame {next_idx+d}:" for d in range(1, 4)]
text = llm.gen(prompt, max_tokens=gen_budget, stop=stops)
parsed = pc.parse_frame(pc.split_first_frame(text), n_obj)
if not parsed:
parsed = pc.parse_frame(text, n_obj)
modeled = len(parsed)
prev_objs = {o["id"]: o for o in rolled[-1]["objects"]} if rolled else {}
new_objs = dict(parsed) if parsed else dict(prev_objs)
if modeled < n_obj:
for oid, o in prev_objs.items():
new_objs.setdefault(oid, o)
last_idx += 1
rolled.append({"frame": last_idx, "description": "Frame %d: " % last_idx,
"objects": list(new_objs.values())})
gt = gt_by_frame.get(last_idx)
rec = {"frame": last_idx, "modeled": modeled, "mean_dist": None}
if gt:
gt_pos = {o["id"]: o["position"] for o in gt["objects"]}
errs = []
for oid, o in new_objs.items():
if oid in gt_pos:
dx = gt_pos[oid]["x"] - o["position"]["x"]
dy = gt_pos[oid]["y"] - o["position"]["y"]
errs.append((dx * dx + dy * dy) ** 0.5)
if errs:
rec["mean_dist"] = sum(errs) / len(errs)
per_frame.append(rec)
valid = [p for p in per_frame if p["mean_dist"] is not None]
mean_dist = (sum(p["mean_dist"] for p in valid) / len(valid)) if valid else None
return {"n_obj": n_obj, "diag": diag, "mean_dist": mean_dist,
"fittable": fittable, "frames_done": len(per_frame),
"pct_diag": (mean_dist / diag * 100.0) if mean_dist else None,
"per_frame": per_frame}
def pct_at(res, k):
vals = [p["mean_dist"] for p in res["per_frame"][:k] if p["mean_dist"] is not None]
if not vals:
return None
return sum(vals) / len(vals) / res["diag"] * 100.0
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--ckpt", default="ckpts/best.pt")
ap.add_argument("--tokenizer", default="tokenizer.json")
ap.add_argument("--scenarios", default="scenarios")
ap.add_argument("--frames", type=int, default=80)
ap.add_argument("--max-scene-seconds", type=float, default=600.0)
ap.add_argument("--out", default="EVAL_RESULTS.json")
args = ap.parse_args()
llm = TinyLLM(args.ckpt, args.tokenizer)
scen_files = sorted(Path(args.scenarios).glob("*.jsonl"))
rows = []
t0 = time.time()
for sf in scen_files:
if sf.name.startswith("._"):
continue # skip macOS AppleDouble resource-fork junk
stype = sf.stem
try:
scenario = pc.load_scenario(sf)
except (UnicodeDecodeError, ValueError, KeyError) as e:
print(f"[eval] {stype:16s} SKIP (unreadable: {e})", flush=True)
continue
res = rollout(llm, scenario, args.frames, max_seconds=args.max_scene_seconds)
p15 = pct_at(res, 15)
p80 = pct_at(res, 80)
held = stype in HELD_OUT
rows.append({"type": stype, "held_out": held, "p15": p15, "p80": p80,
"n_obj": res["n_obj"], "fittable": res["fittable"],
"frames_done": res["frames_done"]})
print(f"[eval] {stype:16s} held={held} fit={res['fittable']} "
f"@15f={p15 if p15 is None else round(p15,3)}% "
f"@80f={p80 if p80 is None else round(p80,3)}% "
f"({res['frames_done']}f, {round(time.time()-t0,0)}s cum)", flush=True)
def agg(key, held, fit_only=False):
vals = [r[key] for r in rows if r["held_out"] == held
and r[key] is not None and (not fit_only or r["fittable"])]
return (sum(vals) / len(vals)) if vals else None
summary = {
"trained": {"p15": agg("p15", False), "p80": agg("p80", False)},
"held_out": {"p15": agg("p15", True), "p80": agg("p80", True)},
"trained_fittable": {"p15": agg("p15", False, True), "p80": agg("p80", False, True)},
"held_out_fittable": {"p15": agg("p15", True, True), "p80": agg("p80", True, True)},
"n_trained": sum(1 for r in rows if not r["held_out"]),
"n_held_out": sum(1 for r in rows if r["held_out"]),
"n_fittable": sum(1 for r in rows if r["fittable"]),
"orbit_p80": next((r["p80"] for r in rows if r["type"] == "orbit"), None),
"rows": rows,
"elapsed_s": round(time.time() - t0, 1),
}
with open(args.out, "w") as f:
json.dump(summary, f, indent=2)
print("\n=== SUMMARY ===")
print(json.dumps({k: v for k, v in summary.items() if k != "rows"}, indent=2))
if __name__ == "__main__":
main()
|