tinymars-proprioceptive-channels / modeling_channels.py
celiumsAI's picture
Upload modeling_channels.py with huggingface_hub
3d57ce4 verified
"""mars_e4b.py — the FIRST MARS-on-Gemma-4-E4B: re-fit the Apuesta-1 ChannelInjection (cross-attn+ReZero)
to E4B dims (hidden 2560, 42 layers), wire it onto the frozen base, and MEASURE the wiring:
1. bit-exact at alpha=0 (the residual add is a clean no-op → base unchanged) [correctness]
2. trainable param count (the adapter is small; base frozen)
3. gradient-flow (one backward → ReZero alphas receive grad → trainable end-to-end)
4. channel counterfactual (bump alpha; same prompt + DIFFERENT channel → DIFFERENT logits) [the MARS property]
Ported verbatim from training/train_apuesta1.py (iter-6 gated cross-attention). CPU, c-60.
"""
import sys, time, torch, torch.nn as nn, torch.nn.functional as F
from dataclasses import dataclass, field
torch.manual_seed(0)
MP = "/root/gemma-4-e4b-it"; torch.set_num_threads(30)
CHANNEL_ORDER = [("memory",1024),("affect",2),("time",16),("ethics",24),("identity",1024),("continuity",1024)]
@dataclass
class InjectionConfig:
hidden_dim: int
channel_inner_dim: int = 192
channel_source_dims: dict = field(default_factory=lambda: dict(CHANNEL_ORDER))
class ChannelInjectionDelta(nn.Module):
K_PER_CHANNEL = {"memory":8,"identity":8,"continuity":8,"time":3,"ethics":1,"affect":1}
def __init__(self, cfg):
super().__init__()
self.channel_names = list(cfg.channel_source_dims.keys()); self.n_ch = len(self.channel_names)
self.src_dims = [cfg.channel_source_dims[n] for n in self.channel_names]; self.max_src = max(self.src_dims)
self.k_per = [self.K_PER_CHANNEL[n] for n in self.channel_names]; self.max_k = max(self.k_per)
h, c = cfg.hidden_dim, cfg.channel_inner_dim; self.c = c; self.scale = c**-0.5
self.pool_weight = nn.Parameter(torch.zeros(self.n_ch, c, self.max_src))
self.expand_k = nn.Parameter(torch.zeros(self.n_ch, self.max_k, c, c))
self.expand_v = nn.Parameter(torch.zeros(self.n_ch, self.max_k, c, c))
self.q_proj = nn.Parameter(torch.zeros(c, h)); self.o_proj = nn.Parameter(torch.zeros(h, c))
self.alphas = nn.Parameter(torch.zeros(self.n_ch))
mask = torch.full((self.n_ch, self.max_k), float("-inf"))
for i, k in enumerate(self.k_per): mask[i, :k] = 0.0
self.register_buffer("key_mask", mask, persistent=False)
for i, d in enumerate(self.src_dims): nn.init.normal_(self.pool_weight.data[i, :, :d], std=d**-0.5)
nn.init.normal_(self.expand_k, std=c**-0.5); nn.init.normal_(self.expand_v, std=c**-0.5)
nn.init.normal_(self.q_proj, std=h**-0.5); nn.init.normal_(self.o_proj, std=c**-0.5)
def forward(self, hidden, channels):
dt = hidden.dtype
ch = torch.stack([F.pad(channels[n].to(dt), (0, self.max_src - self.src_dims[i]))
for i, n in enumerate(self.channel_names)], dim=1)
pw,ek,ev = self.pool_weight.to(dt),self.expand_k.to(dt),self.expand_v.to(dt)
qp,op,al = self.q_proj.to(dt),self.o_proj.to(dt),self.alphas.to(dt); km = self.key_mask.to(dt)
s = torch.einsum("bns,ncs->bnc", ch, pw)
kt = torch.einsum("bnc,nkdc->bnkd", s, ek); vt = torch.einsum("bnc,nkdc->bnkd", s, ev)
q = torch.einsum("bth,ch->btc", hidden, qp)
scores = torch.einsum("btc,bnkc->bntk", q, kt) * self.scale + km.view(1,self.n_ch,1,self.max_k)
attn = scores.softmax(dim=-1); ctx = torch.einsum("bntk,bnkc->bntc", attn, vt)
gated = (ctx * al.view(1,self.n_ch,1,1)).sum(dim=1)
return torch.einsum("btc,hc->bth", gated, op)
class ChannelHolder:
def __init__(self): self.channels = None
class ChanneledLayer(nn.Module):
def __init__(self, base_layer, cfg, holder):
super().__init__()
self.base = base_layer
self.norm_channels = nn.RMSNorm(cfg.hidden_dim)
self.channel_inj = ChannelInjectionDelta(cfg); self._holder = holder
def forward(self, *args, **kwargs):
out = self.base(*args, **kwargs)
hidden = out[0] if isinstance(out, tuple) else out
chans = self._holder.channels
if chans is not None:
normed = self.norm_channels(hidden.float()).to(hidden.dtype)
hidden = hidden + self.channel_inj(normed, chans).to(hidden.dtype)
return (hidden, *out[1:]) if isinstance(out, tuple) else hidden
def find_layers(model):
import torch.nn as nn
best = None
for name, mod in model.named_modules():
if isinstance(mod, nn.ModuleList) and len(mod) >= 20:
if best is None or len(mod) > len(best[1]): best = (name, mod)
return best
def main():
from transformers import AutoTokenizer, AutoModelForCausalLM
tok = AutoTokenizer.from_pretrained(MP)
t0 = time.time()
model = AutoModelForCausalLM.from_pretrained(MP, dtype=torch.bfloat16, low_cpu_mem_usage=True).eval()
H = model.config.text_config.hidden_size if hasattr(model.config, "text_config") else model.config.hidden_size
name, layers = find_layers(model)
print(f"loaded in {time.time()-t0:.0f}s | hidden_dim={H} | layers='{name}' n={len(layers)}")
for p in model.parameters(): p.requires_grad_(False)
holder = ChannelHolder(); cfg = InjectionConfig(hidden_dim=H)
for i in range(len(layers)):
layers[i] = ChanneledLayer(layers[i], cfg, holder)
trainable = [p for p in model.parameters() if p.requires_grad]
n_tr = sum(p.numel() for p in trainable)
print(f"MARS wired: {len(layers)} ChannelInjection blocks | trainable params {n_tr/1e6:.2f}M "
f"({n_tr/7.94e9*100:.3f}% of base) | params f32: {all(p.dtype==torch.float32 for p in trainable)}")
ids = tok("def fibonacci(n):", return_tensors="pt").input_ids
B, T = ids.shape
def chans(seed):
g = torch.Generator().manual_seed(seed)
d = {n: torch.zeros(B, dim) for n, dim in CHANNEL_ORDER}
d["identity"] = torch.randn(B, 1024, generator=g); return d
zeros = {n: torch.zeros(B, dim) for n, dim in CHANNEL_ORDER}
# 1. bit-exact at alpha=0
with torch.no_grad():
holder.channels = None; base_logits = model(ids).logits
holder.channels = zeros; mars_logits = model(ids).logits
holder.channels = None
diff = (base_logits - mars_logits).abs().max().item()
print(f"\n1. BIT-EXACT @alpha=0: max|Δlogits| = {diff:.2e} -> {'BIT-EXACT ✓' if diff < 1e-4 else 'NOT bit-exact ✗'}")
# 2/3. gradient-flow: one backward with channels -> alphas receive grad
holder.channels = chans(1)
out = model(ids).logits
loss = F.cross_entropy(out[:, -1].float(), torch.tensor([100]*B))
loss.backward()
galpha = sum((l.channel_inj.alphas.grad.abs().sum().item()
for l in layers if l.channel_inj.alphas.grad is not None))
print(f"2. GRAD-FLOW: Σ|alpha.grad| across 42 gates = {galpha:.3e} -> {'gates trainable ✓' if galpha>0 else 'NO grad ✗'}")
model.zero_grad(set_to_none=True)
# 4. channel counterfactual: bump alpha, SAME prompt + DIFFERENT channel -> DIFFERENT logits
for l in layers:
with torch.no_grad(): l.channel_inj.alphas.fill_(0.5)
with torch.no_grad():
holder.channels = chans(1); logA = model(ids).logits
holder.channels = chans(2); logB = model(ids).logits
holder.channels = None
cf = (logA - logB).abs().max().item()
flip = (logA[:, -1].argmax(-1) != logB[:, -1].argmax(-1)).float().mean().item()
print(f"3. COUNTERFACTUAL (alpha=0.5): max|Δlogits| chanA-vs-chanB = {cf:.3f} | next-token flip = {flip:.0%} "
f"-> {'channel drives output ✓' if cf>1e-2 else 'channel inert ✗'}")
print("\nVERDICT: first MARS-on-Gemma-4-E4B wired —",
"bit-exact base, gates trainable, channel measurably drives output." if diff<1e-4 and galpha>0 and cf>1e-2
else "CHECK above.")
if __name__ == "__main__":
main()