File size: 7,844 Bytes
3d57ce4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | """mars_e4b.py — the FIRST MARS-on-Gemma-4-E4B: re-fit the Apuesta-1 ChannelInjection (cross-attn+ReZero)
to E4B dims (hidden 2560, 42 layers), wire it onto the frozen base, and MEASURE the wiring:
1. bit-exact at alpha=0 (the residual add is a clean no-op → base unchanged) [correctness]
2. trainable param count (the adapter is small; base frozen)
3. gradient-flow (one backward → ReZero alphas receive grad → trainable end-to-end)
4. channel counterfactual (bump alpha; same prompt + DIFFERENT channel → DIFFERENT logits) [the MARS property]
Ported verbatim from training/train_apuesta1.py (iter-6 gated cross-attention). CPU, c-60.
"""
import sys, time, torch, torch.nn as nn, torch.nn.functional as F
from dataclasses import dataclass, field
torch.manual_seed(0)
MP = "/root/gemma-4-e4b-it"; torch.set_num_threads(30)
CHANNEL_ORDER = [("memory",1024),("affect",2),("time",16),("ethics",24),("identity",1024),("continuity",1024)]
@dataclass
class InjectionConfig:
hidden_dim: int
channel_inner_dim: int = 192
channel_source_dims: dict = field(default_factory=lambda: dict(CHANNEL_ORDER))
class ChannelInjectionDelta(nn.Module):
K_PER_CHANNEL = {"memory":8,"identity":8,"continuity":8,"time":3,"ethics":1,"affect":1}
def __init__(self, cfg):
super().__init__()
self.channel_names = list(cfg.channel_source_dims.keys()); self.n_ch = len(self.channel_names)
self.src_dims = [cfg.channel_source_dims[n] for n in self.channel_names]; self.max_src = max(self.src_dims)
self.k_per = [self.K_PER_CHANNEL[n] for n in self.channel_names]; self.max_k = max(self.k_per)
h, c = cfg.hidden_dim, cfg.channel_inner_dim; self.c = c; self.scale = c**-0.5
self.pool_weight = nn.Parameter(torch.zeros(self.n_ch, c, self.max_src))
self.expand_k = nn.Parameter(torch.zeros(self.n_ch, self.max_k, c, c))
self.expand_v = nn.Parameter(torch.zeros(self.n_ch, self.max_k, c, c))
self.q_proj = nn.Parameter(torch.zeros(c, h)); self.o_proj = nn.Parameter(torch.zeros(h, c))
self.alphas = nn.Parameter(torch.zeros(self.n_ch))
mask = torch.full((self.n_ch, self.max_k), float("-inf"))
for i, k in enumerate(self.k_per): mask[i, :k] = 0.0
self.register_buffer("key_mask", mask, persistent=False)
for i, d in enumerate(self.src_dims): nn.init.normal_(self.pool_weight.data[i, :, :d], std=d**-0.5)
nn.init.normal_(self.expand_k, std=c**-0.5); nn.init.normal_(self.expand_v, std=c**-0.5)
nn.init.normal_(self.q_proj, std=h**-0.5); nn.init.normal_(self.o_proj, std=c**-0.5)
def forward(self, hidden, channels):
dt = hidden.dtype
ch = torch.stack([F.pad(channels[n].to(dt), (0, self.max_src - self.src_dims[i]))
for i, n in enumerate(self.channel_names)], dim=1)
pw,ek,ev = self.pool_weight.to(dt),self.expand_k.to(dt),self.expand_v.to(dt)
qp,op,al = self.q_proj.to(dt),self.o_proj.to(dt),self.alphas.to(dt); km = self.key_mask.to(dt)
s = torch.einsum("bns,ncs->bnc", ch, pw)
kt = torch.einsum("bnc,nkdc->bnkd", s, ek); vt = torch.einsum("bnc,nkdc->bnkd", s, ev)
q = torch.einsum("bth,ch->btc", hidden, qp)
scores = torch.einsum("btc,bnkc->bntk", q, kt) * self.scale + km.view(1,self.n_ch,1,self.max_k)
attn = scores.softmax(dim=-1); ctx = torch.einsum("bntk,bnkc->bntc", attn, vt)
gated = (ctx * al.view(1,self.n_ch,1,1)).sum(dim=1)
return torch.einsum("btc,hc->bth", gated, op)
class ChannelHolder:
def __init__(self): self.channels = None
class ChanneledLayer(nn.Module):
def __init__(self, base_layer, cfg, holder):
super().__init__()
self.base = base_layer
self.norm_channels = nn.RMSNorm(cfg.hidden_dim)
self.channel_inj = ChannelInjectionDelta(cfg); self._holder = holder
def forward(self, *args, **kwargs):
out = self.base(*args, **kwargs)
hidden = out[0] if isinstance(out, tuple) else out
chans = self._holder.channels
if chans is not None:
normed = self.norm_channels(hidden.float()).to(hidden.dtype)
hidden = hidden + self.channel_inj(normed, chans).to(hidden.dtype)
return (hidden, *out[1:]) if isinstance(out, tuple) else hidden
def find_layers(model):
import torch.nn as nn
best = None
for name, mod in model.named_modules():
if isinstance(mod, nn.ModuleList) and len(mod) >= 20:
if best is None or len(mod) > len(best[1]): best = (name, mod)
return best
def main():
from transformers import AutoTokenizer, AutoModelForCausalLM
tok = AutoTokenizer.from_pretrained(MP)
t0 = time.time()
model = AutoModelForCausalLM.from_pretrained(MP, dtype=torch.bfloat16, low_cpu_mem_usage=True).eval()
H = model.config.text_config.hidden_size if hasattr(model.config, "text_config") else model.config.hidden_size
name, layers = find_layers(model)
print(f"loaded in {time.time()-t0:.0f}s | hidden_dim={H} | layers='{name}' n={len(layers)}")
for p in model.parameters(): p.requires_grad_(False)
holder = ChannelHolder(); cfg = InjectionConfig(hidden_dim=H)
for i in range(len(layers)):
layers[i] = ChanneledLayer(layers[i], cfg, holder)
trainable = [p for p in model.parameters() if p.requires_grad]
n_tr = sum(p.numel() for p in trainable)
print(f"MARS wired: {len(layers)} ChannelInjection blocks | trainable params {n_tr/1e6:.2f}M "
f"({n_tr/7.94e9*100:.3f}% of base) | params f32: {all(p.dtype==torch.float32 for p in trainable)}")
ids = tok("def fibonacci(n):", return_tensors="pt").input_ids
B, T = ids.shape
def chans(seed):
g = torch.Generator().manual_seed(seed)
d = {n: torch.zeros(B, dim) for n, dim in CHANNEL_ORDER}
d["identity"] = torch.randn(B, 1024, generator=g); return d
zeros = {n: torch.zeros(B, dim) for n, dim in CHANNEL_ORDER}
# 1. bit-exact at alpha=0
with torch.no_grad():
holder.channels = None; base_logits = model(ids).logits
holder.channels = zeros; mars_logits = model(ids).logits
holder.channels = None
diff = (base_logits - mars_logits).abs().max().item()
print(f"\n1. BIT-EXACT @alpha=0: max|Δlogits| = {diff:.2e} -> {'BIT-EXACT ✓' if diff < 1e-4 else 'NOT bit-exact ✗'}")
# 2/3. gradient-flow: one backward with channels -> alphas receive grad
holder.channels = chans(1)
out = model(ids).logits
loss = F.cross_entropy(out[:, -1].float(), torch.tensor([100]*B))
loss.backward()
galpha = sum((l.channel_inj.alphas.grad.abs().sum().item()
for l in layers if l.channel_inj.alphas.grad is not None))
print(f"2. GRAD-FLOW: Σ|alpha.grad| across 42 gates = {galpha:.3e} -> {'gates trainable ✓' if galpha>0 else 'NO grad ✗'}")
model.zero_grad(set_to_none=True)
# 4. channel counterfactual: bump alpha, SAME prompt + DIFFERENT channel -> DIFFERENT logits
for l in layers:
with torch.no_grad(): l.channel_inj.alphas.fill_(0.5)
with torch.no_grad():
holder.channels = chans(1); logA = model(ids).logits
holder.channels = chans(2); logB = model(ids).logits
holder.channels = None
cf = (logA - logB).abs().max().item()
flip = (logA[:, -1].argmax(-1) != logB[:, -1].argmax(-1)).float().mean().item()
print(f"3. COUNTERFACTUAL (alpha=0.5): max|Δlogits| chanA-vs-chanB = {cf:.3f} | next-token flip = {flip:.0%} "
f"-> {'channel drives output ✓' if cf>1e-2 else 'channel inert ✗'}")
print("\nVERDICT: first MARS-on-Gemma-4-E4B wired —",
"bit-exact base, gates trainable, channel measurably drives output." if diff<1e-4 and galpha>0 and cf>1e-2
else "CHECK above.")
if __name__ == "__main__":
main()
|