""" ============================================================================= SAM-MM Benchmark — reproducible per-family evaluation for the SAM-MM line SparseMind / AMFORGE ============================================================================= Checkpoint-driven and fully self-contained: the held-out eval set is GENERATED INTERNALLY (disjoint seed 99991), so no external data file or generator script is needed. It renders frames/mel from each sample's spec, greedy-decodes the answer, and reports per-family exact match + a CHAT/ACTION breakdown + aggregate. Notebook (Colab/Kaggle): edit the variables at the top of main() and run. Terminal: python samg_mm_benchmark.py --ckpt AMFORGE/sam-mm-reasoning-checkpoints:best.pt --families reasoning python samg_mm_benchmark.py --ckpt AMFORGE/sam-mm-audio-reasoning-checkpoints:best.pt --families audio python samg_mm_benchmark.py --ckpt ./best.pt --n 3000 --n-per 100 Self-contained: the SAM-MM model, the renderers, the tokenizer resolver are inlined verbatim. External vision/audio LMs are NOT comparable on these SAM-specific synthetic tasks (different input pipelines), so this is an honest internal per-family report; add a baseline column only where truly comparable. ============================================================================= """ from __future__ import annotations import os, sys, json, math, random, argparse from dataclasses import dataclass, asdict from typing import Optional from enum import IntEnum import torch, torch.nn as nn, torch.nn.functional as F try: import sentencepiece as spm except ImportError: os.system(f"{sys.executable} -m pip install -q sentencepiece --break-system-packages"); import sentencepiece as spm device = torch.device("cuda" if torch.cuda.is_available() else "cpu") BF16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported() TOK_REPO, TOK_FILE = "AMFORGE/samg_mm_tok", "samg_mm_tokenizer.model" DEFAULT_CKPT_REPO = "AMFORGE/sam-mm-reasoning-checkpoints" ORGANIZATION, MODEL_NAME = "AMFORGE", "SAM-MM" def _pip(p): os.system(f"{sys.executable} -m pip install -q {p} --break-system-packages") def get_hf_token(): t = os.environ.get("HF_TOKEN") or "" if not t: try: from kaggle_secrets import UserSecretsClient t = UserSecretsClient().get_secret("HF_TOKEN") or "" except Exception: pass if not t: try: from google.colab import userdata; t = userdata.get("HF_TOKEN") or "" except Exception: pass if not t: p = os.path.expanduser("~/.cache/huggingface/token") if os.path.exists(p): t = open(p).read().strip() return t def resolve_tokenizer(token=None): for p in [TOK_FILE, os.path.join("tokenizer", TOK_FILE)]: if os.path.isfile(p): return p try: from huggingface_hub import hf_hub_download except ImportError: _pip("huggingface_hub"); from huggingface_hub import hf_hub_download return hf_hub_download(TOK_REPO, TOK_FILE, token=token) def resolve_ckpt(spec, token=None): """Local path, 'repo:file', or bare 'file' from the default repo.""" if os.path.isfile(spec): return spec try: from huggingface_hub import hf_hub_download except ImportError: _pip("huggingface_hub"); from huggingface_hub import hf_hub_download if ":" in spec and not spec.startswith("/"): repo, fn = spec.split(":", 1) else: repo, fn = DEFAULT_CKPT_REPO, spec return hf_hub_download(repo, fn, token=token) # ============================================================================= # SAM-MM model — INLINED VERBATIM from samg_mm_train.py (state_dict-compatible) # ============================================================================= class NeuronType(IntEnum): STEM=0; EXCITATORY=1; INHIBITORY=2; MEMORY=3; RELAY=4; MODULATORY=5; PATTERN=6 TARGET_DISTRIBUTION = {NeuronType.STEM:.10, NeuronType.EXCITATORY:.35, NeuronType.INHIBITORY:.10, NeuronType.MEMORY:.15, NeuronType.RELAY:.10, NeuronType.MODULATORY:.08, NeuronType.PATTERN:.12} @dataclass class Config: vocab_size:int=32000; dim:int=320; n_layers:int=8; n_heads:int=8 max_seq_len:int=1024; channel_top_k:int=120; token_top_k:int=128; ffn_mult:int=4 dropout:float=0.1; pad_id:int=0; eos_id:int=2; use_diversity:bool=True # multimodal v_dim:int=320; v_layers:int=7; v_patch:int=8; img:int=96 a_dim:int=320; a_layers:int=6; mel:int=64 phys_dim:int=320; phys_slots:int=4 # plasticity (text diversity layers) target_stem_ratio:float=.10; min_stem_ratio:float=.08; stem_plasticity:float=.012 reversion_rate:float=.012; min_age_before_revert:int=8; update_interval:int=10 baseline_revert_ratio:float=.5; inhibition_strength:float=.08 modulation_strength:float=.1; excitation_strength:float=.3 # train batch_size:int=16; grad_accum:int=2; lr:float=5e-4; max_steps:int=40000 warmup:int=1500; eval_every:int=1000; save_every:int=1000; patience:int=12 log_every:int=50; aux_phys_w:float=0.5 class DynamicTypeManager(nn.Module): def __init__(self, dim, cfg): super().__init__() self.dim, self.cfg = dim, cfg t=[] for nt,p in TARGET_DISTRIBUTION.items(): t += [nt.value]*int(dim*p) while len(t)4: acts=self.activation_history[spec] thr=torch.quantile(acts,.30).item() cand=spec[(acts<=max(thr,1e-6)) & (self.age[spec]>self.cfg.min_age_before_revert)] n=min(max(1,int(self.dim*self.cfg.reversion_rate)),len(cand)) if n>0: sel=cand[(-self.activation_history[cand]).topk(n)[1]] self.neuron_types[sel]=0; self.age[sel]=0; self.activation_history[sel]*=.5 stem=(self.neuron_types==0).nonzero().view(-1) floor=max(2,int(self.dim*self.cfg.min_stem_ratio)) if len(stem)>floor: n=min(max(1,int(self.dim*self.cfg.stem_plasticity)),len(stem)-floor) sel=stem[self.age[stem].float().topk(n)[1]] for ni in sel: w=torch.tensor([TARGET_DISTRIBUTION[t] for t in NeuronType if t!=NeuronType.STEM]) self.neuron_types[ni]=list(NeuronType)[1:][torch.multinomial(w/w.sum(),1).item()].value self.age[ni]=0 class GentleInhibition(nn.Module): def __init__(s,d,c): super().__init__(); s.k=c.inhibition_strength; s.noise_detector=nn.Sequential(nn.Linear(d,d//4),nn.ReLU(),nn.Linear(d//4,d),nn.Sigmoid()); s.threshold=nn.Parameter(torch.tensor(.15)) def forward(s,x,m): sup=(x.abs()=thr).float(); soft=torch.sigmoid((sc-thr)*10) return x*(hard-soft.detach()+soft) class SparseAttn(nn.Module): def __init__(s,d,h,tk): super().__init__(); s.h,s.hd,s.tk=h,d//h,tk; s.qkv=nn.Linear(d,3*d); s.out=nn.Linear(d,d) def forward(s,x): B,T,D=x.shape q,k,v=s.qkv(x).reshape(B,T,3,s.h,s.hd).permute(2,0,3,1,4) a=(q@k.transpose(-2,-1))*s.hd**-.5 a=a.masked_fill(torch.triu(torch.ones(T,T,device=x.device),1).bool(),float("-inf")) _,i=a.topk(min(s.tk,T),-1) m=torch.zeros_like(a,dtype=torch.bool).scatter_(-1,i,True) a=torch.nan_to_num(F.softmax(a.masked_fill(~m,float("-inf")),-1),0.) return s.out((a@v).transpose(1,2).reshape(B,T,D)) class SparseFFN(nn.Module): def __init__(s,d,m,ck): super().__init__(); s.up=nn.Linear(d,d*m); s.gate=SparseGate(d*m,ck); s.down=nn.Linear(d*m,d) def forward(s,x): return s.down(s.gate(F.silu(s.up(x)))) class Block(nn.Module): def __init__(s,c,i,dim=None,heads=None,tk=None,ck=None,div=False): super().__init__(); d=dim or c.dim; h=heads or c.n_heads s.n1=nn.LayerNorm(d); s.attn=SparseAttn(d,h,tk or c.token_top_k) s.n2=nn.LayerNorm(d); s.ffn=SparseFFN(d,c.ffn_mult,(ck or c.channel_top_k)*c.ffn_mult) s.drop=nn.Dropout(c.dropout); s.div=div if div: s.diversity=BalancedDiversityLayer(d,c) def forward(s,x): x=x+s.drop(s.attn(s.n1(x))) if s.div: x=s.diversity(x) return x+s.drop(s.ffn(s.n2(x))) # ============================================================================= # Encoders + PhysicsCore # ============================================================================= class VisionEncoder(nn.Module): """64x64x3 -> 64 patch tokens dim v_dim -> proj to dim.""" def __init__(s,c): super().__init__() n=(c.img//c.v_patch)**2 s.patch=nn.Conv2d(3,c.v_dim,c.v_patch,c.v_patch) s.pos=nn.Parameter(torch.randn(1,n,c.v_dim)*.02) s.blocks=nn.ModuleList([Block(c,i,dim=c.v_dim,heads=8,tk=n,ck=int(c.v_dim*.375)) for i in range(c.v_layers)]) s.norm=nn.LayerNorm(c.v_dim); s.proj=nn.Linear(c.v_dim,c.dim) def forward(s,img): x=s.patch(img).flatten(2).transpose(1,2)+s.pos for b in s.blocks: x=b(x) return s.proj(s.norm(x)) # B,64,dim class AudioEncoder(nn.Module): """log-mel B,1,64,T -> ~T/4 tokens dim a_dim -> proj to dim.""" def __init__(s,c): super().__init__() s.stem=nn.Sequential(nn.Conv2d(1,32,3,2,1),nn.GELU(),nn.Conv2d(32,c.a_dim,3,2,1),nn.GELU()) s.blocks=nn.ModuleList([Block(c,i,dim=c.a_dim,heads=8,tk=64,ck=int(c.a_dim*.375)) for i in range(c.a_layers)]) s.norm=nn.LayerNorm(c.a_dim); s.proj=nn.Linear(c.a_dim,c.dim) def forward(s,mel): x=s.stem(mel) # B,a_dim,16,T/4 x=x.mean(2).transpose(1,2) # B,T/4,a_dim for b in s.blocks: x=b(x) return s.proj(s.norm(x)) class PhysicsCore(nn.Module): """Latent physical state engine. GRU over per-frame visual summaries, phys_slots learned state slots, predicts next-frame embedding from (z_t, action). Aux loss = MSE+cos(pred, vis_{t+1}).""" def __init__(s,c): super().__init__() s.slots=nn.Parameter(torch.randn(1,c.phys_slots,c.dim)*.02) s.read=nn.MultiheadAttention(c.dim,4,batch_first=True) s.cell=nn.GRUCell(c.dim,c.phys_dim) s.act=nn.Linear(c.dim,c.phys_dim) s.pred=nn.Sequential(nn.Linear(c.phys_dim,c.dim),nn.GELU(),nn.Linear(c.dim,c.dim)) s.to_seq=nn.Linear(c.phys_dim,c.dim) s.pd=c.phys_dim def forward(s,frames,action): # frames: B,T,dim (per-frame mean vis emb); action: B,dim B,T,_=frames.shape z=frames.new_zeros(B,s.pd); preds=[] a=s.act(action) for t in range(T): z=s.cell(frames[:,t]+0., z)+0.1*a preds.append(s.pred(z)) pred=torch.stack(preds,1) # B,T,dim (predict t+1) aux=0. if T>1: tgt=frames[:,1:].detach(); p=pred[:,:-1] aux=F.mse_loss(p,tgt)+ (1-F.cosine_similarity(p,tgt,-1).mean()) slots,_=s.read(s.slots.expand(B,-1,-1), s.to_seq(z).unsqueeze(1), s.to_seq(z).unsqueeze(1)) return slots, aux # B,slots,dim class SAMMM(nn.Module): def __init__(s,c): super().__init__(); s.cfg=c s.tok_emb=nn.Embedding(c.vocab_size,c.dim); s.pos_emb=nn.Embedding(c.max_seq_len,c.dim) s.drop=nn.Dropout(c.dropout) s.blocks=nn.ModuleList([Block(c,i,div=(i%2==0 and c.use_diversity)) for i in range(c.n_layers)]) s.norm=nn.LayerNorm(c.dim) s.vision=VisionEncoder(c); s.audio=AudioEncoder(c); s.phys=PhysicsCore(c) s.mode=nn.Embedding(3,c.dim) # 0=[VIS] 1=[AUD] 2=[PHYS] separators s.apply(s._init) s.n_params=sum(p.numel() for p in s.parameters()) print(f"\n{MODEL_NAME} by {ORGANIZATION}: {s.n_params:,} params") @staticmethod def _init(mod): if isinstance(mod,(nn.Linear,nn.Conv2d,nn.Conv1d)): nn.init.normal_(mod.weight,std=0.02) if mod.bias is not None: nn.init.zeros_(mod.bias) elif isinstance(mod,nn.Embedding): nn.init.normal_(mod.weight,std=0.02) def fuse(s,ids,frames=None,mel=None): B=ids.shape[0]; parts=[] aux=ids.new_zeros(1,dtype=torch.float32).squeeze() if frames is not None: B_,T,C,H,W=frames.shape vis=s.vision(frames.reshape(B_*T,C,H,W)).reshape(B_,T,-1,s.cfg.dim) per=vis.mean(2) act=s.tok_emb(ids).mean(1) slots,aux=s.phys(per,act) parts += [s.mode.weight[0].expand(B,1,-1), vis[:,0], s.mode.weight[2].expand(B,1,-1), slots] if mel is not None: parts += [s.mode.weight[1].expand(B,1,-1), s.audio(mel)] parts.append(s.tok_emb(ids)) x=torch.cat(parts,1); n_pref=x.shape[1]-ids.shape[1] return x,n_pref,aux def forward(s,ids,targets=None,frames=None,mel=None): x,n_pref,aux=s.fuse(ids,frames,mel) T=x.shape[1] x=s.drop(x+s.pos_emb(torch.arange(T,device=x.device))) for b in s.blocks: x=b(x) logits=F.linear(s.norm(x),s.tok_emb.weight)[:,n_pref:] loss=None if targets is not None: lm=F.cross_entropy(logits.reshape(-1,s.cfg.vocab_size),targets.reshape(-1),ignore_index=s.cfg.pad_id) loss=lm+s.cfg.aux_phys_w*aux return logits,loss,aux IMG, T = 96, 8 def render(x, y, img=96, r=4): f = torch.zeros(3, img, img) xi, yi = int(max(r, min(img - r - 1, x))), int(max(r, min(img - r - 1, y))) f[0, yi - r:yi + r, xi - r:xi + r] = 1.0 return f def render_world(kind, seed): rng = random.Random(seed); fr = [] if kind == "ball": x, y = 10., 10.; vx = rng.uniform(5, 9); vy = rng.uniform(-2, 0); g = 1.1 for _ in range(T): fr.append(render(x, y, IMG)); x += vx; vy += g; y += vy elif kind == "spring": c = IMG // 2; A = rng.uniform(15, 30); w = rng.uniform(.5, 1.0) for t in range(T): fr.append(render(c + A * math.sin(w * t), c, IMG)) elif kind == "bounce": x, y = 12., IMG // 2; vx = rng.uniform(7, 11) for _ in range(T): fr.append(render(x, y, IMG)); x += vx if x > IMG - 12: vx = -vx else: # twobody x1, x2, y = 15., float(IMG - 15), IMG // 2; v = rng.uniform(4, 7) for _ in range(T): f = render(x1, y, IMG); f[1] = render(x2, y, IMG)[0] fr.append(f) if abs(x2 - x1) > 10: x1 += v; x2 -= v return torch.stack(fr) _DIG = {d: torch.tensor(b).reshape(7, 5).float() for d, b in { "0": [1,1,1,1,1,1,0,0,0,1,1,0,0,0,1,1,0,0,0,1,1,0,0,0,1,1,0,0,0,1,1,1,1,1,1], "1": [0,0,1,0,0,0,1,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,1,1,1,0], "2": [1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,1,1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,1,1,1,1], "3": [1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,0,1,1,1,1,0,0,0,0,1,0,0,0,0,1,1,1,1,1,1], "4": [1,0,0,0,1,1,0,0,0,1,1,0,0,0,1,1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1], "5": [1,1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,1,1,1,1,1], "6": [1,1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,1,1,1,1,1,0,0,0,1,1,0,0,0,1,1,1,1,1,1], "7": [1,1,1,1,1,0,0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0], "8": [1,1,1,1,1,1,0,0,0,1,1,0,0,0,1,1,1,1,1,1,1,0,0,0,1,1,0,0,0,1,1,1,1,1,1], "9": [1,1,1,1,1,1,0,0,0,1,1,0,0,0,1,1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,1,1,1,1,1], }.items()} def render_ocr(num, img=96): f = torch.zeros(3, img, img); x0, y0, s = 6, img // 2 - 14, 4 for i, ch in enumerate(num): g = F.interpolate(_DIG[ch][None, None], scale_factor=s).squeeze() x = x0 + i * (5 * s + 4) f[:, y0:y0 + 7 * s, x:x + 5 * s] = g return f.unsqueeze(0).repeat(T, 1, 1, 1) def render_robot(seed, img=96): rng = random.Random(seed) x = rng.uniform(20, img - 20); y = rng.uniform(20, img - 20) return render(x, y, img).unsqueeze(0).repeat(T, 1, 1, 1) def _trailing_digits(label): d = "" for ch in reversed(label): if ch.isdigit(): d = ch + d elif d: break return d or "0" def render_from_spec(spec): k = spec["kind"] if k in ("ball", "spring", "bounce", "twobody"): return render_world(k, spec["seed"]) if k == "ocr": return render_ocr(_trailing_digits(spec.get("label", "0"))) if k == "robot": return render_robot(spec["seed"]) return torch.zeros(T, 3, IMG, IMG) # --- audio: synthetic mel (stable) + real ESC-50 mel + runtime pool ----------- import hashlib def _stable_hash(s): return int(hashlib.md5(s.encode()).hexdigest(), 16) def synth_audio(sound, n_mels=64, n_t=64): """Deterministic pseudo-mel per sound class (stable across runs).""" base = _stable_hash(sound) % 32 m = torch.zeros(1, n_mels, n_t) m[0, base:base + 8] = torch.linspace(0.2, 1.0, n_t) m += torch.randn(1, n_mels, n_t) * 0.05 return m def wav_to_mel(wav, sr=16000, n_mels=64, n_t=64): """Lightweight log-magnitude spectrogram resized to (1, n_mels, n_t). Dependency-free (torch.stft); the encoder adapts during finetuning.""" if wav.dim() > 1: wav = wav.mean(0) if wav.numel() < 512: wav = F.pad(wav, (0, 512 - wav.numel())) n_fft = 400; hop = 160 spec = torch.stft(wav, n_fft=n_fft, hop_length=hop, window=torch.hann_window(n_fft), return_complex=True).abs() spec = torch.log1p(spec).unsqueeze(0).unsqueeze(0) mel = F.interpolate(spec, size=(n_mels, n_t), mode="bilinear", align_corners=False)[0] return (mel - mel.mean()) / (mel.std() + 1e-5) def load_esc50(n=600): """Real environmental audio (ESC-50) -> (mel, category). Synthetic fallback.""" try: from datasets import load_dataset ds = load_dataset("ashraq/esc50", split="train", streaming=True) pool = [] for i, ex in enumerate(ds): if i >= n: break a = ex["audio"]; wav = torch.tensor(a["array"]).float() cat = ex.get("category") or str(ex.get("target", "sound")) pool.append((wav_to_mel(wav, a.get("sampling_rate", 16000)), str(cat))) if pool: print(f"[esc50] {len(pool)} real clips loaded", flush=True); return pool except Exception as e: print(f"[esc50] unreachable ({type(e).__name__}) — synthetic-only audio", flush=True) return [] def render_av(spec): """Return (frames, mel) for a sample; mel is signal for audio families, else zeros.""" frames = render_from_spec(spec) mel = synth_audio(spec["sound"]) if "sound" in spec else torch.zeros(1, 64, 64) return frames, mel def esc_sample(pool): mel, cat = random.choice(pool) return "[AUD] what is this sound? [CHAT]", f"step 1: classify the sound. Answer: {cat}", mel # ============================================================================= class Tok: def __init__(s, token=None): s.sp = spm.SentencePieceProcessor(); s.sp.Load(resolve_tokenizer(token)) s.vocab = s.sp.GetPieceSize() def enc(s, t): return s.sp.EncodeAsIds(t) def dec(s, ids): return s.sp.DecodeIds(ids) EOS = 2; PAD = 0; L = 80 def make_batch(tok, rows, idx, esc_pool=None, p_esc=0.22): ids_in, tgts, frs, mls = [], [], [], [] for j in idx: if esc_pool and random.random() < p_esc: prompt, answer, mel = esc_sample(esc_pool); frames = torch.zeros(T, 3, IMG, IMG) else: r = rows[j]; prompt, answer = r["prompt"], r["answer"] frames, mel = render_av(r["spec"]) p = tok.enc(prompt); a = tok.enc(" " + answer) + [EOS] full = (p + a)[:L + 1] if len(full) < L + 1: full = full + [PAD] * (L + 1 - len(full)) inp = full[:L]; tgt = full[1:L + 1] cut = len(p) - 1 # supervise only answer tokens tgt = [PAD if k < cut else t for k, t in enumerate(tgt)] ids_in.append(inp); tgts.append(tgt); frs.append(frames); mls.append(mel) ii = torch.tensor(ids_in, device=device); tt = torch.tensor(tgts, device=device) ff = torch.stack(frs).to(device); mm = torch.stack(mls).to(device) return ii, tt, ff, mm # ============================================================================= # Eval — per family; CHAT matches the Answer span, ACTION matches the plan # ============================================================================= def _extract_json(s): i = s.find("{") if i < 0: return None depth = 0 for k in range(i, len(s)): if s[k] == "{": depth += 1 elif s[k] == "}": depth -= 1 if depth == 0: try: return json.loads(s[i:k + 1]) except Exception: return None return None @torch.no_grad() def generate(model, tok, prompt, frames, mel, max_new=48): model.eval() ids = torch.tensor([tok.enc(prompt)], device=device) fb = frames.unsqueeze(0).to(device); mb = mel.unsqueeze(0).to(device) out = [] for _ in range(max_new): logits, _, _ = model(ids, None, fb, mb) nxt = int(logits[0, -1].argmax()) if nxt == EOS: break out.append(nxt); ids = torch.cat([ids, torch.tensor([[nxt]], device=device)], 1) return tok.dec(out) def _chat_match(pred, gold): g = gold.split("Answer:")[-1].strip() p = pred.split("Answer:")[-1].strip() if "Answer:" in pred else pred.strip() return p.startswith(g) or g in p def _action_match(pred, gold): pj, gj = _extract_json(pred), _extract_json(gold) return pj is not None and pj == gj # ============================================================================= # Eval generators — INLINED (no external file / no --data needed) # ============================================================================= def _aj(o): return json.dumps(o, separators=(",", ":")) def _fr(): return random.random() < 0.30 # --------------------------------------------------------------------------- # Deterministic physics simulation — returns ground-truth facts (no torch here; # the finetune renders pixels, the generator only needs the trajectory facts). # Mirrors gen_world() in samg_mm_train.py kind-for-kind so frames match. # --------------------------------------------------------------------------- def simulate_facts(kind, seed): """Replay the trajectory deterministically; return physical facts used to build the supervised answer. Uses an isolated RNG so it cannot perturb the global stream (the finetune reseeds the SAME way before rendering).""" rng = random.Random(seed) if kind == "ball": x, y = 10., 10.; vx = rng.uniform(5, 9); vy = rng.uniform(-2, 0); g = 1.1 xs = [] for _ in range(T): xs.append(x); x += vx; vy += g; y += vy return {"dynamic": "gravity", "direction": "right", "reaches_right": x > IMG - 12, "dx": vx} if kind == "spring": A = rng.uniform(15, 30); w = rng.uniform(.5, 1.0) return {"dynamic": "oscillation", "direction": "oscillating", "amplitude": A, "reaches_right": False} if kind == "bounce": x = 12.; vx = rng.uniform(7, 11); bounced = False for _ in range(T): x += vx if x > IMG - 12: vx = -vx; bounced = True return {"dynamic": "collision", "direction": "right then left", "bounces": bounced, "reaches_right": True} # twobody v = rng.uniform(4, 7) return {"dynamic": "collision", "direction": "converging", "collides": True, "reaches_right": False} PHYS_KINDS = ["ball", "spring", "bounce", "twobody"] OCR_PREFIX = ["speed=", "temp=", "qos=", "zone ", "dock ", "id="] def v_motion(): kind = random.choice(PHYS_KINDS); seed = random.randint(0, 2**31 - 1) f = simulate_facts(kind, seed); fr = _fr() q = ("dans quel sens se déplace l'objet ?" if fr else "which way does the object move?") prompt = f"[VIS] {q} [CHAT]" d = f["direction"] trace = (f"step 1: track the bright object across frames. " f"step 2: its horizontal position evolves -> {d}. Answer: {d}") return dict(family="v_motion", fmt="CHAT", use_v=True, use_a=False, use_p=False, spec={"kind": kind, "seed": seed}, prompt=prompt, answer=trace) def v_ocr(): pre = random.choice(OCR_PREFIX) num = "".join(random.choice("0123456789") for _ in range(random.randint(2, 3))) label = pre + num; fr = _fr() q = "quel nombre est affiché ?" if fr else "what number is shown?" prompt = f"[VIS] [OCR] {q} [CHAT]" trace = (f"step 1: read the bitmap label. step 2: digits = {num}. Answer: {num}") return dict(family="v_ocr", fmt="CHAT", use_v=True, use_a=False, use_p=False, spec={"kind": "ocr", "seed": random.randint(0, 2**31 - 1), "label": label}, prompt=prompt, answer=trace) # --------------------------------------------------------------------------- # PHYSICS — [CHAT] # --------------------------------------------------------------------------- def p_identify(): kind = random.choice(PHYS_KINDS); seed = random.randint(0, 2**31 - 1) f = simulate_facts(kind, seed); fr = _fr() q = ("quelle dynamique régit ce mouvement ?" if fr else "what dynamic governs this motion?") prompt = f"[VIS] [PHYS] {q} [CHAT]" dyn = f["dynamic"] cue = {"gravity": "constant downward acceleration", "oscillation": "periodic back-and-forth around a center", "collision": "abrupt velocity reversal on contact"}[dyn] trace = f"step 1: observe {cue}. step 2: that is {dyn}. Answer: {dyn}" return dict(family="p_identify", fmt="CHAT", use_v=True, use_a=False, use_p=True, spec={"kind": kind, "seed": seed}, prompt=prompt, answer=trace) def p_predict(): kind = random.choice(PHYS_KINDS); seed = random.randint(0, 2**31 - 1) f = simulate_facts(kind, seed); fr = _fr() if kind in ("ball", "bounce"): ans = "yes" if f.get("reaches_right") else "no" q = ("l'objet atteint-il le bord droit ?" if fr else "does the object reach the right edge?") reason = "its rightward velocity carries it to the wall" if ans == "yes" \ else "it falls or stops before the wall" elif kind == "twobody": ans = "yes"; q = ("les deux corps vont-ils entrer en collision ?" if fr else "will the two bodies collide?") reason = "they approach from both sides and meet in the middle" else: ans = "no"; q = ("l'objet quitte-t-il le centre durablement ?" if fr else "does the object leave the center permanently?") reason = "it oscillates and returns to the center each period" prompt = f"[VIS] [PHYS] {q} [CHAT]" trace = f"step 1: {reason}. Answer: {ans}" return dict(family="p_predict", fmt="CHAT", use_v=True, use_a=False, use_p=True, spec={"kind": kind, "seed": seed}, prompt=prompt, answer=trace) # --------------------------------------------------------------------------- # CROSS-MODAL — [ACTION] ({domain,op,params} schema, as MM base pretraining) # --------------------------------------------------------------------------- def x_robot(): seed = random.randint(0, 2**31 - 1); rng = random.Random(seed); fr = _fr() target = random.choice(["dock", "block", "marker", "exit"]) speed = round(rng.uniform(0.2, 0.9), 2); angle = rng.randint(0, 359) q = (f"pousse vers le {target}" if fr else f"push toward the {target}") prompt = f"[VIS] {q} [ACTION]" action = {"domain": "ros", "op": "move", "params": {"speed": speed, "angle": angle, "duration_s": 1}} return dict(family="x_robot", fmt="ACTION", use_v=True, use_a=False, use_p=False, spec={"kind": "robot", "seed": seed}, prompt=prompt, answer=_aj(action)) def x_sensor(): pre = "speed="; val = random.randint(10, 99) label = pre + str(val); limit = 50; fr = _fr() q = (f"si la vitesse dépasse {limit}, ralentis" if fr else f"if speed exceeds {limit}, slow down") prompt = f"[VIS] [OCR] {q} [ACTION]" if val > limit: action = {"domain": "ros", "op": "set_speed", "params": {"value": limit}} else: action = {"domain": "ros", "op": "continue", "params": {}} return dict(family="x_sensor", fmt="ACTION", use_v=True, use_a=False, use_p=False, spec={"kind": "ocr", "seed": random.randint(0, 2**31 - 1), "label": label}, prompt=prompt, answer=_aj(action)) SOUND_CAUSE = { "sharp impact": "collision", "double impact": "collision", "rhythmic creak": "oscillation", "whoosh then thud": "falling object", "servo whir": "motor", } KIND_SOUND = {"ball": "whoosh then thud", "spring": "rhythmic creak", "bounce": "sharp impact", "twobody": "double impact"} SOUND_ACTION = { "alarm": ({"domain": "ros", "op": "stop", "params": {}}, "an alarm"), "servo whir": ({"domain": "ros", "op": "continue", "params": {}}, "a servo whir"), "sharp impact": ({"domain": "ros", "op": "halt", "params": {"reason": "collision"}}, "an impact"), "rhythmic creak": ({"domain": "ros", "op": "slow", "params": {"value": 20}}, "a creak"), } def a_identify(): snd = random.choice(list(SOUND_CAUSE.keys())); cause = SOUND_CAUSE[snd]; fr = _fr() q = "qu'est-ce qui a produit ce son ?" if fr else "what produced this sound?" prompt = f"[AUD] {q} [CHAT]" desc = {"collision": "a sharp broadband transient", "oscillation": "a periodic rhythmic tone", "falling object": "a rising sweep followed by a thud", "motor": "a steady mechanical hum"}[cause] trace = f"step 1: hear {desc}. step 2: that indicates {cause}. Answer: {cause}" return dict(family="a_identify", fmt="CHAT", use_v=False, use_a=True, use_p=False, spec={"kind": "audio", "sound": snd}, prompt=prompt, answer=trace) def a_match(): kind = random.choice(PHYS_KINDS); seed = random.randint(0, 2**31 - 1) true_sound = KIND_SOUND[kind]; fr = _fr() if random.random() < 0.5: snd = true_sound; ans = "yes"; reason = "the sound fits the motion" else: snd = random.choice([s for s in KIND_SOUND.values() if s != true_sound]) ans = "no"; reason = "the sound does not fit the motion" q = ("le son correspond-il au mouvement ?" if fr else "does the sound match the motion?") prompt = f"[VIS] [AUD] [PHYS] {q} [CHAT]" trace = f"step 1: {reason}. Answer: {ans}" return dict(family="a_match", fmt="CHAT", use_v=True, use_a=True, use_p=True, spec={"kind": kind, "seed": seed, "sound": snd}, prompt=prompt, answer=trace) def a_event(): snd = random.choice(list(SOUND_ACTION.keys())); action, desc = SOUND_ACTION[snd]; fr = _fr() instr = {"stop": "arrête le robot" if fr else "stop the robot", "continue": "continue" if fr else "keep going", "halt": "stoppe net" if fr else "halt immediately", "slow": "ralentis" if fr else "slow down"}[action["op"]] q = (f"si tu entends {desc}, {instr}" if fr else f"if you hear {desc}, {instr}") prompt = f"[AUD] {q} [ACTION]" return dict(family="a_event", fmt="ACTION", use_v=False, use_a=True, use_p=False, spec={"kind": "audio", "sound": snd}, prompt=prompt, answer=_aj(action)) # builder: produce a held-out eval set in-memory (disjoint seed, no files) def build_eval(n=1800, seed=99991, families="auto"): """families: 'reasoning' (6 visual/physics/cross-modal) or 'auto'/'audio' (all 9).""" if families == "reasoning": gens = [v_motion, v_ocr, p_identify, p_predict, x_robot, x_sensor] else: gens = [v_motion, v_ocr, p_identify, p_predict, x_robot, x_sensor, a_identify, a_match, a_event] _st = random.getstate(); random.seed(seed) rows = [] for _ in range(n): s = random.choice(gens)() s["text"] = s["prompt"] + " " + s["answer"]; rows.append(s) random.setstate(_st) return rows # ============================================================================= # Benchmark # ============================================================================= def load_ckpt(model, path): ck = torch.load(path, map_location=device) sd = ck["model"] if "model" in ck else ck model.load_state_dict(sd, strict=True) return ck.get("step", "?"), ck.get("best", None) # ============================================================================= # SAM-MM — HuggingFace Space (self-contained; weights pulled from HF) # Architecture inlined above. Set HF_TOKEN as a Space secret for private repos. # ============================================================================= import io try: from PIL import Image except ImportError: os.system(f"{sys.executable} -m pip install -q Pillow --break-system-packages"); from PIL import Image import gradio as gr CHECKPOINTS = { "Reasoning — vision + physics": "AMFORGE/sam-mm-reasoning-checkpoints:best.pt", "Audio-reasoning — + sound": "AMFORGE/sam-mm-audio-reasoning-checkpoints:best.pt", } SCENES = { "🪐 Physics — identify the dynamic": "p_identify", "🎯 Physics — predict the outcome": "p_predict", "➡️ Vision — direction of motion": "v_motion", "🔢 Vision — read the number (OCR)": "v_ocr", "🛰️ Cross-modal — sensor → action": "x_sensor", "🔊 Audio — identify the sound": "a_identify", "🎬 Audio — match sight + sound": "a_match", "⚡ Audio — sound → action": "a_event", } FAMFUNC = {"p_identify": p_identify, "p_predict": p_predict, "v_motion": v_motion, "v_ocr": v_ocr, "x_sensor": x_sensor, "a_identify": a_identify, "a_match": a_match, "a_event": a_event} AUDIO_FAMS = {"a_identify", "a_match", "a_event"} _STATE = {"model": None, "tok": None, "ckpt": None} def _load(ckpt): if _STATE["model"] is not None and _STATE["ckpt"] == ckpt: return _STATE["model"], _STATE["tok"] token = get_hf_token() tok = Tok(token=token) model = SAMMM(Config()).to(device) load_ckpt(model, resolve_ckpt(ckpt, token)); model.eval() _STATE.update(model=model, tok=tok, ckpt=ckpt) return model, tok def _montage(frames): n = frames.shape[0] tiles = [(frames[i].clamp(0,1).permute(1,2,0)*255).byte().cpu().numpy() for i in range(n)] w = IMG*n + (n-1)*2 img = Image.new("RGB", (w, IMG), (17,18,26)) for i,t in enumerate(tiles): img.paste(Image.fromarray(t), (i*(IMG+2),0)) return img.resize((w*4, IMG*4), Image.NEAREST) def _infer(ckpt_label, scene_label, max_new): ckpt = CHECKPOINTS[ckpt_label]; fam = SCENES[scene_label] s = FAMFUNC[fam]() frames, mel = render_av(s["spec"]) model, tok = _load(ckpt) pred = generate(model, tok, s["prompt"], frames, mel, max_new=int(max_new)) gold = s["answer"].split("Answer:")[-1].strip() got = pred.split("Answer:")[-1].strip() if "Answer:" in pred else pred.strip() ok = _chat_match(pred, s["answer"]) return s, frames, got, gold, ok, fam def run_one(ckpt_label, scene_label, max_new): s, frames, got, gold, ok, fam = _infer(ckpt_label, scene_label, max_new) warn = "" if fam in AUDIO_FAMS and "audio" not in ckpt_label.lower(): warn = ("\n\n> ⚠️ This is an **audio** scene on the **reasoning** checkpoint — " "it never learned sound, so a correct answer here is chance. " "Switch the checkpoint to *Audio-reasoning* to test it for real.") verdict = "✅ **correct**" if ok else "❌ **mismatch**" md = (f"### {verdict}\n" f"**Prompt → SAM-MM**\n```\n{s['prompt']}\n```\n" f"**Model answer:** `{got}`  •  **Ground truth:** `{gold}`") if "sound" in s["spec"]: md += (f"\n\n*Sound cue `{s['spec']['sound']}` → a deterministic log-mel " f"(no audible file; the encoder reads the spectrogram).*") md += warn return _montage(frames), md def run_batch(ckpt_label, scene_label, max_new, n=20): hits = 0; lines = [] for _ in range(int(n)): s, frames, got, gold, ok, fam = _infer(ckpt_label, scene_label, max_new) hits += int(ok) lines.append(f"{'✅' if ok else '❌'} `{got}` vs `{gold}`") acc = 100*hits/int(n) head = f"### {hits}/{int(n)} correct   →   **{acc:.0f}%** exact-match\n\n" return head + "\n".join(lines) CSS = """ .gradio-container {max-width: 980px !important} #title {text-align:center} #frames img {image-rendering: pixelated; border-radius: 10px} footer {display:none !important} """ with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo", neutral_hue="slate"), css=CSS, title="SAM-MM · multimodal demo") as demo: gr.Markdown("# 🧠 SAM-MM — a 58M multimodal model that *reasons*", elem_id="title") gr.Markdown( "Pick a scene. SAM-MM perceives the **rendered frames** (and, for audio scenes, a " "**log-mel spectrogram**), then answers in `[CHAT]` text or a `[ACTION]` JSON record. " "Frames are synthetic — this is the model's native world. Nothing is hard-coded: each " "scene is freshly generated, the model decodes token-by-token, and the answer is checked " "against ground truth computed independently.") with gr.Row(): with gr.Column(scale=1): ckpt = gr.Dropdown(list(CHECKPOINTS), value=list(CHECKPOINTS)[0], label="Checkpoint") scene = gr.Dropdown(list(SCENES), value=list(SCENES)[0], label="Scene") max_new = gr.Slider(24, 96, value=64, step=8, label="max new tokens") with gr.Row(): b1 = gr.Button("Generate & run", variant="primary") b2 = gr.Button("Run 20 (accuracy)") with gr.Column(scale=2): img = gr.Image(label="What SAM-MM sees", elem_id="frames") md = gr.Markdown() batch = gr.Markdown() b1.click(run_one, [ckpt, scene, max_new], [img, md]) b2.click(run_batch, [ckpt, scene, max_new], [batch]) gr.Markdown( "---\n**Honest notes.** Physics & motion are SAM-MM's strength (its world-model carries " "real dynamics). OCR generalizes to unseen numbers but isn't perfect. The cross-modal " "`[ACTION]` family is weaker. **Audio is the weak modality** — it was trained on synthetic " "pseudo-mel, so strong audio scores here partly reflect that, not true listening. " "Architecture internals are proprietary and not exposed.") if __name__ == "__main__": demo.launch()