Spaces:
Running
Running
| """ | |
| ============================================================================= | |
| SAM-MM Benchmark β reproducible per-family evaluation for the SAM-MM line | |
| SparseMind / AMFORGE | |
| ============================================================================= | |
| Checkpoint-driven and fully self-contained: the held-out eval set is GENERATED | |
| INTERNALLY (disjoint seed 99991), so no external data file or generator script | |
| is needed. It renders frames/mel from each sample's spec, greedy-decodes the | |
| answer, and reports per-family exact match + a CHAT/ACTION breakdown + aggregate. | |
| Notebook (Colab/Kaggle): edit the variables at the top of main() and run. | |
| Terminal: | |
| python samg_mm_benchmark.py --ckpt AMFORGE/sam-mm-reasoning-checkpoints:best.pt --families reasoning | |
| python samg_mm_benchmark.py --ckpt AMFORGE/sam-mm-audio-reasoning-checkpoints:best.pt --families audio | |
| python samg_mm_benchmark.py --ckpt ./best.pt --n 3000 --n-per 100 | |
| Self-contained: the SAM-MM model, the renderers, the tokenizer resolver are | |
| inlined verbatim. External vision/audio LMs are NOT comparable on these | |
| SAM-specific synthetic tasks (different input pipelines), so this is an honest | |
| internal per-family report; add a baseline column only where truly comparable. | |
| ============================================================================= | |
| """ | |
| from __future__ import annotations | |
| import os, sys, json, math, random, argparse | |
| from dataclasses import dataclass, asdict | |
| from typing import Optional | |
| from enum import IntEnum | |
| import torch, torch.nn as nn, torch.nn.functional as F | |
| try: import sentencepiece as spm | |
| except ImportError: | |
| os.system(f"{sys.executable} -m pip install -q sentencepiece --break-system-packages"); import sentencepiece as spm | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| BF16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported() | |
| TOK_REPO, TOK_FILE = "AMFORGE/samg_mm_tok", "samg_mm_tokenizer.model" | |
| DEFAULT_CKPT_REPO = "AMFORGE/sam-mm-reasoning-checkpoints" | |
| ORGANIZATION, MODEL_NAME = "AMFORGE", "SAM-MM" | |
| def _pip(p): os.system(f"{sys.executable} -m pip install -q {p} --break-system-packages") | |
| def get_hf_token(): | |
| t = os.environ.get("HF_TOKEN") or "" | |
| if not t: | |
| try: | |
| from kaggle_secrets import UserSecretsClient | |
| t = UserSecretsClient().get_secret("HF_TOKEN") or "" | |
| except Exception: pass | |
| if not t: | |
| try: | |
| from google.colab import userdata; t = userdata.get("HF_TOKEN") or "" | |
| except Exception: pass | |
| if not t: | |
| p = os.path.expanduser("~/.cache/huggingface/token") | |
| if os.path.exists(p): t = open(p).read().strip() | |
| return t | |
| def resolve_tokenizer(token=None): | |
| for p in [TOK_FILE, os.path.join("tokenizer", TOK_FILE)]: | |
| if os.path.isfile(p): return p | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| except ImportError: | |
| _pip("huggingface_hub"); from huggingface_hub import hf_hub_download | |
| return hf_hub_download(TOK_REPO, TOK_FILE, token=token) | |
| def resolve_ckpt(spec, token=None): | |
| """Local path, 'repo:file', or bare 'file' from the default repo.""" | |
| if os.path.isfile(spec): return spec | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| except ImportError: | |
| _pip("huggingface_hub"); from huggingface_hub import hf_hub_download | |
| if ":" in spec and not spec.startswith("/"): | |
| repo, fn = spec.split(":", 1) | |
| else: | |
| repo, fn = DEFAULT_CKPT_REPO, spec | |
| return hf_hub_download(repo, fn, token=token) | |
| # ============================================================================= | |
| # SAM-MM model β INLINED VERBATIM from samg_mm_train.py (state_dict-compatible) | |
| # ============================================================================= | |
| class NeuronType(IntEnum): | |
| STEM=0; EXCITATORY=1; INHIBITORY=2; MEMORY=3; RELAY=4; MODULATORY=5; PATTERN=6 | |
| TARGET_DISTRIBUTION = {NeuronType.STEM:.10, NeuronType.EXCITATORY:.35, NeuronType.INHIBITORY:.10, | |
| NeuronType.MEMORY:.15, NeuronType.RELAY:.10, NeuronType.MODULATORY:.08, NeuronType.PATTERN:.12} | |
| class Config: | |
| vocab_size:int=32000; dim:int=320; n_layers:int=8; n_heads:int=8 | |
| max_seq_len:int=1024; channel_top_k:int=120; token_top_k:int=128; ffn_mult:int=4 | |
| dropout:float=0.1; pad_id:int=0; eos_id:int=2; use_diversity:bool=True | |
| # multimodal | |
| v_dim:int=320; v_layers:int=7; v_patch:int=8; img:int=96 | |
| a_dim:int=320; a_layers:int=6; mel:int=64 | |
| phys_dim:int=320; phys_slots:int=4 | |
| # plasticity (text diversity layers) | |
| target_stem_ratio:float=.10; min_stem_ratio:float=.08; stem_plasticity:float=.012 | |
| reversion_rate:float=.012; min_age_before_revert:int=8; update_interval:int=10 | |
| baseline_revert_ratio:float=.5; inhibition_strength:float=.08 | |
| modulation_strength:float=.1; excitation_strength:float=.3 | |
| # train | |
| batch_size:int=16; grad_accum:int=2; lr:float=5e-4; max_steps:int=40000 | |
| warmup:int=1500; eval_every:int=1000; save_every:int=1000; patience:int=12 | |
| log_every:int=50; aux_phys_w:float=0.5 | |
| class DynamicTypeManager(nn.Module): | |
| def __init__(self, dim, cfg): | |
| super().__init__() | |
| self.dim, self.cfg = dim, cfg | |
| t=[] | |
| for nt,p in TARGET_DISTRIBUTION.items(): t += [nt.value]*int(dim*p) | |
| while len(t)<dim: t.append(0) | |
| random.shuffle(t) | |
| self.register_buffer("neuron_types", torch.tensor(t,dtype=torch.long)) | |
| self.register_buffer("activation_history", torch.zeros(dim)) | |
| self.register_buffer("age", torch.randint(0,cfg.min_age_before_revert,(dim,),dtype=torch.long)) | |
| self.register_buffer("cycle_counter", torch.tensor(0,dtype=torch.long)) | |
| def get_type_mask(self,t): return (self.neuron_types==t.value).float() | |
| def step(self,x): | |
| if not self.training: return | |
| self.activation_history.mul_(.95).add_(x.abs().mean((0,1)).float(),alpha=.05) | |
| self.age += 1; self.cycle_counter += 1 | |
| if self.cycle_counter.item()%self.cfg.update_interval: return | |
| spec=(self.neuron_types!=0).nonzero().view(-1) | |
| if len(spec)>4: | |
| acts=self.activation_history[spec] | |
| thr=torch.quantile(acts,.30).item() | |
| cand=spec[(acts<=max(thr,1e-6)) & (self.age[spec]>self.cfg.min_age_before_revert)] | |
| n=min(max(1,int(self.dim*self.cfg.reversion_rate)),len(cand)) | |
| if n>0: | |
| sel=cand[(-self.activation_history[cand]).topk(n)[1]] | |
| self.neuron_types[sel]=0; self.age[sel]=0; self.activation_history[sel]*=.5 | |
| stem=(self.neuron_types==0).nonzero().view(-1) | |
| floor=max(2,int(self.dim*self.cfg.min_stem_ratio)) | |
| if len(stem)>floor: | |
| n=min(max(1,int(self.dim*self.cfg.stem_plasticity)),len(stem)-floor) | |
| sel=stem[self.age[stem].float().topk(n)[1]] | |
| for ni in sel: | |
| w=torch.tensor([TARGET_DISTRIBUTION[t] for t in NeuronType if t!=NeuronType.STEM]) | |
| self.neuron_types[ni]=list(NeuronType)[1:][torch.multinomial(w/w.sum(),1).item()].value | |
| self.age[ni]=0 | |
| class GentleInhibition(nn.Module): | |
| def __init__(s,d,c): super().__init__(); s.k=c.inhibition_strength; s.noise_detector=nn.Sequential(nn.Linear(d,d//4),nn.ReLU(),nn.Linear(d//4,d),nn.Sigmoid()); s.threshold=nn.Parameter(torch.tensor(.15)) | |
| def forward(s,x,m): | |
| sup=(x.abs()<s.threshold)&(s.noise_detector(x)<.3) | |
| return x*(1-sup.float()*m.view(1,1,-1)*s.k) | |
| class StrongExcitation(nn.Module): | |
| def __init__(s,d,c): super().__init__(); s.k=c.excitation_strength; s.integrator=nn.Sequential(nn.Linear(d,d),nn.GELU(),nn.Linear(d,d)); s.importance=nn.Sequential(nn.Linear(d,d),nn.Sigmoid()) | |
| def forward(s,x,m): return x+s.integrator(x)*s.importance(x)*m.view(1,1,-1)*s.k | |
| class GentleModulation(nn.Module): | |
| def __init__(s,d,c): super().__init__(); s.k=c.modulation_strength; s.context=nn.Sequential(nn.Linear(d,d//4),nn.GELU(),nn.Linear(d//4,d),nn.Tanh()) | |
| def forward(s,x,m): | |
| B,T,D=x.shape; den=torch.arange(1,T+1,device=x.device,dtype=x.dtype).view(1,T,1) | |
| return x+s.context(x.cumsum(1)/den)*m.view(1,1,-1)*s.k | |
| class PatternDetection(nn.Module): | |
| def __init__(s,d): super().__init__(); s.conv3=nn.Conv1d(d,d//2,3,groups=d//2); s.conv5=nn.Conv1d(d,d//2,5,groups=d//2); s.combine=nn.Linear(d,d) | |
| def forward(s,x,m): | |
| xt=x.transpose(1,2) | |
| p=torch.cat([s.conv3(F.pad(xt,(2,0))).transpose(1,2),s.conv5(F.pad(xt,(4,0))).transpose(1,2)],-1) | |
| return x+s.combine(p)*m.view(1,1,-1)*.2 | |
| class RelayNetwork(nn.Module): | |
| def __init__(s,d): super().__init__(); s.gate=nn.Sequential(nn.Linear(d,d),nn.Sigmoid()); s.transform=nn.Linear(d,d) | |
| def forward(s,x,m): return x+s.transform(x)*s.gate(x)*m.view(1,1,-1)*.2 | |
| class BalancedDiversityLayer(nn.Module): | |
| def __init__(s,d,c): | |
| super().__init__(); s.type_manager=DynamicTypeManager(d,c) | |
| s.inhibition=GentleInhibition(d,c); s.excitation=StrongExcitation(d,c) | |
| s.modulation=GentleModulation(d,c); s.pattern=PatternDetection(d); s.relay=RelayNetwork(d) | |
| s.norm=nn.LayerNorm(d); s.output=nn.Linear(d,d) | |
| def forward(s,x): | |
| r=x; x=s.norm(x); tm=s.type_manager; tm.step(x) | |
| x=s.excitation(x,tm.get_type_mask(NeuronType.EXCITATORY)) | |
| x=s.pattern(x,tm.get_type_mask(NeuronType.PATTERN)) | |
| x=s.relay(x,tm.get_type_mask(NeuronType.RELAY)) | |
| x=s.modulation(x,tm.get_type_mask(NeuronType.MODULATORY)) | |
| x=s.inhibition(x,tm.get_type_mask(NeuronType.INHIBITORY)) | |
| return r+s.output(x)*.5 | |
| class SparseGate(nn.Module): | |
| def __init__(s,d,k): | |
| super().__init__(); s.k=k; s.scorer=nn.Sequential(nn.Linear(d,d//4),nn.SiLU(),nn.Linear(d//4,d)) | |
| nn.init.zeros_(s.scorer[-1].weight); nn.init.zeros_(s.scorer[-1].bias) | |
| def forward(s,x): | |
| sc=torch.sigmoid(s.scorer(x)); k=min(s.k,x.shape[-1]) | |
| thr=sc.topk(k,-1)[0][...,-1:]; hard=(sc>=thr).float(); soft=torch.sigmoid((sc-thr)*10) | |
| return x*(hard-soft.detach()+soft) | |
| class SparseAttn(nn.Module): | |
| def __init__(s,d,h,tk): super().__init__(); s.h,s.hd,s.tk=h,d//h,tk; s.qkv=nn.Linear(d,3*d); s.out=nn.Linear(d,d) | |
| def forward(s,x): | |
| B,T,D=x.shape | |
| q,k,v=s.qkv(x).reshape(B,T,3,s.h,s.hd).permute(2,0,3,1,4) | |
| a=(q@k.transpose(-2,-1))*s.hd**-.5 | |
| a=a.masked_fill(torch.triu(torch.ones(T,T,device=x.device),1).bool(),float("-inf")) | |
| _,i=a.topk(min(s.tk,T),-1) | |
| m=torch.zeros_like(a,dtype=torch.bool).scatter_(-1,i,True) | |
| a=torch.nan_to_num(F.softmax(a.masked_fill(~m,float("-inf")),-1),0.) | |
| return s.out((a@v).transpose(1,2).reshape(B,T,D)) | |
| class SparseFFN(nn.Module): | |
| def __init__(s,d,m,ck): super().__init__(); s.up=nn.Linear(d,d*m); s.gate=SparseGate(d*m,ck); s.down=nn.Linear(d*m,d) | |
| def forward(s,x): return s.down(s.gate(F.silu(s.up(x)))) | |
| class Block(nn.Module): | |
| def __init__(s,c,i,dim=None,heads=None,tk=None,ck=None,div=False): | |
| super().__init__(); d=dim or c.dim; h=heads or c.n_heads | |
| s.n1=nn.LayerNorm(d); s.attn=SparseAttn(d,h,tk or c.token_top_k) | |
| s.n2=nn.LayerNorm(d); s.ffn=SparseFFN(d,c.ffn_mult,(ck or c.channel_top_k)*c.ffn_mult) | |
| s.drop=nn.Dropout(c.dropout); s.div=div | |
| if div: s.diversity=BalancedDiversityLayer(d,c) | |
| def forward(s,x): | |
| x=x+s.drop(s.attn(s.n1(x))) | |
| if s.div: x=s.diversity(x) | |
| return x+s.drop(s.ffn(s.n2(x))) | |
| # ============================================================================= | |
| # Encoders + PhysicsCore | |
| # ============================================================================= | |
| class VisionEncoder(nn.Module): | |
| """64x64x3 -> 64 patch tokens dim v_dim -> proj to dim.""" | |
| def __init__(s,c): | |
| super().__init__() | |
| n=(c.img//c.v_patch)**2 | |
| s.patch=nn.Conv2d(3,c.v_dim,c.v_patch,c.v_patch) | |
| s.pos=nn.Parameter(torch.randn(1,n,c.v_dim)*.02) | |
| s.blocks=nn.ModuleList([Block(c,i,dim=c.v_dim,heads=8,tk=n,ck=int(c.v_dim*.375)) for i in range(c.v_layers)]) | |
| s.norm=nn.LayerNorm(c.v_dim); s.proj=nn.Linear(c.v_dim,c.dim) | |
| def forward(s,img): | |
| x=s.patch(img).flatten(2).transpose(1,2)+s.pos | |
| for b in s.blocks: x=b(x) | |
| return s.proj(s.norm(x)) # B,64,dim | |
| class AudioEncoder(nn.Module): | |
| """log-mel B,1,64,T -> ~T/4 tokens dim a_dim -> proj to dim.""" | |
| def __init__(s,c): | |
| super().__init__() | |
| s.stem=nn.Sequential(nn.Conv2d(1,32,3,2,1),nn.GELU(),nn.Conv2d(32,c.a_dim,3,2,1),nn.GELU()) | |
| s.blocks=nn.ModuleList([Block(c,i,dim=c.a_dim,heads=8,tk=64,ck=int(c.a_dim*.375)) for i in range(c.a_layers)]) | |
| s.norm=nn.LayerNorm(c.a_dim); s.proj=nn.Linear(c.a_dim,c.dim) | |
| def forward(s,mel): | |
| x=s.stem(mel) # B,a_dim,16,T/4 | |
| x=x.mean(2).transpose(1,2) # B,T/4,a_dim | |
| for b in s.blocks: x=b(x) | |
| return s.proj(s.norm(x)) | |
| class PhysicsCore(nn.Module): | |
| """Latent physical state engine. GRU over per-frame visual summaries, | |
| phys_slots learned state slots, predicts next-frame embedding from | |
| (z_t, action). Aux loss = MSE+cos(pred, vis_{t+1}).""" | |
| def __init__(s,c): | |
| super().__init__() | |
| s.slots=nn.Parameter(torch.randn(1,c.phys_slots,c.dim)*.02) | |
| s.read=nn.MultiheadAttention(c.dim,4,batch_first=True) | |
| s.cell=nn.GRUCell(c.dim,c.phys_dim) | |
| s.act=nn.Linear(c.dim,c.phys_dim) | |
| s.pred=nn.Sequential(nn.Linear(c.phys_dim,c.dim),nn.GELU(),nn.Linear(c.dim,c.dim)) | |
| s.to_seq=nn.Linear(c.phys_dim,c.dim) | |
| s.pd=c.phys_dim | |
| def forward(s,frames,action): | |
| # frames: B,T,dim (per-frame mean vis emb); action: B,dim | |
| B,T,_=frames.shape | |
| z=frames.new_zeros(B,s.pd); preds=[] | |
| a=s.act(action) | |
| for t in range(T): | |
| z=s.cell(frames[:,t]+0., z)+0.1*a | |
| preds.append(s.pred(z)) | |
| pred=torch.stack(preds,1) # B,T,dim (predict t+1) | |
| aux=0. | |
| if T>1: | |
| tgt=frames[:,1:].detach(); p=pred[:,:-1] | |
| aux=F.mse_loss(p,tgt)+ (1-F.cosine_similarity(p,tgt,-1).mean()) | |
| slots,_=s.read(s.slots.expand(B,-1,-1), s.to_seq(z).unsqueeze(1), s.to_seq(z).unsqueeze(1)) | |
| return slots, aux # B,slots,dim | |
| class SAMMM(nn.Module): | |
| def __init__(s,c): | |
| super().__init__(); s.cfg=c | |
| s.tok_emb=nn.Embedding(c.vocab_size,c.dim); s.pos_emb=nn.Embedding(c.max_seq_len,c.dim) | |
| s.drop=nn.Dropout(c.dropout) | |
| s.blocks=nn.ModuleList([Block(c,i,div=(i%2==0 and c.use_diversity)) for i in range(c.n_layers)]) | |
| s.norm=nn.LayerNorm(c.dim) | |
| s.vision=VisionEncoder(c); s.audio=AudioEncoder(c); s.phys=PhysicsCore(c) | |
| s.mode=nn.Embedding(3,c.dim) # 0=[VIS] 1=[AUD] 2=[PHYS] separators | |
| s.apply(s._init) | |
| s.n_params=sum(p.numel() for p in s.parameters()) | |
| print(f"\n{MODEL_NAME} by {ORGANIZATION}: {s.n_params:,} params") | |
| def _init(mod): | |
| if isinstance(mod,(nn.Linear,nn.Conv2d,nn.Conv1d)): | |
| nn.init.normal_(mod.weight,std=0.02) | |
| if mod.bias is not None: nn.init.zeros_(mod.bias) | |
| elif isinstance(mod,nn.Embedding): nn.init.normal_(mod.weight,std=0.02) | |
| def fuse(s,ids,frames=None,mel=None): | |
| B=ids.shape[0]; parts=[] | |
| aux=ids.new_zeros(1,dtype=torch.float32).squeeze() | |
| if frames is not None: | |
| B_,T,C,H,W=frames.shape | |
| vis=s.vision(frames.reshape(B_*T,C,H,W)).reshape(B_,T,-1,s.cfg.dim) | |
| per=vis.mean(2) | |
| act=s.tok_emb(ids).mean(1) | |
| slots,aux=s.phys(per,act) | |
| parts += [s.mode.weight[0].expand(B,1,-1), vis[:,0], | |
| s.mode.weight[2].expand(B,1,-1), slots] | |
| if mel is not None: | |
| parts += [s.mode.weight[1].expand(B,1,-1), s.audio(mel)] | |
| parts.append(s.tok_emb(ids)) | |
| x=torch.cat(parts,1); n_pref=x.shape[1]-ids.shape[1] | |
| return x,n_pref,aux | |
| def forward(s,ids,targets=None,frames=None,mel=None): | |
| x,n_pref,aux=s.fuse(ids,frames,mel) | |
| T=x.shape[1] | |
| x=s.drop(x+s.pos_emb(torch.arange(T,device=x.device))) | |
| for b in s.blocks: x=b(x) | |
| logits=F.linear(s.norm(x),s.tok_emb.weight)[:,n_pref:] | |
| loss=None | |
| if targets is not None: | |
| lm=F.cross_entropy(logits.reshape(-1,s.cfg.vocab_size),targets.reshape(-1),ignore_index=s.cfg.pad_id) | |
| loss=lm+s.cfg.aux_phys_w*aux | |
| return logits,loss,aux | |
| IMG, T = 96, 8 | |
| def render(x, y, img=96, r=4): | |
| f = torch.zeros(3, img, img) | |
| xi, yi = int(max(r, min(img - r - 1, x))), int(max(r, min(img - r - 1, y))) | |
| f[0, yi - r:yi + r, xi - r:xi + r] = 1.0 | |
| return f | |
| def render_world(kind, seed): | |
| rng = random.Random(seed); fr = [] | |
| if kind == "ball": | |
| x, y = 10., 10.; vx = rng.uniform(5, 9); vy = rng.uniform(-2, 0); g = 1.1 | |
| for _ in range(T): fr.append(render(x, y, IMG)); x += vx; vy += g; y += vy | |
| elif kind == "spring": | |
| c = IMG // 2; A = rng.uniform(15, 30); w = rng.uniform(.5, 1.0) | |
| for t in range(T): fr.append(render(c + A * math.sin(w * t), c, IMG)) | |
| elif kind == "bounce": | |
| x, y = 12., IMG // 2; vx = rng.uniform(7, 11) | |
| for _ in range(T): | |
| fr.append(render(x, y, IMG)); x += vx | |
| if x > IMG - 12: vx = -vx | |
| else: # twobody | |
| x1, x2, y = 15., float(IMG - 15), IMG // 2; v = rng.uniform(4, 7) | |
| for _ in range(T): | |
| f = render(x1, y, IMG); f[1] = render(x2, y, IMG)[0] | |
| fr.append(f) | |
| if abs(x2 - x1) > 10: x1 += v; x2 -= v | |
| return torch.stack(fr) | |
| _DIG = {d: torch.tensor(b).reshape(7, 5).float() for d, b in { | |
| "0": [1,1,1,1,1,1,0,0,0,1,1,0,0,0,1,1,0,0,0,1,1,0,0,0,1,1,0,0,0,1,1,1,1,1,1], | |
| "1": [0,0,1,0,0,0,1,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,1,1,1,0], | |
| "2": [1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,1,1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,1,1,1,1], | |
| "3": [1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,0,1,1,1,1,0,0,0,0,1,0,0,0,0,1,1,1,1,1,1], | |
| "4": [1,0,0,0,1,1,0,0,0,1,1,0,0,0,1,1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1], | |
| "5": [1,1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,1,1,1,1,1], | |
| "6": [1,1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,1,1,1,1,1,0,0,0,1,1,0,0,0,1,1,1,1,1,1], | |
| "7": [1,1,1,1,1,0,0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0], | |
| "8": [1,1,1,1,1,1,0,0,0,1,1,0,0,0,1,1,1,1,1,1,1,0,0,0,1,1,0,0,0,1,1,1,1,1,1], | |
| "9": [1,1,1,1,1,1,0,0,0,1,1,0,0,0,1,1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,1,1,1,1,1], | |
| }.items()} | |
| def render_ocr(num, img=96): | |
| f = torch.zeros(3, img, img); x0, y0, s = 6, img // 2 - 14, 4 | |
| for i, ch in enumerate(num): | |
| g = F.interpolate(_DIG[ch][None, None], scale_factor=s).squeeze() | |
| x = x0 + i * (5 * s + 4) | |
| f[:, y0:y0 + 7 * s, x:x + 5 * s] = g | |
| return f.unsqueeze(0).repeat(T, 1, 1, 1) | |
| def render_robot(seed, img=96): | |
| rng = random.Random(seed) | |
| x = rng.uniform(20, img - 20); y = rng.uniform(20, img - 20) | |
| return render(x, y, img).unsqueeze(0).repeat(T, 1, 1, 1) | |
| def _trailing_digits(label): | |
| d = "" | |
| for ch in reversed(label): | |
| if ch.isdigit(): d = ch + d | |
| elif d: break | |
| return d or "0" | |
| def render_from_spec(spec): | |
| k = spec["kind"] | |
| if k in ("ball", "spring", "bounce", "twobody"): return render_world(k, spec["seed"]) | |
| if k == "ocr": return render_ocr(_trailing_digits(spec.get("label", "0"))) | |
| if k == "robot": return render_robot(spec["seed"]) | |
| return torch.zeros(T, 3, IMG, IMG) | |
| # --- audio: synthetic mel (stable) + real ESC-50 mel + runtime pool ----------- | |
| import hashlib | |
| def _stable_hash(s): return int(hashlib.md5(s.encode()).hexdigest(), 16) | |
| def synth_audio(sound, n_mels=64, n_t=64): | |
| """Deterministic pseudo-mel per sound class (stable across runs).""" | |
| base = _stable_hash(sound) % 32 | |
| m = torch.zeros(1, n_mels, n_t) | |
| m[0, base:base + 8] = torch.linspace(0.2, 1.0, n_t) | |
| m += torch.randn(1, n_mels, n_t) * 0.05 | |
| return m | |
| def wav_to_mel(wav, sr=16000, n_mels=64, n_t=64): | |
| """Lightweight log-magnitude spectrogram resized to (1, n_mels, n_t). | |
| Dependency-free (torch.stft); the encoder adapts during finetuning.""" | |
| if wav.dim() > 1: wav = wav.mean(0) | |
| if wav.numel() < 512: wav = F.pad(wav, (0, 512 - wav.numel())) | |
| n_fft = 400; hop = 160 | |
| spec = torch.stft(wav, n_fft=n_fft, hop_length=hop, | |
| window=torch.hann_window(n_fft), return_complex=True).abs() | |
| spec = torch.log1p(spec).unsqueeze(0).unsqueeze(0) | |
| mel = F.interpolate(spec, size=(n_mels, n_t), mode="bilinear", align_corners=False)[0] | |
| return (mel - mel.mean()) / (mel.std() + 1e-5) | |
| def load_esc50(n=600): | |
| """Real environmental audio (ESC-50) -> (mel, category). Synthetic fallback.""" | |
| try: | |
| from datasets import load_dataset | |
| ds = load_dataset("ashraq/esc50", split="train", streaming=True) | |
| pool = [] | |
| for i, ex in enumerate(ds): | |
| if i >= n: break | |
| a = ex["audio"]; wav = torch.tensor(a["array"]).float() | |
| cat = ex.get("category") or str(ex.get("target", "sound")) | |
| pool.append((wav_to_mel(wav, a.get("sampling_rate", 16000)), str(cat))) | |
| if pool: | |
| print(f"[esc50] {len(pool)} real clips loaded", flush=True); return pool | |
| except Exception as e: | |
| print(f"[esc50] unreachable ({type(e).__name__}) β synthetic-only audio", flush=True) | |
| return [] | |
| def render_av(spec): | |
| """Return (frames, mel) for a sample; mel is signal for audio families, else zeros.""" | |
| frames = render_from_spec(spec) | |
| mel = synth_audio(spec["sound"]) if "sound" in spec else torch.zeros(1, 64, 64) | |
| return frames, mel | |
| def esc_sample(pool): | |
| mel, cat = random.choice(pool) | |
| return "[AUD] what is this sound? [CHAT]", f"step 1: classify the sound. Answer: {cat}", mel | |
| # ============================================================================= | |
| class Tok: | |
| def __init__(s, token=None): | |
| s.sp = spm.SentencePieceProcessor(); s.sp.Load(resolve_tokenizer(token)) | |
| s.vocab = s.sp.GetPieceSize() | |
| def enc(s, t): return s.sp.EncodeAsIds(t) | |
| def dec(s, ids): return s.sp.DecodeIds(ids) | |
| EOS = 2; PAD = 0; L = 80 | |
| def make_batch(tok, rows, idx, esc_pool=None, p_esc=0.22): | |
| ids_in, tgts, frs, mls = [], [], [], [] | |
| for j in idx: | |
| if esc_pool and random.random() < p_esc: | |
| prompt, answer, mel = esc_sample(esc_pool); frames = torch.zeros(T, 3, IMG, IMG) | |
| else: | |
| r = rows[j]; prompt, answer = r["prompt"], r["answer"] | |
| frames, mel = render_av(r["spec"]) | |
| p = tok.enc(prompt); a = tok.enc(" " + answer) + [EOS] | |
| full = (p + a)[:L + 1] | |
| if len(full) < L + 1: full = full + [PAD] * (L + 1 - len(full)) | |
| inp = full[:L]; tgt = full[1:L + 1] | |
| cut = len(p) - 1 # supervise only answer tokens | |
| tgt = [PAD if k < cut else t for k, t in enumerate(tgt)] | |
| ids_in.append(inp); tgts.append(tgt); frs.append(frames); mls.append(mel) | |
| ii = torch.tensor(ids_in, device=device); tt = torch.tensor(tgts, device=device) | |
| ff = torch.stack(frs).to(device); mm = torch.stack(mls).to(device) | |
| return ii, tt, ff, mm | |
| # ============================================================================= | |
| # Eval β per family; CHAT matches the Answer span, ACTION matches the plan | |
| # ============================================================================= | |
| def _extract_json(s): | |
| i = s.find("{") | |
| if i < 0: return None | |
| depth = 0 | |
| for k in range(i, len(s)): | |
| if s[k] == "{": depth += 1 | |
| elif s[k] == "}": | |
| depth -= 1 | |
| if depth == 0: | |
| try: return json.loads(s[i:k + 1]) | |
| except Exception: return None | |
| return None | |
| def generate(model, tok, prompt, frames, mel, max_new=48): | |
| model.eval() | |
| ids = torch.tensor([tok.enc(prompt)], device=device) | |
| fb = frames.unsqueeze(0).to(device); mb = mel.unsqueeze(0).to(device) | |
| out = [] | |
| for _ in range(max_new): | |
| logits, _, _ = model(ids, None, fb, mb) | |
| nxt = int(logits[0, -1].argmax()) | |
| if nxt == EOS: break | |
| out.append(nxt); ids = torch.cat([ids, torch.tensor([[nxt]], device=device)], 1) | |
| return tok.dec(out) | |
| def _chat_match(pred, gold): | |
| g = gold.split("Answer:")[-1].strip() | |
| p = pred.split("Answer:")[-1].strip() if "Answer:" in pred else pred.strip() | |
| return p.startswith(g) or g in p | |
| def _action_match(pred, gold): | |
| pj, gj = _extract_json(pred), _extract_json(gold) | |
| return pj is not None and pj == gj | |
| # ============================================================================= | |
| # Eval generators β INLINED (no external file / no --data needed) | |
| # ============================================================================= | |
| def _aj(o): return json.dumps(o, separators=(",", ":")) | |
| def _fr(): return random.random() < 0.30 | |
| # --------------------------------------------------------------------------- | |
| # Deterministic physics simulation β returns ground-truth facts (no torch here; | |
| # the finetune renders pixels, the generator only needs the trajectory facts). | |
| # Mirrors gen_world() in samg_mm_train.py kind-for-kind so frames match. | |
| # --------------------------------------------------------------------------- | |
| def simulate_facts(kind, seed): | |
| """Replay the trajectory deterministically; return physical facts used to | |
| build the supervised answer. Uses an isolated RNG so it cannot perturb the | |
| global stream (the finetune reseeds the SAME way before rendering).""" | |
| rng = random.Random(seed) | |
| if kind == "ball": | |
| x, y = 10., 10.; vx = rng.uniform(5, 9); vy = rng.uniform(-2, 0); g = 1.1 | |
| xs = [] | |
| for _ in range(T): xs.append(x); x += vx; vy += g; y += vy | |
| return {"dynamic": "gravity", "direction": "right", | |
| "reaches_right": x > IMG - 12, "dx": vx} | |
| if kind == "spring": | |
| A = rng.uniform(15, 30); w = rng.uniform(.5, 1.0) | |
| return {"dynamic": "oscillation", "direction": "oscillating", | |
| "amplitude": A, "reaches_right": False} | |
| if kind == "bounce": | |
| x = 12.; vx = rng.uniform(7, 11); bounced = False | |
| for _ in range(T): | |
| x += vx | |
| if x > IMG - 12: vx = -vx; bounced = True | |
| return {"dynamic": "collision", "direction": "right then left", | |
| "bounces": bounced, "reaches_right": True} | |
| # twobody | |
| v = rng.uniform(4, 7) | |
| return {"dynamic": "collision", "direction": "converging", | |
| "collides": True, "reaches_right": False} | |
| PHYS_KINDS = ["ball", "spring", "bounce", "twobody"] | |
| OCR_PREFIX = ["speed=", "temp=", "qos=", "zone ", "dock ", "id="] | |
| def v_motion(): | |
| kind = random.choice(PHYS_KINDS); seed = random.randint(0, 2**31 - 1) | |
| f = simulate_facts(kind, seed); fr = _fr() | |
| q = ("dans quel sens se dΓ©place l'objet ?" if fr | |
| else "which way does the object move?") | |
| prompt = f"[VIS] {q} [CHAT]" | |
| d = f["direction"] | |
| trace = (f"step 1: track the bright object across frames. " | |
| f"step 2: its horizontal position evolves -> {d}. Answer: {d}") | |
| return dict(family="v_motion", fmt="CHAT", use_v=True, use_a=False, use_p=False, | |
| spec={"kind": kind, "seed": seed}, prompt=prompt, answer=trace) | |
| def v_ocr(): | |
| pre = random.choice(OCR_PREFIX) | |
| num = "".join(random.choice("0123456789") for _ in range(random.randint(2, 3))) | |
| label = pre + num; fr = _fr() | |
| q = "quel nombre est affichΓ© ?" if fr else "what number is shown?" | |
| prompt = f"[VIS] [OCR] {q} [CHAT]" | |
| trace = (f"step 1: read the bitmap label. step 2: digits = {num}. Answer: {num}") | |
| return dict(family="v_ocr", fmt="CHAT", use_v=True, use_a=False, use_p=False, | |
| spec={"kind": "ocr", "seed": random.randint(0, 2**31 - 1), "label": label}, | |
| prompt=prompt, answer=trace) | |
| # --------------------------------------------------------------------------- | |
| # PHYSICS β [CHAT] | |
| # --------------------------------------------------------------------------- | |
| def p_identify(): | |
| kind = random.choice(PHYS_KINDS); seed = random.randint(0, 2**31 - 1) | |
| f = simulate_facts(kind, seed); fr = _fr() | |
| q = ("quelle dynamique rΓ©git ce mouvement ?" if fr | |
| else "what dynamic governs this motion?") | |
| prompt = f"[VIS] [PHYS] {q} [CHAT]" | |
| dyn = f["dynamic"] | |
| cue = {"gravity": "constant downward acceleration", | |
| "oscillation": "periodic back-and-forth around a center", | |
| "collision": "abrupt velocity reversal on contact"}[dyn] | |
| trace = f"step 1: observe {cue}. step 2: that is {dyn}. Answer: {dyn}" | |
| return dict(family="p_identify", fmt="CHAT", use_v=True, use_a=False, use_p=True, | |
| spec={"kind": kind, "seed": seed}, prompt=prompt, answer=trace) | |
| def p_predict(): | |
| kind = random.choice(PHYS_KINDS); seed = random.randint(0, 2**31 - 1) | |
| f = simulate_facts(kind, seed); fr = _fr() | |
| if kind in ("ball", "bounce"): | |
| ans = "yes" if f.get("reaches_right") else "no" | |
| q = ("l'objet atteint-il le bord droit ?" if fr | |
| else "does the object reach the right edge?") | |
| reason = "its rightward velocity carries it to the wall" if ans == "yes" \ | |
| else "it falls or stops before the wall" | |
| elif kind == "twobody": | |
| ans = "yes"; q = ("les deux corps vont-ils entrer en collision ?" if fr | |
| else "will the two bodies collide?") | |
| reason = "they approach from both sides and meet in the middle" | |
| else: | |
| ans = "no"; q = ("l'objet quitte-t-il le centre durablement ?" if fr | |
| else "does the object leave the center permanently?") | |
| reason = "it oscillates and returns to the center each period" | |
| prompt = f"[VIS] [PHYS] {q} [CHAT]" | |
| trace = f"step 1: {reason}. Answer: {ans}" | |
| return dict(family="p_predict", fmt="CHAT", use_v=True, use_a=False, use_p=True, | |
| spec={"kind": kind, "seed": seed}, prompt=prompt, answer=trace) | |
| # --------------------------------------------------------------------------- | |
| # CROSS-MODAL β [ACTION] ({domain,op,params} schema, as MM base pretraining) | |
| # --------------------------------------------------------------------------- | |
| def x_robot(): | |
| seed = random.randint(0, 2**31 - 1); rng = random.Random(seed); fr = _fr() | |
| target = random.choice(["dock", "block", "marker", "exit"]) | |
| speed = round(rng.uniform(0.2, 0.9), 2); angle = rng.randint(0, 359) | |
| q = (f"pousse vers le {target}" if fr else f"push toward the {target}") | |
| prompt = f"[VIS] {q} [ACTION]" | |
| action = {"domain": "ros", "op": "move", | |
| "params": {"speed": speed, "angle": angle, "duration_s": 1}} | |
| return dict(family="x_robot", fmt="ACTION", use_v=True, use_a=False, use_p=False, | |
| spec={"kind": "robot", "seed": seed}, prompt=prompt, answer=_aj(action)) | |
| def x_sensor(): | |
| pre = "speed="; val = random.randint(10, 99) | |
| label = pre + str(val); limit = 50; fr = _fr() | |
| q = (f"si la vitesse dΓ©passe {limit}, ralentis" if fr | |
| else f"if speed exceeds {limit}, slow down") | |
| prompt = f"[VIS] [OCR] {q} [ACTION]" | |
| if val > limit: | |
| action = {"domain": "ros", "op": "set_speed", "params": {"value": limit}} | |
| else: | |
| action = {"domain": "ros", "op": "continue", "params": {}} | |
| return dict(family="x_sensor", fmt="ACTION", use_v=True, use_a=False, use_p=False, | |
| spec={"kind": "ocr", "seed": random.randint(0, 2**31 - 1), "label": label}, | |
| prompt=prompt, answer=_aj(action)) | |
| SOUND_CAUSE = { | |
| "sharp impact": "collision", | |
| "double impact": "collision", | |
| "rhythmic creak": "oscillation", | |
| "whoosh then thud": "falling object", | |
| "servo whir": "motor", | |
| } | |
| KIND_SOUND = {"ball": "whoosh then thud", "spring": "rhythmic creak", | |
| "bounce": "sharp impact", "twobody": "double impact"} | |
| SOUND_ACTION = { | |
| "alarm": ({"domain": "ros", "op": "stop", "params": {}}, "an alarm"), | |
| "servo whir": ({"domain": "ros", "op": "continue", "params": {}}, "a servo whir"), | |
| "sharp impact": ({"domain": "ros", "op": "halt", "params": {"reason": "collision"}}, "an impact"), | |
| "rhythmic creak": ({"domain": "ros", "op": "slow", "params": {"value": 20}}, "a creak"), | |
| } | |
| def a_identify(): | |
| snd = random.choice(list(SOUND_CAUSE.keys())); cause = SOUND_CAUSE[snd]; fr = _fr() | |
| q = "qu'est-ce qui a produit ce son ?" if fr else "what produced this sound?" | |
| prompt = f"[AUD] {q} [CHAT]" | |
| desc = {"collision": "a sharp broadband transient", | |
| "oscillation": "a periodic rhythmic tone", | |
| "falling object": "a rising sweep followed by a thud", | |
| "motor": "a steady mechanical hum"}[cause] | |
| trace = f"step 1: hear {desc}. step 2: that indicates {cause}. Answer: {cause}" | |
| return dict(family="a_identify", fmt="CHAT", use_v=False, use_a=True, use_p=False, | |
| spec={"kind": "audio", "sound": snd}, prompt=prompt, answer=trace) | |
| def a_match(): | |
| kind = random.choice(PHYS_KINDS); seed = random.randint(0, 2**31 - 1) | |
| true_sound = KIND_SOUND[kind]; fr = _fr() | |
| if random.random() < 0.5: | |
| snd = true_sound; ans = "yes"; reason = "the sound fits the motion" | |
| else: | |
| snd = random.choice([s for s in KIND_SOUND.values() if s != true_sound]) | |
| ans = "no"; reason = "the sound does not fit the motion" | |
| q = ("le son correspond-il au mouvement ?" if fr | |
| else "does the sound match the motion?") | |
| prompt = f"[VIS] [AUD] [PHYS] {q} [CHAT]" | |
| trace = f"step 1: {reason}. Answer: {ans}" | |
| return dict(family="a_match", fmt="CHAT", use_v=True, use_a=True, use_p=True, | |
| spec={"kind": kind, "seed": seed, "sound": snd}, prompt=prompt, answer=trace) | |
| def a_event(): | |
| snd = random.choice(list(SOUND_ACTION.keys())); action, desc = SOUND_ACTION[snd]; fr = _fr() | |
| instr = {"stop": "arrΓͺte le robot" if fr else "stop the robot", | |
| "continue": "continue" if fr else "keep going", | |
| "halt": "stoppe net" if fr else "halt immediately", | |
| "slow": "ralentis" if fr else "slow down"}[action["op"]] | |
| q = (f"si tu entends {desc}, {instr}" if fr else f"if you hear {desc}, {instr}") | |
| prompt = f"[AUD] {q} [ACTION]" | |
| return dict(family="a_event", fmt="ACTION", use_v=False, use_a=True, use_p=False, | |
| spec={"kind": "audio", "sound": snd}, prompt=prompt, answer=_aj(action)) | |
| # builder: produce a held-out eval set in-memory (disjoint seed, no files) | |
| def build_eval(n=1800, seed=99991, families="auto"): | |
| """families: 'reasoning' (6 visual/physics/cross-modal) or 'auto'/'audio' (all 9).""" | |
| if families == "reasoning": | |
| gens = [v_motion, v_ocr, p_identify, p_predict, x_robot, x_sensor] | |
| else: | |
| gens = [v_motion, v_ocr, p_identify, p_predict, x_robot, x_sensor, | |
| a_identify, a_match, a_event] | |
| _st = random.getstate(); random.seed(seed) | |
| rows = [] | |
| for _ in range(n): | |
| s = random.choice(gens)() | |
| s["text"] = s["prompt"] + " " + s["answer"]; rows.append(s) | |
| random.setstate(_st) | |
| return rows | |
| # ============================================================================= | |
| # Benchmark | |
| # ============================================================================= | |
| def load_ckpt(model, path): | |
| ck = torch.load(path, map_location=device) | |
| sd = ck["model"] if "model" in ck else ck | |
| model.load_state_dict(sd, strict=True) | |
| return ck.get("step", "?"), ck.get("best", None) | |
| # ============================================================================= | |
| # SAM-MM β HuggingFace Space (self-contained; weights pulled from HF) | |
| # Architecture inlined above. Set HF_TOKEN as a Space secret for private repos. | |
| # ============================================================================= | |
| import io | |
| try: | |
| from PIL import Image | |
| except ImportError: | |
| os.system(f"{sys.executable} -m pip install -q Pillow --break-system-packages"); from PIL import Image | |
| import gradio as gr | |
| CHECKPOINTS = { | |
| "Reasoning β vision + physics": "AMFORGE/sam-mm-reasoning-checkpoints:best.pt", | |
| "Audio-reasoning β + sound": "AMFORGE/sam-mm-audio-reasoning-checkpoints:best.pt", | |
| } | |
| SCENES = { | |
| "πͺ Physics β identify the dynamic": "p_identify", | |
| "π― Physics β predict the outcome": "p_predict", | |
| "β‘οΈ Vision β direction of motion": "v_motion", | |
| "π’ Vision β read the number (OCR)": "v_ocr", | |
| "π°οΈ Cross-modal β sensor β action": "x_sensor", | |
| "π Audio β identify the sound": "a_identify", | |
| "π¬ Audio β match sight + sound": "a_match", | |
| "β‘ Audio β sound β action": "a_event", | |
| } | |
| FAMFUNC = {"p_identify": p_identify, "p_predict": p_predict, "v_motion": v_motion, | |
| "v_ocr": v_ocr, "x_sensor": x_sensor, "a_identify": a_identify, | |
| "a_match": a_match, "a_event": a_event} | |
| AUDIO_FAMS = {"a_identify", "a_match", "a_event"} | |
| _STATE = {"model": None, "tok": None, "ckpt": None} | |
| def _load(ckpt): | |
| if _STATE["model"] is not None and _STATE["ckpt"] == ckpt: | |
| return _STATE["model"], _STATE["tok"] | |
| token = get_hf_token() | |
| tok = Tok(token=token) | |
| model = SAMMM(Config()).to(device) | |
| load_ckpt(model, resolve_ckpt(ckpt, token)); model.eval() | |
| _STATE.update(model=model, tok=tok, ckpt=ckpt) | |
| return model, tok | |
| def _montage(frames): | |
| n = frames.shape[0] | |
| tiles = [(frames[i].clamp(0,1).permute(1,2,0)*255).byte().cpu().numpy() for i in range(n)] | |
| w = IMG*n + (n-1)*2 | |
| img = Image.new("RGB", (w, IMG), (17,18,26)) | |
| for i,t in enumerate(tiles): img.paste(Image.fromarray(t), (i*(IMG+2),0)) | |
| return img.resize((w*4, IMG*4), Image.NEAREST) | |
| def _infer(ckpt_label, scene_label, max_new): | |
| ckpt = CHECKPOINTS[ckpt_label]; fam = SCENES[scene_label] | |
| s = FAMFUNC[fam]() | |
| frames, mel = render_av(s["spec"]) | |
| model, tok = _load(ckpt) | |
| pred = generate(model, tok, s["prompt"], frames, mel, max_new=int(max_new)) | |
| gold = s["answer"].split("Answer:")[-1].strip() | |
| got = pred.split("Answer:")[-1].strip() if "Answer:" in pred else pred.strip() | |
| ok = _chat_match(pred, s["answer"]) | |
| return s, frames, got, gold, ok, fam | |
| def run_one(ckpt_label, scene_label, max_new): | |
| s, frames, got, gold, ok, fam = _infer(ckpt_label, scene_label, max_new) | |
| warn = "" | |
| if fam in AUDIO_FAMS and "audio" not in ckpt_label.lower(): | |
| warn = ("\n\n> β οΈ This is an **audio** scene on the **reasoning** checkpoint β " | |
| "it never learned sound, so a correct answer here is chance. " | |
| "Switch the checkpoint to *Audio-reasoning* to test it for real.") | |
| verdict = "β **correct**" if ok else "β **mismatch**" | |
| md = (f"### {verdict}\n" | |
| f"**Prompt β SAM-MM**\n```\n{s['prompt']}\n```\n" | |
| f"**Model answer:** `{got}` β’ **Ground truth:** `{gold}`") | |
| if "sound" in s["spec"]: | |
| md += (f"\n\n*Sound cue `{s['spec']['sound']}` β a deterministic log-mel " | |
| f"(no audible file; the encoder reads the spectrogram).*") | |
| md += warn | |
| return _montage(frames), md | |
| def run_batch(ckpt_label, scene_label, max_new, n=20): | |
| hits = 0; lines = [] | |
| for _ in range(int(n)): | |
| s, frames, got, gold, ok, fam = _infer(ckpt_label, scene_label, max_new) | |
| hits += int(ok) | |
| lines.append(f"{'β ' if ok else 'β'} `{got}` vs `{gold}`") | |
| acc = 100*hits/int(n) | |
| head = f"### {hits}/{int(n)} correct β **{acc:.0f}%** exact-match\n\n" | |
| return head + "\n".join(lines) | |
| CSS = """ | |
| .gradio-container {max-width: 980px !important} | |
| #title {text-align:center} | |
| #frames img {image-rendering: pixelated; border-radius: 10px} | |
| footer {display:none !important} | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo", neutral_hue="slate"), css=CSS, | |
| title="SAM-MM Β· multimodal demo") as demo: | |
| gr.Markdown("# π§ SAM-MM β a 58M multimodal model that *reasons*", elem_id="title") | |
| gr.Markdown( | |
| "Pick a scene. SAM-MM perceives the **rendered frames** (and, for audio scenes, a " | |
| "**log-mel spectrogram**), then answers in `[CHAT]` text or a `[ACTION]` JSON record. " | |
| "Frames are synthetic β this is the model's native world. Nothing is hard-coded: each " | |
| "scene is freshly generated, the model decodes token-by-token, and the answer is checked " | |
| "against ground truth computed independently.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| ckpt = gr.Dropdown(list(CHECKPOINTS), value=list(CHECKPOINTS)[0], label="Checkpoint") | |
| scene = gr.Dropdown(list(SCENES), value=list(SCENES)[0], label="Scene") | |
| max_new = gr.Slider(24, 96, value=64, step=8, label="max new tokens") | |
| with gr.Row(): | |
| b1 = gr.Button("Generate & run", variant="primary") | |
| b2 = gr.Button("Run 20 (accuracy)") | |
| with gr.Column(scale=2): | |
| img = gr.Image(label="What SAM-MM sees", elem_id="frames") | |
| md = gr.Markdown() | |
| batch = gr.Markdown() | |
| b1.click(run_one, [ckpt, scene, max_new], [img, md]) | |
| b2.click(run_batch, [ckpt, scene, max_new], [batch]) | |
| gr.Markdown( | |
| "---\n**Honest notes.** Physics & motion are SAM-MM's strength (its world-model carries " | |
| "real dynamics). OCR generalizes to unseen numbers but isn't perfect. The cross-modal " | |
| "`[ACTION]` family is weaker. **Audio is the weak modality** β it was trained on synthetic " | |
| "pseudo-mel, so strong audio scores here partly reflect that, not true listening. " | |
| "Architecture internals are proprietary and not exposed.") | |
| if __name__ == "__main__": | |
| demo.launch() |