sam-mm-space / app.py
ameforge's picture
Update app.py
0252509 verified
Raw
History Blame Contribute Delete
41.8 kB
"""
=============================================================================
SAM-MM Benchmark β€” reproducible per-family evaluation for the SAM-MM line
SparseMind / AMFORGE
=============================================================================
Checkpoint-driven and fully self-contained: the held-out eval set is GENERATED
INTERNALLY (disjoint seed 99991), so no external data file or generator script
is needed. It renders frames/mel from each sample's spec, greedy-decodes the
answer, and reports per-family exact match + a CHAT/ACTION breakdown + aggregate.
Notebook (Colab/Kaggle): edit the variables at the top of main() and run.
Terminal:
python samg_mm_benchmark.py --ckpt AMFORGE/sam-mm-reasoning-checkpoints:best.pt --families reasoning
python samg_mm_benchmark.py --ckpt AMFORGE/sam-mm-audio-reasoning-checkpoints:best.pt --families audio
python samg_mm_benchmark.py --ckpt ./best.pt --n 3000 --n-per 100
Self-contained: the SAM-MM model, the renderers, the tokenizer resolver are
inlined verbatim. External vision/audio LMs are NOT comparable on these
SAM-specific synthetic tasks (different input pipelines), so this is an honest
internal per-family report; add a baseline column only where truly comparable.
=============================================================================
"""
from __future__ import annotations
import os, sys, json, math, random, argparse
from dataclasses import dataclass, asdict
from typing import Optional
from enum import IntEnum
import torch, torch.nn as nn, torch.nn.functional as F
try: import sentencepiece as spm
except ImportError:
os.system(f"{sys.executable} -m pip install -q sentencepiece --break-system-packages"); import sentencepiece as spm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BF16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
TOK_REPO, TOK_FILE = "AMFORGE/samg_mm_tok", "samg_mm_tokenizer.model"
DEFAULT_CKPT_REPO = "AMFORGE/sam-mm-reasoning-checkpoints"
ORGANIZATION, MODEL_NAME = "AMFORGE", "SAM-MM"
def _pip(p): os.system(f"{sys.executable} -m pip install -q {p} --break-system-packages")
def get_hf_token():
t = os.environ.get("HF_TOKEN") or ""
if not t:
try:
from kaggle_secrets import UserSecretsClient
t = UserSecretsClient().get_secret("HF_TOKEN") or ""
except Exception: pass
if not t:
try:
from google.colab import userdata; t = userdata.get("HF_TOKEN") or ""
except Exception: pass
if not t:
p = os.path.expanduser("~/.cache/huggingface/token")
if os.path.exists(p): t = open(p).read().strip()
return t
def resolve_tokenizer(token=None):
for p in [TOK_FILE, os.path.join("tokenizer", TOK_FILE)]:
if os.path.isfile(p): return p
try:
from huggingface_hub import hf_hub_download
except ImportError:
_pip("huggingface_hub"); from huggingface_hub import hf_hub_download
return hf_hub_download(TOK_REPO, TOK_FILE, token=token)
def resolve_ckpt(spec, token=None):
"""Local path, 'repo:file', or bare 'file' from the default repo."""
if os.path.isfile(spec): return spec
try:
from huggingface_hub import hf_hub_download
except ImportError:
_pip("huggingface_hub"); from huggingface_hub import hf_hub_download
if ":" in spec and not spec.startswith("/"):
repo, fn = spec.split(":", 1)
else:
repo, fn = DEFAULT_CKPT_REPO, spec
return hf_hub_download(repo, fn, token=token)
# =============================================================================
# SAM-MM model β€” INLINED VERBATIM from samg_mm_train.py (state_dict-compatible)
# =============================================================================
class NeuronType(IntEnum):
STEM=0; EXCITATORY=1; INHIBITORY=2; MEMORY=3; RELAY=4; MODULATORY=5; PATTERN=6
TARGET_DISTRIBUTION = {NeuronType.STEM:.10, NeuronType.EXCITATORY:.35, NeuronType.INHIBITORY:.10,
NeuronType.MEMORY:.15, NeuronType.RELAY:.10, NeuronType.MODULATORY:.08, NeuronType.PATTERN:.12}
@dataclass
class Config:
vocab_size:int=32000; dim:int=320; n_layers:int=8; n_heads:int=8
max_seq_len:int=1024; channel_top_k:int=120; token_top_k:int=128; ffn_mult:int=4
dropout:float=0.1; pad_id:int=0; eos_id:int=2; use_diversity:bool=True
# multimodal
v_dim:int=320; v_layers:int=7; v_patch:int=8; img:int=96
a_dim:int=320; a_layers:int=6; mel:int=64
phys_dim:int=320; phys_slots:int=4
# plasticity (text diversity layers)
target_stem_ratio:float=.10; min_stem_ratio:float=.08; stem_plasticity:float=.012
reversion_rate:float=.012; min_age_before_revert:int=8; update_interval:int=10
baseline_revert_ratio:float=.5; inhibition_strength:float=.08
modulation_strength:float=.1; excitation_strength:float=.3
# train
batch_size:int=16; grad_accum:int=2; lr:float=5e-4; max_steps:int=40000
warmup:int=1500; eval_every:int=1000; save_every:int=1000; patience:int=12
log_every:int=50; aux_phys_w:float=0.5
class DynamicTypeManager(nn.Module):
def __init__(self, dim, cfg):
super().__init__()
self.dim, self.cfg = dim, cfg
t=[]
for nt,p in TARGET_DISTRIBUTION.items(): t += [nt.value]*int(dim*p)
while len(t)<dim: t.append(0)
random.shuffle(t)
self.register_buffer("neuron_types", torch.tensor(t,dtype=torch.long))
self.register_buffer("activation_history", torch.zeros(dim))
self.register_buffer("age", torch.randint(0,cfg.min_age_before_revert,(dim,),dtype=torch.long))
self.register_buffer("cycle_counter", torch.tensor(0,dtype=torch.long))
def get_type_mask(self,t): return (self.neuron_types==t.value).float()
@torch.no_grad()
def step(self,x):
if not self.training: return
self.activation_history.mul_(.95).add_(x.abs().mean((0,1)).float(),alpha=.05)
self.age += 1; self.cycle_counter += 1
if self.cycle_counter.item()%self.cfg.update_interval: return
spec=(self.neuron_types!=0).nonzero().view(-1)
if len(spec)>4:
acts=self.activation_history[spec]
thr=torch.quantile(acts,.30).item()
cand=spec[(acts<=max(thr,1e-6)) & (self.age[spec]>self.cfg.min_age_before_revert)]
n=min(max(1,int(self.dim*self.cfg.reversion_rate)),len(cand))
if n>0:
sel=cand[(-self.activation_history[cand]).topk(n)[1]]
self.neuron_types[sel]=0; self.age[sel]=0; self.activation_history[sel]*=.5
stem=(self.neuron_types==0).nonzero().view(-1)
floor=max(2,int(self.dim*self.cfg.min_stem_ratio))
if len(stem)>floor:
n=min(max(1,int(self.dim*self.cfg.stem_plasticity)),len(stem)-floor)
sel=stem[self.age[stem].float().topk(n)[1]]
for ni in sel:
w=torch.tensor([TARGET_DISTRIBUTION[t] for t in NeuronType if t!=NeuronType.STEM])
self.neuron_types[ni]=list(NeuronType)[1:][torch.multinomial(w/w.sum(),1).item()].value
self.age[ni]=0
class GentleInhibition(nn.Module):
def __init__(s,d,c): super().__init__(); s.k=c.inhibition_strength; s.noise_detector=nn.Sequential(nn.Linear(d,d//4),nn.ReLU(),nn.Linear(d//4,d),nn.Sigmoid()); s.threshold=nn.Parameter(torch.tensor(.15))
def forward(s,x,m):
sup=(x.abs()<s.threshold)&(s.noise_detector(x)<.3)
return x*(1-sup.float()*m.view(1,1,-1)*s.k)
class StrongExcitation(nn.Module):
def __init__(s,d,c): super().__init__(); s.k=c.excitation_strength; s.integrator=nn.Sequential(nn.Linear(d,d),nn.GELU(),nn.Linear(d,d)); s.importance=nn.Sequential(nn.Linear(d,d),nn.Sigmoid())
def forward(s,x,m): return x+s.integrator(x)*s.importance(x)*m.view(1,1,-1)*s.k
class GentleModulation(nn.Module):
def __init__(s,d,c): super().__init__(); s.k=c.modulation_strength; s.context=nn.Sequential(nn.Linear(d,d//4),nn.GELU(),nn.Linear(d//4,d),nn.Tanh())
def forward(s,x,m):
B,T,D=x.shape; den=torch.arange(1,T+1,device=x.device,dtype=x.dtype).view(1,T,1)
return x+s.context(x.cumsum(1)/den)*m.view(1,1,-1)*s.k
class PatternDetection(nn.Module):
def __init__(s,d): super().__init__(); s.conv3=nn.Conv1d(d,d//2,3,groups=d//2); s.conv5=nn.Conv1d(d,d//2,5,groups=d//2); s.combine=nn.Linear(d,d)
def forward(s,x,m):
xt=x.transpose(1,2)
p=torch.cat([s.conv3(F.pad(xt,(2,0))).transpose(1,2),s.conv5(F.pad(xt,(4,0))).transpose(1,2)],-1)
return x+s.combine(p)*m.view(1,1,-1)*.2
class RelayNetwork(nn.Module):
def __init__(s,d): super().__init__(); s.gate=nn.Sequential(nn.Linear(d,d),nn.Sigmoid()); s.transform=nn.Linear(d,d)
def forward(s,x,m): return x+s.transform(x)*s.gate(x)*m.view(1,1,-1)*.2
class BalancedDiversityLayer(nn.Module):
def __init__(s,d,c):
super().__init__(); s.type_manager=DynamicTypeManager(d,c)
s.inhibition=GentleInhibition(d,c); s.excitation=StrongExcitation(d,c)
s.modulation=GentleModulation(d,c); s.pattern=PatternDetection(d); s.relay=RelayNetwork(d)
s.norm=nn.LayerNorm(d); s.output=nn.Linear(d,d)
def forward(s,x):
r=x; x=s.norm(x); tm=s.type_manager; tm.step(x)
x=s.excitation(x,tm.get_type_mask(NeuronType.EXCITATORY))
x=s.pattern(x,tm.get_type_mask(NeuronType.PATTERN))
x=s.relay(x,tm.get_type_mask(NeuronType.RELAY))
x=s.modulation(x,tm.get_type_mask(NeuronType.MODULATORY))
x=s.inhibition(x,tm.get_type_mask(NeuronType.INHIBITORY))
return r+s.output(x)*.5
class SparseGate(nn.Module):
def __init__(s,d,k):
super().__init__(); s.k=k; s.scorer=nn.Sequential(nn.Linear(d,d//4),nn.SiLU(),nn.Linear(d//4,d))
nn.init.zeros_(s.scorer[-1].weight); nn.init.zeros_(s.scorer[-1].bias)
def forward(s,x):
sc=torch.sigmoid(s.scorer(x)); k=min(s.k,x.shape[-1])
thr=sc.topk(k,-1)[0][...,-1:]; hard=(sc>=thr).float(); soft=torch.sigmoid((sc-thr)*10)
return x*(hard-soft.detach()+soft)
class SparseAttn(nn.Module):
def __init__(s,d,h,tk): super().__init__(); s.h,s.hd,s.tk=h,d//h,tk; s.qkv=nn.Linear(d,3*d); s.out=nn.Linear(d,d)
def forward(s,x):
B,T,D=x.shape
q,k,v=s.qkv(x).reshape(B,T,3,s.h,s.hd).permute(2,0,3,1,4)
a=(q@k.transpose(-2,-1))*s.hd**-.5
a=a.masked_fill(torch.triu(torch.ones(T,T,device=x.device),1).bool(),float("-inf"))
_,i=a.topk(min(s.tk,T),-1)
m=torch.zeros_like(a,dtype=torch.bool).scatter_(-1,i,True)
a=torch.nan_to_num(F.softmax(a.masked_fill(~m,float("-inf")),-1),0.)
return s.out((a@v).transpose(1,2).reshape(B,T,D))
class SparseFFN(nn.Module):
def __init__(s,d,m,ck): super().__init__(); s.up=nn.Linear(d,d*m); s.gate=SparseGate(d*m,ck); s.down=nn.Linear(d*m,d)
def forward(s,x): return s.down(s.gate(F.silu(s.up(x))))
class Block(nn.Module):
def __init__(s,c,i,dim=None,heads=None,tk=None,ck=None,div=False):
super().__init__(); d=dim or c.dim; h=heads or c.n_heads
s.n1=nn.LayerNorm(d); s.attn=SparseAttn(d,h,tk or c.token_top_k)
s.n2=nn.LayerNorm(d); s.ffn=SparseFFN(d,c.ffn_mult,(ck or c.channel_top_k)*c.ffn_mult)
s.drop=nn.Dropout(c.dropout); s.div=div
if div: s.diversity=BalancedDiversityLayer(d,c)
def forward(s,x):
x=x+s.drop(s.attn(s.n1(x)))
if s.div: x=s.diversity(x)
return x+s.drop(s.ffn(s.n2(x)))
# =============================================================================
# Encoders + PhysicsCore
# =============================================================================
class VisionEncoder(nn.Module):
"""64x64x3 -> 64 patch tokens dim v_dim -> proj to dim."""
def __init__(s,c):
super().__init__()
n=(c.img//c.v_patch)**2
s.patch=nn.Conv2d(3,c.v_dim,c.v_patch,c.v_patch)
s.pos=nn.Parameter(torch.randn(1,n,c.v_dim)*.02)
s.blocks=nn.ModuleList([Block(c,i,dim=c.v_dim,heads=8,tk=n,ck=int(c.v_dim*.375)) for i in range(c.v_layers)])
s.norm=nn.LayerNorm(c.v_dim); s.proj=nn.Linear(c.v_dim,c.dim)
def forward(s,img):
x=s.patch(img).flatten(2).transpose(1,2)+s.pos
for b in s.blocks: x=b(x)
return s.proj(s.norm(x)) # B,64,dim
class AudioEncoder(nn.Module):
"""log-mel B,1,64,T -> ~T/4 tokens dim a_dim -> proj to dim."""
def __init__(s,c):
super().__init__()
s.stem=nn.Sequential(nn.Conv2d(1,32,3,2,1),nn.GELU(),nn.Conv2d(32,c.a_dim,3,2,1),nn.GELU())
s.blocks=nn.ModuleList([Block(c,i,dim=c.a_dim,heads=8,tk=64,ck=int(c.a_dim*.375)) for i in range(c.a_layers)])
s.norm=nn.LayerNorm(c.a_dim); s.proj=nn.Linear(c.a_dim,c.dim)
def forward(s,mel):
x=s.stem(mel) # B,a_dim,16,T/4
x=x.mean(2).transpose(1,2) # B,T/4,a_dim
for b in s.blocks: x=b(x)
return s.proj(s.norm(x))
class PhysicsCore(nn.Module):
"""Latent physical state engine. GRU over per-frame visual summaries,
phys_slots learned state slots, predicts next-frame embedding from
(z_t, action). Aux loss = MSE+cos(pred, vis_{t+1})."""
def __init__(s,c):
super().__init__()
s.slots=nn.Parameter(torch.randn(1,c.phys_slots,c.dim)*.02)
s.read=nn.MultiheadAttention(c.dim,4,batch_first=True)
s.cell=nn.GRUCell(c.dim,c.phys_dim)
s.act=nn.Linear(c.dim,c.phys_dim)
s.pred=nn.Sequential(nn.Linear(c.phys_dim,c.dim),nn.GELU(),nn.Linear(c.dim,c.dim))
s.to_seq=nn.Linear(c.phys_dim,c.dim)
s.pd=c.phys_dim
def forward(s,frames,action):
# frames: B,T,dim (per-frame mean vis emb); action: B,dim
B,T,_=frames.shape
z=frames.new_zeros(B,s.pd); preds=[]
a=s.act(action)
for t in range(T):
z=s.cell(frames[:,t]+0., z)+0.1*a
preds.append(s.pred(z))
pred=torch.stack(preds,1) # B,T,dim (predict t+1)
aux=0.
if T>1:
tgt=frames[:,1:].detach(); p=pred[:,:-1]
aux=F.mse_loss(p,tgt)+ (1-F.cosine_similarity(p,tgt,-1).mean())
slots,_=s.read(s.slots.expand(B,-1,-1), s.to_seq(z).unsqueeze(1), s.to_seq(z).unsqueeze(1))
return slots, aux # B,slots,dim
class SAMMM(nn.Module):
def __init__(s,c):
super().__init__(); s.cfg=c
s.tok_emb=nn.Embedding(c.vocab_size,c.dim); s.pos_emb=nn.Embedding(c.max_seq_len,c.dim)
s.drop=nn.Dropout(c.dropout)
s.blocks=nn.ModuleList([Block(c,i,div=(i%2==0 and c.use_diversity)) for i in range(c.n_layers)])
s.norm=nn.LayerNorm(c.dim)
s.vision=VisionEncoder(c); s.audio=AudioEncoder(c); s.phys=PhysicsCore(c)
s.mode=nn.Embedding(3,c.dim) # 0=[VIS] 1=[AUD] 2=[PHYS] separators
s.apply(s._init)
s.n_params=sum(p.numel() for p in s.parameters())
print(f"\n{MODEL_NAME} by {ORGANIZATION}: {s.n_params:,} params")
@staticmethod
def _init(mod):
if isinstance(mod,(nn.Linear,nn.Conv2d,nn.Conv1d)):
nn.init.normal_(mod.weight,std=0.02)
if mod.bias is not None: nn.init.zeros_(mod.bias)
elif isinstance(mod,nn.Embedding): nn.init.normal_(mod.weight,std=0.02)
def fuse(s,ids,frames=None,mel=None):
B=ids.shape[0]; parts=[]
aux=ids.new_zeros(1,dtype=torch.float32).squeeze()
if frames is not None:
B_,T,C,H,W=frames.shape
vis=s.vision(frames.reshape(B_*T,C,H,W)).reshape(B_,T,-1,s.cfg.dim)
per=vis.mean(2)
act=s.tok_emb(ids).mean(1)
slots,aux=s.phys(per,act)
parts += [s.mode.weight[0].expand(B,1,-1), vis[:,0],
s.mode.weight[2].expand(B,1,-1), slots]
if mel is not None:
parts += [s.mode.weight[1].expand(B,1,-1), s.audio(mel)]
parts.append(s.tok_emb(ids))
x=torch.cat(parts,1); n_pref=x.shape[1]-ids.shape[1]
return x,n_pref,aux
def forward(s,ids,targets=None,frames=None,mel=None):
x,n_pref,aux=s.fuse(ids,frames,mel)
T=x.shape[1]
x=s.drop(x+s.pos_emb(torch.arange(T,device=x.device)))
for b in s.blocks: x=b(x)
logits=F.linear(s.norm(x),s.tok_emb.weight)[:,n_pref:]
loss=None
if targets is not None:
lm=F.cross_entropy(logits.reshape(-1,s.cfg.vocab_size),targets.reshape(-1),ignore_index=s.cfg.pad_id)
loss=lm+s.cfg.aux_phys_w*aux
return logits,loss,aux
IMG, T = 96, 8
def render(x, y, img=96, r=4):
f = torch.zeros(3, img, img)
xi, yi = int(max(r, min(img - r - 1, x))), int(max(r, min(img - r - 1, y)))
f[0, yi - r:yi + r, xi - r:xi + r] = 1.0
return f
def render_world(kind, seed):
rng = random.Random(seed); fr = []
if kind == "ball":
x, y = 10., 10.; vx = rng.uniform(5, 9); vy = rng.uniform(-2, 0); g = 1.1
for _ in range(T): fr.append(render(x, y, IMG)); x += vx; vy += g; y += vy
elif kind == "spring":
c = IMG // 2; A = rng.uniform(15, 30); w = rng.uniform(.5, 1.0)
for t in range(T): fr.append(render(c + A * math.sin(w * t), c, IMG))
elif kind == "bounce":
x, y = 12., IMG // 2; vx = rng.uniform(7, 11)
for _ in range(T):
fr.append(render(x, y, IMG)); x += vx
if x > IMG - 12: vx = -vx
else: # twobody
x1, x2, y = 15., float(IMG - 15), IMG // 2; v = rng.uniform(4, 7)
for _ in range(T):
f = render(x1, y, IMG); f[1] = render(x2, y, IMG)[0]
fr.append(f)
if abs(x2 - x1) > 10: x1 += v; x2 -= v
return torch.stack(fr)
_DIG = {d: torch.tensor(b).reshape(7, 5).float() for d, b in {
"0": [1,1,1,1,1,1,0,0,0,1,1,0,0,0,1,1,0,0,0,1,1,0,0,0,1,1,0,0,0,1,1,1,1,1,1],
"1": [0,0,1,0,0,0,1,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,1,1,1,0],
"2": [1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,1,1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,1,1,1,1],
"3": [1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,0,1,1,1,1,0,0,0,0,1,0,0,0,0,1,1,1,1,1,1],
"4": [1,0,0,0,1,1,0,0,0,1,1,0,0,0,1,1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1],
"5": [1,1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,1,1,1,1,1],
"6": [1,1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,1,1,1,1,1,0,0,0,1,1,0,0,0,1,1,1,1,1,1],
"7": [1,1,1,1,1,0,0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0],
"8": [1,1,1,1,1,1,0,0,0,1,1,0,0,0,1,1,1,1,1,1,1,0,0,0,1,1,0,0,0,1,1,1,1,1,1],
"9": [1,1,1,1,1,1,0,0,0,1,1,0,0,0,1,1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,1,1,1,1,1],
}.items()}
def render_ocr(num, img=96):
f = torch.zeros(3, img, img); x0, y0, s = 6, img // 2 - 14, 4
for i, ch in enumerate(num):
g = F.interpolate(_DIG[ch][None, None], scale_factor=s).squeeze()
x = x0 + i * (5 * s + 4)
f[:, y0:y0 + 7 * s, x:x + 5 * s] = g
return f.unsqueeze(0).repeat(T, 1, 1, 1)
def render_robot(seed, img=96):
rng = random.Random(seed)
x = rng.uniform(20, img - 20); y = rng.uniform(20, img - 20)
return render(x, y, img).unsqueeze(0).repeat(T, 1, 1, 1)
def _trailing_digits(label):
d = ""
for ch in reversed(label):
if ch.isdigit(): d = ch + d
elif d: break
return d or "0"
def render_from_spec(spec):
k = spec["kind"]
if k in ("ball", "spring", "bounce", "twobody"): return render_world(k, spec["seed"])
if k == "ocr": return render_ocr(_trailing_digits(spec.get("label", "0")))
if k == "robot": return render_robot(spec["seed"])
return torch.zeros(T, 3, IMG, IMG)
# --- audio: synthetic mel (stable) + real ESC-50 mel + runtime pool -----------
import hashlib
def _stable_hash(s): return int(hashlib.md5(s.encode()).hexdigest(), 16)
def synth_audio(sound, n_mels=64, n_t=64):
"""Deterministic pseudo-mel per sound class (stable across runs)."""
base = _stable_hash(sound) % 32
m = torch.zeros(1, n_mels, n_t)
m[0, base:base + 8] = torch.linspace(0.2, 1.0, n_t)
m += torch.randn(1, n_mels, n_t) * 0.05
return m
def wav_to_mel(wav, sr=16000, n_mels=64, n_t=64):
"""Lightweight log-magnitude spectrogram resized to (1, n_mels, n_t).
Dependency-free (torch.stft); the encoder adapts during finetuning."""
if wav.dim() > 1: wav = wav.mean(0)
if wav.numel() < 512: wav = F.pad(wav, (0, 512 - wav.numel()))
n_fft = 400; hop = 160
spec = torch.stft(wav, n_fft=n_fft, hop_length=hop,
window=torch.hann_window(n_fft), return_complex=True).abs()
spec = torch.log1p(spec).unsqueeze(0).unsqueeze(0)
mel = F.interpolate(spec, size=(n_mels, n_t), mode="bilinear", align_corners=False)[0]
return (mel - mel.mean()) / (mel.std() + 1e-5)
def load_esc50(n=600):
"""Real environmental audio (ESC-50) -> (mel, category). Synthetic fallback."""
try:
from datasets import load_dataset
ds = load_dataset("ashraq/esc50", split="train", streaming=True)
pool = []
for i, ex in enumerate(ds):
if i >= n: break
a = ex["audio"]; wav = torch.tensor(a["array"]).float()
cat = ex.get("category") or str(ex.get("target", "sound"))
pool.append((wav_to_mel(wav, a.get("sampling_rate", 16000)), str(cat)))
if pool:
print(f"[esc50] {len(pool)} real clips loaded", flush=True); return pool
except Exception as e:
print(f"[esc50] unreachable ({type(e).__name__}) β€” synthetic-only audio", flush=True)
return []
def render_av(spec):
"""Return (frames, mel) for a sample; mel is signal for audio families, else zeros."""
frames = render_from_spec(spec)
mel = synth_audio(spec["sound"]) if "sound" in spec else torch.zeros(1, 64, 64)
return frames, mel
def esc_sample(pool):
mel, cat = random.choice(pool)
return "[AUD] what is this sound? [CHAT]", f"step 1: classify the sound. Answer: {cat}", mel
# =============================================================================
class Tok:
def __init__(s, token=None):
s.sp = spm.SentencePieceProcessor(); s.sp.Load(resolve_tokenizer(token))
s.vocab = s.sp.GetPieceSize()
def enc(s, t): return s.sp.EncodeAsIds(t)
def dec(s, ids): return s.sp.DecodeIds(ids)
EOS = 2; PAD = 0; L = 80
def make_batch(tok, rows, idx, esc_pool=None, p_esc=0.22):
ids_in, tgts, frs, mls = [], [], [], []
for j in idx:
if esc_pool and random.random() < p_esc:
prompt, answer, mel = esc_sample(esc_pool); frames = torch.zeros(T, 3, IMG, IMG)
else:
r = rows[j]; prompt, answer = r["prompt"], r["answer"]
frames, mel = render_av(r["spec"])
p = tok.enc(prompt); a = tok.enc(" " + answer) + [EOS]
full = (p + a)[:L + 1]
if len(full) < L + 1: full = full + [PAD] * (L + 1 - len(full))
inp = full[:L]; tgt = full[1:L + 1]
cut = len(p) - 1 # supervise only answer tokens
tgt = [PAD if k < cut else t for k, t in enumerate(tgt)]
ids_in.append(inp); tgts.append(tgt); frs.append(frames); mls.append(mel)
ii = torch.tensor(ids_in, device=device); tt = torch.tensor(tgts, device=device)
ff = torch.stack(frs).to(device); mm = torch.stack(mls).to(device)
return ii, tt, ff, mm
# =============================================================================
# Eval β€” per family; CHAT matches the Answer span, ACTION matches the plan
# =============================================================================
def _extract_json(s):
i = s.find("{")
if i < 0: return None
depth = 0
for k in range(i, len(s)):
if s[k] == "{": depth += 1
elif s[k] == "}":
depth -= 1
if depth == 0:
try: return json.loads(s[i:k + 1])
except Exception: return None
return None
@torch.no_grad()
def generate(model, tok, prompt, frames, mel, max_new=48):
model.eval()
ids = torch.tensor([tok.enc(prompt)], device=device)
fb = frames.unsqueeze(0).to(device); mb = mel.unsqueeze(0).to(device)
out = []
for _ in range(max_new):
logits, _, _ = model(ids, None, fb, mb)
nxt = int(logits[0, -1].argmax())
if nxt == EOS: break
out.append(nxt); ids = torch.cat([ids, torch.tensor([[nxt]], device=device)], 1)
return tok.dec(out)
def _chat_match(pred, gold):
g = gold.split("Answer:")[-1].strip()
p = pred.split("Answer:")[-1].strip() if "Answer:" in pred else pred.strip()
return p.startswith(g) or g in p
def _action_match(pred, gold):
pj, gj = _extract_json(pred), _extract_json(gold)
return pj is not None and pj == gj
# =============================================================================
# Eval generators β€” INLINED (no external file / no --data needed)
# =============================================================================
def _aj(o): return json.dumps(o, separators=(",", ":"))
def _fr(): return random.random() < 0.30
# ---------------------------------------------------------------------------
# Deterministic physics simulation β€” returns ground-truth facts (no torch here;
# the finetune renders pixels, the generator only needs the trajectory facts).
# Mirrors gen_world() in samg_mm_train.py kind-for-kind so frames match.
# ---------------------------------------------------------------------------
def simulate_facts(kind, seed):
"""Replay the trajectory deterministically; return physical facts used to
build the supervised answer. Uses an isolated RNG so it cannot perturb the
global stream (the finetune reseeds the SAME way before rendering)."""
rng = random.Random(seed)
if kind == "ball":
x, y = 10., 10.; vx = rng.uniform(5, 9); vy = rng.uniform(-2, 0); g = 1.1
xs = []
for _ in range(T): xs.append(x); x += vx; vy += g; y += vy
return {"dynamic": "gravity", "direction": "right",
"reaches_right": x > IMG - 12, "dx": vx}
if kind == "spring":
A = rng.uniform(15, 30); w = rng.uniform(.5, 1.0)
return {"dynamic": "oscillation", "direction": "oscillating",
"amplitude": A, "reaches_right": False}
if kind == "bounce":
x = 12.; vx = rng.uniform(7, 11); bounced = False
for _ in range(T):
x += vx
if x > IMG - 12: vx = -vx; bounced = True
return {"dynamic": "collision", "direction": "right then left",
"bounces": bounced, "reaches_right": True}
# twobody
v = rng.uniform(4, 7)
return {"dynamic": "collision", "direction": "converging",
"collides": True, "reaches_right": False}
PHYS_KINDS = ["ball", "spring", "bounce", "twobody"]
OCR_PREFIX = ["speed=", "temp=", "qos=", "zone ", "dock ", "id="]
def v_motion():
kind = random.choice(PHYS_KINDS); seed = random.randint(0, 2**31 - 1)
f = simulate_facts(kind, seed); fr = _fr()
q = ("dans quel sens se dΓ©place l'objet ?" if fr
else "which way does the object move?")
prompt = f"[VIS] {q} [CHAT]"
d = f["direction"]
trace = (f"step 1: track the bright object across frames. "
f"step 2: its horizontal position evolves -> {d}. Answer: {d}")
return dict(family="v_motion", fmt="CHAT", use_v=True, use_a=False, use_p=False,
spec={"kind": kind, "seed": seed}, prompt=prompt, answer=trace)
def v_ocr():
pre = random.choice(OCR_PREFIX)
num = "".join(random.choice("0123456789") for _ in range(random.randint(2, 3)))
label = pre + num; fr = _fr()
q = "quel nombre est affichΓ© ?" if fr else "what number is shown?"
prompt = f"[VIS] [OCR] {q} [CHAT]"
trace = (f"step 1: read the bitmap label. step 2: digits = {num}. Answer: {num}")
return dict(family="v_ocr", fmt="CHAT", use_v=True, use_a=False, use_p=False,
spec={"kind": "ocr", "seed": random.randint(0, 2**31 - 1), "label": label},
prompt=prompt, answer=trace)
# ---------------------------------------------------------------------------
# PHYSICS β€” [CHAT]
# ---------------------------------------------------------------------------
def p_identify():
kind = random.choice(PHYS_KINDS); seed = random.randint(0, 2**31 - 1)
f = simulate_facts(kind, seed); fr = _fr()
q = ("quelle dynamique rΓ©git ce mouvement ?" if fr
else "what dynamic governs this motion?")
prompt = f"[VIS] [PHYS] {q} [CHAT]"
dyn = f["dynamic"]
cue = {"gravity": "constant downward acceleration",
"oscillation": "periodic back-and-forth around a center",
"collision": "abrupt velocity reversal on contact"}[dyn]
trace = f"step 1: observe {cue}. step 2: that is {dyn}. Answer: {dyn}"
return dict(family="p_identify", fmt="CHAT", use_v=True, use_a=False, use_p=True,
spec={"kind": kind, "seed": seed}, prompt=prompt, answer=trace)
def p_predict():
kind = random.choice(PHYS_KINDS); seed = random.randint(0, 2**31 - 1)
f = simulate_facts(kind, seed); fr = _fr()
if kind in ("ball", "bounce"):
ans = "yes" if f.get("reaches_right") else "no"
q = ("l'objet atteint-il le bord droit ?" if fr
else "does the object reach the right edge?")
reason = "its rightward velocity carries it to the wall" if ans == "yes" \
else "it falls or stops before the wall"
elif kind == "twobody":
ans = "yes"; q = ("les deux corps vont-ils entrer en collision ?" if fr
else "will the two bodies collide?")
reason = "they approach from both sides and meet in the middle"
else:
ans = "no"; q = ("l'objet quitte-t-il le centre durablement ?" if fr
else "does the object leave the center permanently?")
reason = "it oscillates and returns to the center each period"
prompt = f"[VIS] [PHYS] {q} [CHAT]"
trace = f"step 1: {reason}. Answer: {ans}"
return dict(family="p_predict", fmt="CHAT", use_v=True, use_a=False, use_p=True,
spec={"kind": kind, "seed": seed}, prompt=prompt, answer=trace)
# ---------------------------------------------------------------------------
# CROSS-MODAL β€” [ACTION] ({domain,op,params} schema, as MM base pretraining)
# ---------------------------------------------------------------------------
def x_robot():
seed = random.randint(0, 2**31 - 1); rng = random.Random(seed); fr = _fr()
target = random.choice(["dock", "block", "marker", "exit"])
speed = round(rng.uniform(0.2, 0.9), 2); angle = rng.randint(0, 359)
q = (f"pousse vers le {target}" if fr else f"push toward the {target}")
prompt = f"[VIS] {q} [ACTION]"
action = {"domain": "ros", "op": "move",
"params": {"speed": speed, "angle": angle, "duration_s": 1}}
return dict(family="x_robot", fmt="ACTION", use_v=True, use_a=False, use_p=False,
spec={"kind": "robot", "seed": seed}, prompt=prompt, answer=_aj(action))
def x_sensor():
pre = "speed="; val = random.randint(10, 99)
label = pre + str(val); limit = 50; fr = _fr()
q = (f"si la vitesse dΓ©passe {limit}, ralentis" if fr
else f"if speed exceeds {limit}, slow down")
prompt = f"[VIS] [OCR] {q} [ACTION]"
if val > limit:
action = {"domain": "ros", "op": "set_speed", "params": {"value": limit}}
else:
action = {"domain": "ros", "op": "continue", "params": {}}
return dict(family="x_sensor", fmt="ACTION", use_v=True, use_a=False, use_p=False,
spec={"kind": "ocr", "seed": random.randint(0, 2**31 - 1), "label": label},
prompt=prompt, answer=_aj(action))
SOUND_CAUSE = {
"sharp impact": "collision",
"double impact": "collision",
"rhythmic creak": "oscillation",
"whoosh then thud": "falling object",
"servo whir": "motor",
}
KIND_SOUND = {"ball": "whoosh then thud", "spring": "rhythmic creak",
"bounce": "sharp impact", "twobody": "double impact"}
SOUND_ACTION = {
"alarm": ({"domain": "ros", "op": "stop", "params": {}}, "an alarm"),
"servo whir": ({"domain": "ros", "op": "continue", "params": {}}, "a servo whir"),
"sharp impact": ({"domain": "ros", "op": "halt", "params": {"reason": "collision"}}, "an impact"),
"rhythmic creak": ({"domain": "ros", "op": "slow", "params": {"value": 20}}, "a creak"),
}
def a_identify():
snd = random.choice(list(SOUND_CAUSE.keys())); cause = SOUND_CAUSE[snd]; fr = _fr()
q = "qu'est-ce qui a produit ce son ?" if fr else "what produced this sound?"
prompt = f"[AUD] {q} [CHAT]"
desc = {"collision": "a sharp broadband transient",
"oscillation": "a periodic rhythmic tone",
"falling object": "a rising sweep followed by a thud",
"motor": "a steady mechanical hum"}[cause]
trace = f"step 1: hear {desc}. step 2: that indicates {cause}. Answer: {cause}"
return dict(family="a_identify", fmt="CHAT", use_v=False, use_a=True, use_p=False,
spec={"kind": "audio", "sound": snd}, prompt=prompt, answer=trace)
def a_match():
kind = random.choice(PHYS_KINDS); seed = random.randint(0, 2**31 - 1)
true_sound = KIND_SOUND[kind]; fr = _fr()
if random.random() < 0.5:
snd = true_sound; ans = "yes"; reason = "the sound fits the motion"
else:
snd = random.choice([s for s in KIND_SOUND.values() if s != true_sound])
ans = "no"; reason = "the sound does not fit the motion"
q = ("le son correspond-il au mouvement ?" if fr
else "does the sound match the motion?")
prompt = f"[VIS] [AUD] [PHYS] {q} [CHAT]"
trace = f"step 1: {reason}. Answer: {ans}"
return dict(family="a_match", fmt="CHAT", use_v=True, use_a=True, use_p=True,
spec={"kind": kind, "seed": seed, "sound": snd}, prompt=prompt, answer=trace)
def a_event():
snd = random.choice(list(SOUND_ACTION.keys())); action, desc = SOUND_ACTION[snd]; fr = _fr()
instr = {"stop": "arrΓͺte le robot" if fr else "stop the robot",
"continue": "continue" if fr else "keep going",
"halt": "stoppe net" if fr else "halt immediately",
"slow": "ralentis" if fr else "slow down"}[action["op"]]
q = (f"si tu entends {desc}, {instr}" if fr else f"if you hear {desc}, {instr}")
prompt = f"[AUD] {q} [ACTION]"
return dict(family="a_event", fmt="ACTION", use_v=False, use_a=True, use_p=False,
spec={"kind": "audio", "sound": snd}, prompt=prompt, answer=_aj(action))
# builder: produce a held-out eval set in-memory (disjoint seed, no files)
def build_eval(n=1800, seed=99991, families="auto"):
"""families: 'reasoning' (6 visual/physics/cross-modal) or 'auto'/'audio' (all 9)."""
if families == "reasoning":
gens = [v_motion, v_ocr, p_identify, p_predict, x_robot, x_sensor]
else:
gens = [v_motion, v_ocr, p_identify, p_predict, x_robot, x_sensor,
a_identify, a_match, a_event]
_st = random.getstate(); random.seed(seed)
rows = []
for _ in range(n):
s = random.choice(gens)()
s["text"] = s["prompt"] + " " + s["answer"]; rows.append(s)
random.setstate(_st)
return rows
# =============================================================================
# Benchmark
# =============================================================================
def load_ckpt(model, path):
ck = torch.load(path, map_location=device)
sd = ck["model"] if "model" in ck else ck
model.load_state_dict(sd, strict=True)
return ck.get("step", "?"), ck.get("best", None)
# =============================================================================
# SAM-MM β€” HuggingFace Space (self-contained; weights pulled from HF)
# Architecture inlined above. Set HF_TOKEN as a Space secret for private repos.
# =============================================================================
import io
try:
from PIL import Image
except ImportError:
os.system(f"{sys.executable} -m pip install -q Pillow --break-system-packages"); from PIL import Image
import gradio as gr
CHECKPOINTS = {
"Reasoning β€” vision + physics": "AMFORGE/sam-mm-reasoning-checkpoints:best.pt",
"Audio-reasoning β€” + sound": "AMFORGE/sam-mm-audio-reasoning-checkpoints:best.pt",
}
SCENES = {
"πŸͺ Physics β€” identify the dynamic": "p_identify",
"🎯 Physics β€” predict the outcome": "p_predict",
"➑️ Vision β€” direction of motion": "v_motion",
"πŸ”’ Vision β€” read the number (OCR)": "v_ocr",
"πŸ›°οΈ Cross-modal β€” sensor β†’ action": "x_sensor",
"πŸ”Š Audio β€” identify the sound": "a_identify",
"🎬 Audio β€” match sight + sound": "a_match",
"⚑ Audio β€” sound β†’ action": "a_event",
}
FAMFUNC = {"p_identify": p_identify, "p_predict": p_predict, "v_motion": v_motion,
"v_ocr": v_ocr, "x_sensor": x_sensor, "a_identify": a_identify,
"a_match": a_match, "a_event": a_event}
AUDIO_FAMS = {"a_identify", "a_match", "a_event"}
_STATE = {"model": None, "tok": None, "ckpt": None}
def _load(ckpt):
if _STATE["model"] is not None and _STATE["ckpt"] == ckpt:
return _STATE["model"], _STATE["tok"]
token = get_hf_token()
tok = Tok(token=token)
model = SAMMM(Config()).to(device)
load_ckpt(model, resolve_ckpt(ckpt, token)); model.eval()
_STATE.update(model=model, tok=tok, ckpt=ckpt)
return model, tok
def _montage(frames):
n = frames.shape[0]
tiles = [(frames[i].clamp(0,1).permute(1,2,0)*255).byte().cpu().numpy() for i in range(n)]
w = IMG*n + (n-1)*2
img = Image.new("RGB", (w, IMG), (17,18,26))
for i,t in enumerate(tiles): img.paste(Image.fromarray(t), (i*(IMG+2),0))
return img.resize((w*4, IMG*4), Image.NEAREST)
def _infer(ckpt_label, scene_label, max_new):
ckpt = CHECKPOINTS[ckpt_label]; fam = SCENES[scene_label]
s = FAMFUNC[fam]()
frames, mel = render_av(s["spec"])
model, tok = _load(ckpt)
pred = generate(model, tok, s["prompt"], frames, mel, max_new=int(max_new))
gold = s["answer"].split("Answer:")[-1].strip()
got = pred.split("Answer:")[-1].strip() if "Answer:" in pred else pred.strip()
ok = _chat_match(pred, s["answer"])
return s, frames, got, gold, ok, fam
def run_one(ckpt_label, scene_label, max_new):
s, frames, got, gold, ok, fam = _infer(ckpt_label, scene_label, max_new)
warn = ""
if fam in AUDIO_FAMS and "audio" not in ckpt_label.lower():
warn = ("\n\n> ⚠️ This is an **audio** scene on the **reasoning** checkpoint β€” "
"it never learned sound, so a correct answer here is chance. "
"Switch the checkpoint to *Audio-reasoning* to test it for real.")
verdict = "βœ… **correct**" if ok else "❌ **mismatch**"
md = (f"### {verdict}\n"
f"**Prompt β†’ SAM-MM**\n```\n{s['prompt']}\n```\n"
f"**Model answer:** `{got}` &nbsp;β€’&nbsp; **Ground truth:** `{gold}`")
if "sound" in s["spec"]:
md += (f"\n\n*Sound cue `{s['spec']['sound']}` β†’ a deterministic log-mel "
f"(no audible file; the encoder reads the spectrogram).*")
md += warn
return _montage(frames), md
def run_batch(ckpt_label, scene_label, max_new, n=20):
hits = 0; lines = []
for _ in range(int(n)):
s, frames, got, gold, ok, fam = _infer(ckpt_label, scene_label, max_new)
hits += int(ok)
lines.append(f"{'βœ…' if ok else '❌'} `{got}` vs `{gold}`")
acc = 100*hits/int(n)
head = f"### {hits}/{int(n)} correct &nbsp; β†’ &nbsp; **{acc:.0f}%** exact-match\n\n"
return head + "\n".join(lines)
CSS = """
.gradio-container {max-width: 980px !important}
#title {text-align:center}
#frames img {image-rendering: pixelated; border-radius: 10px}
footer {display:none !important}
"""
with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo", neutral_hue="slate"), css=CSS,
title="SAM-MM Β· multimodal demo") as demo:
gr.Markdown("# 🧠 SAM-MM β€” a 58M multimodal model that *reasons*", elem_id="title")
gr.Markdown(
"Pick a scene. SAM-MM perceives the **rendered frames** (and, for audio scenes, a "
"**log-mel spectrogram**), then answers in `[CHAT]` text or a `[ACTION]` JSON record. "
"Frames are synthetic β€” this is the model's native world. Nothing is hard-coded: each "
"scene is freshly generated, the model decodes token-by-token, and the answer is checked "
"against ground truth computed independently.")
with gr.Row():
with gr.Column(scale=1):
ckpt = gr.Dropdown(list(CHECKPOINTS), value=list(CHECKPOINTS)[0], label="Checkpoint")
scene = gr.Dropdown(list(SCENES), value=list(SCENES)[0], label="Scene")
max_new = gr.Slider(24, 96, value=64, step=8, label="max new tokens")
with gr.Row():
b1 = gr.Button("Generate & run", variant="primary")
b2 = gr.Button("Run 20 (accuracy)")
with gr.Column(scale=2):
img = gr.Image(label="What SAM-MM sees", elem_id="frames")
md = gr.Markdown()
batch = gr.Markdown()
b1.click(run_one, [ckpt, scene, max_new], [img, md])
b2.click(run_batch, [ckpt, scene, max_new], [batch])
gr.Markdown(
"---\n**Honest notes.** Physics & motion are SAM-MM's strength (its world-model carries "
"real dynamics). OCR generalizes to unseen numbers but isn't perfect. The cross-modal "
"`[ACTION]` family is weaker. **Audio is the weak modality** β€” it was trained on synthetic "
"pseudo-mel, so strong audio scores here partly reflect that, not true listening. "
"Architecture internals are proprietary and not exposed.")
if __name__ == "__main__":
demo.launch()