Upload modeling_channels.py with huggingface_hub
Browse files- modeling_channels.py +140 -0
modeling_channels.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""mars_e4b.py — the FIRST MARS-on-Gemma-4-E4B: re-fit the Apuesta-1 ChannelInjection (cross-attn+ReZero)
|
| 2 |
+
to E4B dims (hidden 2560, 42 layers), wire it onto the frozen base, and MEASURE the wiring:
|
| 3 |
+
1. bit-exact at alpha=0 (the residual add is a clean no-op → base unchanged) [correctness]
|
| 4 |
+
2. trainable param count (the adapter is small; base frozen)
|
| 5 |
+
3. gradient-flow (one backward → ReZero alphas receive grad → trainable end-to-end)
|
| 6 |
+
4. channel counterfactual (bump alpha; same prompt + DIFFERENT channel → DIFFERENT logits) [the MARS property]
|
| 7 |
+
Ported verbatim from training/train_apuesta1.py (iter-6 gated cross-attention). CPU, c-60.
|
| 8 |
+
"""
|
| 9 |
+
import sys, time, torch, torch.nn as nn, torch.nn.functional as F
|
| 10 |
+
from dataclasses import dataclass, field
|
| 11 |
+
torch.manual_seed(0)
|
| 12 |
+
MP = "/root/gemma-4-e4b-it"; torch.set_num_threads(30)
|
| 13 |
+
|
| 14 |
+
CHANNEL_ORDER = [("memory",1024),("affect",2),("time",16),("ethics",24),("identity",1024),("continuity",1024)]
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class InjectionConfig:
|
| 18 |
+
hidden_dim: int
|
| 19 |
+
channel_inner_dim: int = 192
|
| 20 |
+
channel_source_dims: dict = field(default_factory=lambda: dict(CHANNEL_ORDER))
|
| 21 |
+
|
| 22 |
+
class ChannelInjectionDelta(nn.Module):
|
| 23 |
+
K_PER_CHANNEL = {"memory":8,"identity":8,"continuity":8,"time":3,"ethics":1,"affect":1}
|
| 24 |
+
def __init__(self, cfg):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.channel_names = list(cfg.channel_source_dims.keys()); self.n_ch = len(self.channel_names)
|
| 27 |
+
self.src_dims = [cfg.channel_source_dims[n] for n in self.channel_names]; self.max_src = max(self.src_dims)
|
| 28 |
+
self.k_per = [self.K_PER_CHANNEL[n] for n in self.channel_names]; self.max_k = max(self.k_per)
|
| 29 |
+
h, c = cfg.hidden_dim, cfg.channel_inner_dim; self.c = c; self.scale = c**-0.5
|
| 30 |
+
self.pool_weight = nn.Parameter(torch.zeros(self.n_ch, c, self.max_src))
|
| 31 |
+
self.expand_k = nn.Parameter(torch.zeros(self.n_ch, self.max_k, c, c))
|
| 32 |
+
self.expand_v = nn.Parameter(torch.zeros(self.n_ch, self.max_k, c, c))
|
| 33 |
+
self.q_proj = nn.Parameter(torch.zeros(c, h)); self.o_proj = nn.Parameter(torch.zeros(h, c))
|
| 34 |
+
self.alphas = nn.Parameter(torch.zeros(self.n_ch))
|
| 35 |
+
mask = torch.full((self.n_ch, self.max_k), float("-inf"))
|
| 36 |
+
for i, k in enumerate(self.k_per): mask[i, :k] = 0.0
|
| 37 |
+
self.register_buffer("key_mask", mask, persistent=False)
|
| 38 |
+
for i, d in enumerate(self.src_dims): nn.init.normal_(self.pool_weight.data[i, :, :d], std=d**-0.5)
|
| 39 |
+
nn.init.normal_(self.expand_k, std=c**-0.5); nn.init.normal_(self.expand_v, std=c**-0.5)
|
| 40 |
+
nn.init.normal_(self.q_proj, std=h**-0.5); nn.init.normal_(self.o_proj, std=c**-0.5)
|
| 41 |
+
def forward(self, hidden, channels):
|
| 42 |
+
dt = hidden.dtype
|
| 43 |
+
ch = torch.stack([F.pad(channels[n].to(dt), (0, self.max_src - self.src_dims[i]))
|
| 44 |
+
for i, n in enumerate(self.channel_names)], dim=1)
|
| 45 |
+
pw,ek,ev = self.pool_weight.to(dt),self.expand_k.to(dt),self.expand_v.to(dt)
|
| 46 |
+
qp,op,al = self.q_proj.to(dt),self.o_proj.to(dt),self.alphas.to(dt); km = self.key_mask.to(dt)
|
| 47 |
+
s = torch.einsum("bns,ncs->bnc", ch, pw)
|
| 48 |
+
kt = torch.einsum("bnc,nkdc->bnkd", s, ek); vt = torch.einsum("bnc,nkdc->bnkd", s, ev)
|
| 49 |
+
q = torch.einsum("bth,ch->btc", hidden, qp)
|
| 50 |
+
scores = torch.einsum("btc,bnkc->bntk", q, kt) * self.scale + km.view(1,self.n_ch,1,self.max_k)
|
| 51 |
+
attn = scores.softmax(dim=-1); ctx = torch.einsum("bntk,bnkc->bntc", attn, vt)
|
| 52 |
+
gated = (ctx * al.view(1,self.n_ch,1,1)).sum(dim=1)
|
| 53 |
+
return torch.einsum("btc,hc->bth", gated, op)
|
| 54 |
+
|
| 55 |
+
class ChannelHolder:
|
| 56 |
+
def __init__(self): self.channels = None
|
| 57 |
+
|
| 58 |
+
class ChanneledLayer(nn.Module):
|
| 59 |
+
def __init__(self, base_layer, cfg, holder):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.base = base_layer
|
| 62 |
+
self.norm_channels = nn.RMSNorm(cfg.hidden_dim)
|
| 63 |
+
self.channel_inj = ChannelInjectionDelta(cfg); self._holder = holder
|
| 64 |
+
def forward(self, *args, **kwargs):
|
| 65 |
+
out = self.base(*args, **kwargs)
|
| 66 |
+
hidden = out[0] if isinstance(out, tuple) else out
|
| 67 |
+
chans = self._holder.channels
|
| 68 |
+
if chans is not None:
|
| 69 |
+
normed = self.norm_channels(hidden.float()).to(hidden.dtype)
|
| 70 |
+
hidden = hidden + self.channel_inj(normed, chans).to(hidden.dtype)
|
| 71 |
+
return (hidden, *out[1:]) if isinstance(out, tuple) else hidden
|
| 72 |
+
|
| 73 |
+
def find_layers(model):
|
| 74 |
+
import torch.nn as nn
|
| 75 |
+
best = None
|
| 76 |
+
for name, mod in model.named_modules():
|
| 77 |
+
if isinstance(mod, nn.ModuleList) and len(mod) >= 20:
|
| 78 |
+
if best is None or len(mod) > len(best[1]): best = (name, mod)
|
| 79 |
+
return best
|
| 80 |
+
|
| 81 |
+
def main():
|
| 82 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 83 |
+
tok = AutoTokenizer.from_pretrained(MP)
|
| 84 |
+
t0 = time.time()
|
| 85 |
+
model = AutoModelForCausalLM.from_pretrained(MP, dtype=torch.bfloat16, low_cpu_mem_usage=True).eval()
|
| 86 |
+
H = model.config.text_config.hidden_size if hasattr(model.config, "text_config") else model.config.hidden_size
|
| 87 |
+
name, layers = find_layers(model)
|
| 88 |
+
print(f"loaded in {time.time()-t0:.0f}s | hidden_dim={H} | layers='{name}' n={len(layers)}")
|
| 89 |
+
for p in model.parameters(): p.requires_grad_(False)
|
| 90 |
+
holder = ChannelHolder(); cfg = InjectionConfig(hidden_dim=H)
|
| 91 |
+
for i in range(len(layers)):
|
| 92 |
+
layers[i] = ChanneledLayer(layers[i], cfg, holder)
|
| 93 |
+
trainable = [p for p in model.parameters() if p.requires_grad]
|
| 94 |
+
n_tr = sum(p.numel() for p in trainable)
|
| 95 |
+
print(f"MARS wired: {len(layers)} ChannelInjection blocks | trainable params {n_tr/1e6:.2f}M "
|
| 96 |
+
f"({n_tr/7.94e9*100:.3f}% of base) | params f32: {all(p.dtype==torch.float32 for p in trainable)}")
|
| 97 |
+
|
| 98 |
+
ids = tok("def fibonacci(n):", return_tensors="pt").input_ids
|
| 99 |
+
B, T = ids.shape
|
| 100 |
+
def chans(seed):
|
| 101 |
+
g = torch.Generator().manual_seed(seed)
|
| 102 |
+
d = {n: torch.zeros(B, dim) for n, dim in CHANNEL_ORDER}
|
| 103 |
+
d["identity"] = torch.randn(B, 1024, generator=g); return d
|
| 104 |
+
zeros = {n: torch.zeros(B, dim) for n, dim in CHANNEL_ORDER}
|
| 105 |
+
|
| 106 |
+
# 1. bit-exact at alpha=0
|
| 107 |
+
with torch.no_grad():
|
| 108 |
+
holder.channels = None; base_logits = model(ids).logits
|
| 109 |
+
holder.channels = zeros; mars_logits = model(ids).logits
|
| 110 |
+
holder.channels = None
|
| 111 |
+
diff = (base_logits - mars_logits).abs().max().item()
|
| 112 |
+
print(f"\n1. BIT-EXACT @alpha=0: max|Δlogits| = {diff:.2e} -> {'BIT-EXACT ✓' if diff < 1e-4 else 'NOT bit-exact ✗'}")
|
| 113 |
+
|
| 114 |
+
# 2/3. gradient-flow: one backward with channels -> alphas receive grad
|
| 115 |
+
holder.channels = chans(1)
|
| 116 |
+
out = model(ids).logits
|
| 117 |
+
loss = F.cross_entropy(out[:, -1].float(), torch.tensor([100]*B))
|
| 118 |
+
loss.backward()
|
| 119 |
+
galpha = sum((l.channel_inj.alphas.grad.abs().sum().item()
|
| 120 |
+
for l in layers if l.channel_inj.alphas.grad is not None))
|
| 121 |
+
print(f"2. GRAD-FLOW: Σ|alpha.grad| across 42 gates = {galpha:.3e} -> {'gates trainable ✓' if galpha>0 else 'NO grad ✗'}")
|
| 122 |
+
model.zero_grad(set_to_none=True)
|
| 123 |
+
|
| 124 |
+
# 4. channel counterfactual: bump alpha, SAME prompt + DIFFERENT channel -> DIFFERENT logits
|
| 125 |
+
for l in layers:
|
| 126 |
+
with torch.no_grad(): l.channel_inj.alphas.fill_(0.5)
|
| 127 |
+
with torch.no_grad():
|
| 128 |
+
holder.channels = chans(1); logA = model(ids).logits
|
| 129 |
+
holder.channels = chans(2); logB = model(ids).logits
|
| 130 |
+
holder.channels = None
|
| 131 |
+
cf = (logA - logB).abs().max().item()
|
| 132 |
+
flip = (logA[:, -1].argmax(-1) != logB[:, -1].argmax(-1)).float().mean().item()
|
| 133 |
+
print(f"3. COUNTERFACTUAL (alpha=0.5): max|Δlogits| chanA-vs-chanB = {cf:.3f} | next-token flip = {flip:.0%} "
|
| 134 |
+
f"-> {'channel drives output ✓' if cf>1e-2 else 'channel inert ✗'}")
|
| 135 |
+
print("\nVERDICT: first MARS-on-Gemma-4-E4B wired —",
|
| 136 |
+
"bit-exact base, gates trainable, channel measurably drives output." if diff<1e-4 and galpha>0 and cf>1e-2
|
| 137 |
+
else "CHECK above.")
|
| 138 |
+
|
| 139 |
+
if __name__ == "__main__":
|
| 140 |
+
main()
|