"""mars_e4b.py — the FIRST MARS-on-Gemma-4-E4B: re-fit the Apuesta-1 ChannelInjection (cross-attn+ReZero) to E4B dims (hidden 2560, 42 layers), wire it onto the frozen base, and MEASURE the wiring: 1. bit-exact at alpha=0 (the residual add is a clean no-op → base unchanged) [correctness] 2. trainable param count (the adapter is small; base frozen) 3. gradient-flow (one backward → ReZero alphas receive grad → trainable end-to-end) 4. channel counterfactual (bump alpha; same prompt + DIFFERENT channel → DIFFERENT logits) [the MARS property] Ported verbatim from training/train_apuesta1.py (iter-6 gated cross-attention). CPU, c-60. """ import sys, time, torch, torch.nn as nn, torch.nn.functional as F from dataclasses import dataclass, field torch.manual_seed(0) MP = "/root/gemma-4-e4b-it"; torch.set_num_threads(30) CHANNEL_ORDER = [("memory",1024),("affect",2),("time",16),("ethics",24),("identity",1024),("continuity",1024)] @dataclass class InjectionConfig: hidden_dim: int channel_inner_dim: int = 192 channel_source_dims: dict = field(default_factory=lambda: dict(CHANNEL_ORDER)) class ChannelInjectionDelta(nn.Module): K_PER_CHANNEL = {"memory":8,"identity":8,"continuity":8,"time":3,"ethics":1,"affect":1} def __init__(self, cfg): super().__init__() self.channel_names = list(cfg.channel_source_dims.keys()); self.n_ch = len(self.channel_names) self.src_dims = [cfg.channel_source_dims[n] for n in self.channel_names]; self.max_src = max(self.src_dims) self.k_per = [self.K_PER_CHANNEL[n] for n in self.channel_names]; self.max_k = max(self.k_per) h, c = cfg.hidden_dim, cfg.channel_inner_dim; self.c = c; self.scale = c**-0.5 self.pool_weight = nn.Parameter(torch.zeros(self.n_ch, c, self.max_src)) self.expand_k = nn.Parameter(torch.zeros(self.n_ch, self.max_k, c, c)) self.expand_v = nn.Parameter(torch.zeros(self.n_ch, self.max_k, c, c)) self.q_proj = nn.Parameter(torch.zeros(c, h)); self.o_proj = nn.Parameter(torch.zeros(h, c)) self.alphas = nn.Parameter(torch.zeros(self.n_ch)) mask = torch.full((self.n_ch, self.max_k), float("-inf")) for i, k in enumerate(self.k_per): mask[i, :k] = 0.0 self.register_buffer("key_mask", mask, persistent=False) for i, d in enumerate(self.src_dims): nn.init.normal_(self.pool_weight.data[i, :, :d], std=d**-0.5) nn.init.normal_(self.expand_k, std=c**-0.5); nn.init.normal_(self.expand_v, std=c**-0.5) nn.init.normal_(self.q_proj, std=h**-0.5); nn.init.normal_(self.o_proj, std=c**-0.5) def forward(self, hidden, channels): dt = hidden.dtype ch = torch.stack([F.pad(channels[n].to(dt), (0, self.max_src - self.src_dims[i])) for i, n in enumerate(self.channel_names)], dim=1) pw,ek,ev = self.pool_weight.to(dt),self.expand_k.to(dt),self.expand_v.to(dt) qp,op,al = self.q_proj.to(dt),self.o_proj.to(dt),self.alphas.to(dt); km = self.key_mask.to(dt) s = torch.einsum("bns,ncs->bnc", ch, pw) kt = torch.einsum("bnc,nkdc->bnkd", s, ek); vt = torch.einsum("bnc,nkdc->bnkd", s, ev) q = torch.einsum("bth,ch->btc", hidden, qp) scores = torch.einsum("btc,bnkc->bntk", q, kt) * self.scale + km.view(1,self.n_ch,1,self.max_k) attn = scores.softmax(dim=-1); ctx = torch.einsum("bntk,bnkc->bntc", attn, vt) gated = (ctx * al.view(1,self.n_ch,1,1)).sum(dim=1) return torch.einsum("btc,hc->bth", gated, op) class ChannelHolder: def __init__(self): self.channels = None class ChanneledLayer(nn.Module): def __init__(self, base_layer, cfg, holder): super().__init__() self.base = base_layer self.norm_channels = nn.RMSNorm(cfg.hidden_dim) self.channel_inj = ChannelInjectionDelta(cfg); self._holder = holder def forward(self, *args, **kwargs): out = self.base(*args, **kwargs) hidden = out[0] if isinstance(out, tuple) else out chans = self._holder.channels if chans is not None: normed = self.norm_channels(hidden.float()).to(hidden.dtype) hidden = hidden + self.channel_inj(normed, chans).to(hidden.dtype) return (hidden, *out[1:]) if isinstance(out, tuple) else hidden def find_layers(model): import torch.nn as nn best = None for name, mod in model.named_modules(): if isinstance(mod, nn.ModuleList) and len(mod) >= 20: if best is None or len(mod) > len(best[1]): best = (name, mod) return best def main(): from transformers import AutoTokenizer, AutoModelForCausalLM tok = AutoTokenizer.from_pretrained(MP) t0 = time.time() model = AutoModelForCausalLM.from_pretrained(MP, dtype=torch.bfloat16, low_cpu_mem_usage=True).eval() H = model.config.text_config.hidden_size if hasattr(model.config, "text_config") else model.config.hidden_size name, layers = find_layers(model) print(f"loaded in {time.time()-t0:.0f}s | hidden_dim={H} | layers='{name}' n={len(layers)}") for p in model.parameters(): p.requires_grad_(False) holder = ChannelHolder(); cfg = InjectionConfig(hidden_dim=H) for i in range(len(layers)): layers[i] = ChanneledLayer(layers[i], cfg, holder) trainable = [p for p in model.parameters() if p.requires_grad] n_tr = sum(p.numel() for p in trainable) print(f"MARS wired: {len(layers)} ChannelInjection blocks | trainable params {n_tr/1e6:.2f}M " f"({n_tr/7.94e9*100:.3f}% of base) | params f32: {all(p.dtype==torch.float32 for p in trainable)}") ids = tok("def fibonacci(n):", return_tensors="pt").input_ids B, T = ids.shape def chans(seed): g = torch.Generator().manual_seed(seed) d = {n: torch.zeros(B, dim) for n, dim in CHANNEL_ORDER} d["identity"] = torch.randn(B, 1024, generator=g); return d zeros = {n: torch.zeros(B, dim) for n, dim in CHANNEL_ORDER} # 1. bit-exact at alpha=0 with torch.no_grad(): holder.channels = None; base_logits = model(ids).logits holder.channels = zeros; mars_logits = model(ids).logits holder.channels = None diff = (base_logits - mars_logits).abs().max().item() print(f"\n1. BIT-EXACT @alpha=0: max|Δlogits| = {diff:.2e} -> {'BIT-EXACT ✓' if diff < 1e-4 else 'NOT bit-exact ✗'}") # 2/3. gradient-flow: one backward with channels -> alphas receive grad holder.channels = chans(1) out = model(ids).logits loss = F.cross_entropy(out[:, -1].float(), torch.tensor([100]*B)) loss.backward() galpha = sum((l.channel_inj.alphas.grad.abs().sum().item() for l in layers if l.channel_inj.alphas.grad is not None)) print(f"2. GRAD-FLOW: Σ|alpha.grad| across 42 gates = {galpha:.3e} -> {'gates trainable ✓' if galpha>0 else 'NO grad ✗'}") model.zero_grad(set_to_none=True) # 4. channel counterfactual: bump alpha, SAME prompt + DIFFERENT channel -> DIFFERENT logits for l in layers: with torch.no_grad(): l.channel_inj.alphas.fill_(0.5) with torch.no_grad(): holder.channels = chans(1); logA = model(ids).logits holder.channels = chans(2); logB = model(ids).logits holder.channels = None cf = (logA - logB).abs().max().item() flip = (logA[:, -1].argmax(-1) != logB[:, -1].argmax(-1)).float().mean().item() print(f"3. COUNTERFACTUAL (alpha=0.5): max|Δlogits| chanA-vs-chanB = {cf:.3f} | next-token flip = {flip:.0%} " f"-> {'channel drives output ✓' if cf>1e-2 else 'channel inert ✗'}") print("\nVERDICT: first MARS-on-Gemma-4-E4B wired —", "bit-exact base, gates trainable, channel measurably drives output." if diff<1e-4 and galpha>0 and cf>1e-2 else "CHECK above.") if __name__ == "__main__": main()