| """mars_e4b.py — the FIRST MARS-on-Gemma-4-E4B: re-fit the Apuesta-1 ChannelInjection (cross-attn+ReZero) |
| to E4B dims (hidden 2560, 42 layers), wire it onto the frozen base, and MEASURE the wiring: |
| 1. bit-exact at alpha=0 (the residual add is a clean no-op → base unchanged) [correctness] |
| 2. trainable param count (the adapter is small; base frozen) |
| 3. gradient-flow (one backward → ReZero alphas receive grad → trainable end-to-end) |
| 4. channel counterfactual (bump alpha; same prompt + DIFFERENT channel → DIFFERENT logits) [the MARS property] |
| Ported verbatim from training/train_apuesta1.py (iter-6 gated cross-attention). CPU, c-60. |
| """ |
| import sys, time, torch, torch.nn as nn, torch.nn.functional as F |
| from dataclasses import dataclass, field |
| torch.manual_seed(0) |
| MP = "/root/gemma-4-e4b-it"; torch.set_num_threads(30) |
|
|
| CHANNEL_ORDER = [("memory",1024),("affect",2),("time",16),("ethics",24),("identity",1024),("continuity",1024)] |
|
|
| @dataclass |
| class InjectionConfig: |
| hidden_dim: int |
| channel_inner_dim: int = 192 |
| channel_source_dims: dict = field(default_factory=lambda: dict(CHANNEL_ORDER)) |
|
|
| class ChannelInjectionDelta(nn.Module): |
| K_PER_CHANNEL = {"memory":8,"identity":8,"continuity":8,"time":3,"ethics":1,"affect":1} |
| def __init__(self, cfg): |
| super().__init__() |
| self.channel_names = list(cfg.channel_source_dims.keys()); self.n_ch = len(self.channel_names) |
| self.src_dims = [cfg.channel_source_dims[n] for n in self.channel_names]; self.max_src = max(self.src_dims) |
| self.k_per = [self.K_PER_CHANNEL[n] for n in self.channel_names]; self.max_k = max(self.k_per) |
| h, c = cfg.hidden_dim, cfg.channel_inner_dim; self.c = c; self.scale = c**-0.5 |
| self.pool_weight = nn.Parameter(torch.zeros(self.n_ch, c, self.max_src)) |
| self.expand_k = nn.Parameter(torch.zeros(self.n_ch, self.max_k, c, c)) |
| self.expand_v = nn.Parameter(torch.zeros(self.n_ch, self.max_k, c, c)) |
| self.q_proj = nn.Parameter(torch.zeros(c, h)); self.o_proj = nn.Parameter(torch.zeros(h, c)) |
| self.alphas = nn.Parameter(torch.zeros(self.n_ch)) |
| mask = torch.full((self.n_ch, self.max_k), float("-inf")) |
| for i, k in enumerate(self.k_per): mask[i, :k] = 0.0 |
| self.register_buffer("key_mask", mask, persistent=False) |
| for i, d in enumerate(self.src_dims): nn.init.normal_(self.pool_weight.data[i, :, :d], std=d**-0.5) |
| nn.init.normal_(self.expand_k, std=c**-0.5); nn.init.normal_(self.expand_v, std=c**-0.5) |
| nn.init.normal_(self.q_proj, std=h**-0.5); nn.init.normal_(self.o_proj, std=c**-0.5) |
| def forward(self, hidden, channels): |
| dt = hidden.dtype |
| ch = torch.stack([F.pad(channels[n].to(dt), (0, self.max_src - self.src_dims[i])) |
| for i, n in enumerate(self.channel_names)], dim=1) |
| pw,ek,ev = self.pool_weight.to(dt),self.expand_k.to(dt),self.expand_v.to(dt) |
| qp,op,al = self.q_proj.to(dt),self.o_proj.to(dt),self.alphas.to(dt); km = self.key_mask.to(dt) |
| s = torch.einsum("bns,ncs->bnc", ch, pw) |
| kt = torch.einsum("bnc,nkdc->bnkd", s, ek); vt = torch.einsum("bnc,nkdc->bnkd", s, ev) |
| q = torch.einsum("bth,ch->btc", hidden, qp) |
| scores = torch.einsum("btc,bnkc->bntk", q, kt) * self.scale + km.view(1,self.n_ch,1,self.max_k) |
| attn = scores.softmax(dim=-1); ctx = torch.einsum("bntk,bnkc->bntc", attn, vt) |
| gated = (ctx * al.view(1,self.n_ch,1,1)).sum(dim=1) |
| return torch.einsum("btc,hc->bth", gated, op) |
|
|
| class ChannelHolder: |
| def __init__(self): self.channels = None |
|
|
| class ChanneledLayer(nn.Module): |
| def __init__(self, base_layer, cfg, holder): |
| super().__init__() |
| self.base = base_layer |
| self.norm_channels = nn.RMSNorm(cfg.hidden_dim) |
| self.channel_inj = ChannelInjectionDelta(cfg); self._holder = holder |
| def forward(self, *args, **kwargs): |
| out = self.base(*args, **kwargs) |
| hidden = out[0] if isinstance(out, tuple) else out |
| chans = self._holder.channels |
| if chans is not None: |
| normed = self.norm_channels(hidden.float()).to(hidden.dtype) |
| hidden = hidden + self.channel_inj(normed, chans).to(hidden.dtype) |
| return (hidden, *out[1:]) if isinstance(out, tuple) else hidden |
|
|
| def find_layers(model): |
| import torch.nn as nn |
| best = None |
| for name, mod in model.named_modules(): |
| if isinstance(mod, nn.ModuleList) and len(mod) >= 20: |
| if best is None or len(mod) > len(best[1]): best = (name, mod) |
| return best |
|
|
| def main(): |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| tok = AutoTokenizer.from_pretrained(MP) |
| t0 = time.time() |
| model = AutoModelForCausalLM.from_pretrained(MP, dtype=torch.bfloat16, low_cpu_mem_usage=True).eval() |
| H = model.config.text_config.hidden_size if hasattr(model.config, "text_config") else model.config.hidden_size |
| name, layers = find_layers(model) |
| print(f"loaded in {time.time()-t0:.0f}s | hidden_dim={H} | layers='{name}' n={len(layers)}") |
| for p in model.parameters(): p.requires_grad_(False) |
| holder = ChannelHolder(); cfg = InjectionConfig(hidden_dim=H) |
| for i in range(len(layers)): |
| layers[i] = ChanneledLayer(layers[i], cfg, holder) |
| trainable = [p for p in model.parameters() if p.requires_grad] |
| n_tr = sum(p.numel() for p in trainable) |
| print(f"MARS wired: {len(layers)} ChannelInjection blocks | trainable params {n_tr/1e6:.2f}M " |
| f"({n_tr/7.94e9*100:.3f}% of base) | params f32: {all(p.dtype==torch.float32 for p in trainable)}") |
|
|
| ids = tok("def fibonacci(n):", return_tensors="pt").input_ids |
| B, T = ids.shape |
| def chans(seed): |
| g = torch.Generator().manual_seed(seed) |
| d = {n: torch.zeros(B, dim) for n, dim in CHANNEL_ORDER} |
| d["identity"] = torch.randn(B, 1024, generator=g); return d |
| zeros = {n: torch.zeros(B, dim) for n, dim in CHANNEL_ORDER} |
|
|
| |
| with torch.no_grad(): |
| holder.channels = None; base_logits = model(ids).logits |
| holder.channels = zeros; mars_logits = model(ids).logits |
| holder.channels = None |
| diff = (base_logits - mars_logits).abs().max().item() |
| print(f"\n1. BIT-EXACT @alpha=0: max|Δlogits| = {diff:.2e} -> {'BIT-EXACT ✓' if diff < 1e-4 else 'NOT bit-exact ✗'}") |
|
|
| |
| holder.channels = chans(1) |
| out = model(ids).logits |
| loss = F.cross_entropy(out[:, -1].float(), torch.tensor([100]*B)) |
| loss.backward() |
| galpha = sum((l.channel_inj.alphas.grad.abs().sum().item() |
| for l in layers if l.channel_inj.alphas.grad is not None)) |
| print(f"2. GRAD-FLOW: Σ|alpha.grad| across 42 gates = {galpha:.3e} -> {'gates trainable ✓' if galpha>0 else 'NO grad ✗'}") |
| model.zero_grad(set_to_none=True) |
|
|
| |
| for l in layers: |
| with torch.no_grad(): l.channel_inj.alphas.fill_(0.5) |
| with torch.no_grad(): |
| holder.channels = chans(1); logA = model(ids).logits |
| holder.channels = chans(2); logB = model(ids).logits |
| holder.channels = None |
| cf = (logA - logB).abs().max().item() |
| flip = (logA[:, -1].argmax(-1) != logB[:, -1].argmax(-1)).float().mean().item() |
| print(f"3. COUNTERFACTUAL (alpha=0.5): max|Δlogits| chanA-vs-chanB = {cf:.3f} | next-token flip = {flip:.0%} " |
| f"-> {'channel drives output ✓' if cf>1e-2 else 'channel inert ✗'}") |
| print("\nVERDICT: first MARS-on-Gemma-4-E4B wired —", |
| "bit-exact base, gates trainable, channel measurably drives output." if diff<1e-4 and galpha>0 and cf>1e-2 |
| else "CHECK above.") |
|
|
| if __name__ == "__main__": |
| main() |
|
|