celiumsAI commited on
Commit
3d57ce4
·
verified ·
1 Parent(s): 0cfaf0d

Upload modeling_channels.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_channels.py +140 -0
modeling_channels.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """mars_e4b.py — the FIRST MARS-on-Gemma-4-E4B: re-fit the Apuesta-1 ChannelInjection (cross-attn+ReZero)
2
+ to E4B dims (hidden 2560, 42 layers), wire it onto the frozen base, and MEASURE the wiring:
3
+ 1. bit-exact at alpha=0 (the residual add is a clean no-op → base unchanged) [correctness]
4
+ 2. trainable param count (the adapter is small; base frozen)
5
+ 3. gradient-flow (one backward → ReZero alphas receive grad → trainable end-to-end)
6
+ 4. channel counterfactual (bump alpha; same prompt + DIFFERENT channel → DIFFERENT logits) [the MARS property]
7
+ Ported verbatim from training/train_apuesta1.py (iter-6 gated cross-attention). CPU, c-60.
8
+ """
9
+ import sys, time, torch, torch.nn as nn, torch.nn.functional as F
10
+ from dataclasses import dataclass, field
11
+ torch.manual_seed(0)
12
+ MP = "/root/gemma-4-e4b-it"; torch.set_num_threads(30)
13
+
14
+ CHANNEL_ORDER = [("memory",1024),("affect",2),("time",16),("ethics",24),("identity",1024),("continuity",1024)]
15
+
16
+ @dataclass
17
+ class InjectionConfig:
18
+ hidden_dim: int
19
+ channel_inner_dim: int = 192
20
+ channel_source_dims: dict = field(default_factory=lambda: dict(CHANNEL_ORDER))
21
+
22
+ class ChannelInjectionDelta(nn.Module):
23
+ K_PER_CHANNEL = {"memory":8,"identity":8,"continuity":8,"time":3,"ethics":1,"affect":1}
24
+ def __init__(self, cfg):
25
+ super().__init__()
26
+ self.channel_names = list(cfg.channel_source_dims.keys()); self.n_ch = len(self.channel_names)
27
+ self.src_dims = [cfg.channel_source_dims[n] for n in self.channel_names]; self.max_src = max(self.src_dims)
28
+ self.k_per = [self.K_PER_CHANNEL[n] for n in self.channel_names]; self.max_k = max(self.k_per)
29
+ h, c = cfg.hidden_dim, cfg.channel_inner_dim; self.c = c; self.scale = c**-0.5
30
+ self.pool_weight = nn.Parameter(torch.zeros(self.n_ch, c, self.max_src))
31
+ self.expand_k = nn.Parameter(torch.zeros(self.n_ch, self.max_k, c, c))
32
+ self.expand_v = nn.Parameter(torch.zeros(self.n_ch, self.max_k, c, c))
33
+ self.q_proj = nn.Parameter(torch.zeros(c, h)); self.o_proj = nn.Parameter(torch.zeros(h, c))
34
+ self.alphas = nn.Parameter(torch.zeros(self.n_ch))
35
+ mask = torch.full((self.n_ch, self.max_k), float("-inf"))
36
+ for i, k in enumerate(self.k_per): mask[i, :k] = 0.0
37
+ self.register_buffer("key_mask", mask, persistent=False)
38
+ for i, d in enumerate(self.src_dims): nn.init.normal_(self.pool_weight.data[i, :, :d], std=d**-0.5)
39
+ nn.init.normal_(self.expand_k, std=c**-0.5); nn.init.normal_(self.expand_v, std=c**-0.5)
40
+ nn.init.normal_(self.q_proj, std=h**-0.5); nn.init.normal_(self.o_proj, std=c**-0.5)
41
+ def forward(self, hidden, channels):
42
+ dt = hidden.dtype
43
+ ch = torch.stack([F.pad(channels[n].to(dt), (0, self.max_src - self.src_dims[i]))
44
+ for i, n in enumerate(self.channel_names)], dim=1)
45
+ pw,ek,ev = self.pool_weight.to(dt),self.expand_k.to(dt),self.expand_v.to(dt)
46
+ qp,op,al = self.q_proj.to(dt),self.o_proj.to(dt),self.alphas.to(dt); km = self.key_mask.to(dt)
47
+ s = torch.einsum("bns,ncs->bnc", ch, pw)
48
+ kt = torch.einsum("bnc,nkdc->bnkd", s, ek); vt = torch.einsum("bnc,nkdc->bnkd", s, ev)
49
+ q = torch.einsum("bth,ch->btc", hidden, qp)
50
+ scores = torch.einsum("btc,bnkc->bntk", q, kt) * self.scale + km.view(1,self.n_ch,1,self.max_k)
51
+ attn = scores.softmax(dim=-1); ctx = torch.einsum("bntk,bnkc->bntc", attn, vt)
52
+ gated = (ctx * al.view(1,self.n_ch,1,1)).sum(dim=1)
53
+ return torch.einsum("btc,hc->bth", gated, op)
54
+
55
+ class ChannelHolder:
56
+ def __init__(self): self.channels = None
57
+
58
+ class ChanneledLayer(nn.Module):
59
+ def __init__(self, base_layer, cfg, holder):
60
+ super().__init__()
61
+ self.base = base_layer
62
+ self.norm_channels = nn.RMSNorm(cfg.hidden_dim)
63
+ self.channel_inj = ChannelInjectionDelta(cfg); self._holder = holder
64
+ def forward(self, *args, **kwargs):
65
+ out = self.base(*args, **kwargs)
66
+ hidden = out[0] if isinstance(out, tuple) else out
67
+ chans = self._holder.channels
68
+ if chans is not None:
69
+ normed = self.norm_channels(hidden.float()).to(hidden.dtype)
70
+ hidden = hidden + self.channel_inj(normed, chans).to(hidden.dtype)
71
+ return (hidden, *out[1:]) if isinstance(out, tuple) else hidden
72
+
73
+ def find_layers(model):
74
+ import torch.nn as nn
75
+ best = None
76
+ for name, mod in model.named_modules():
77
+ if isinstance(mod, nn.ModuleList) and len(mod) >= 20:
78
+ if best is None or len(mod) > len(best[1]): best = (name, mod)
79
+ return best
80
+
81
+ def main():
82
+ from transformers import AutoTokenizer, AutoModelForCausalLM
83
+ tok = AutoTokenizer.from_pretrained(MP)
84
+ t0 = time.time()
85
+ model = AutoModelForCausalLM.from_pretrained(MP, dtype=torch.bfloat16, low_cpu_mem_usage=True).eval()
86
+ H = model.config.text_config.hidden_size if hasattr(model.config, "text_config") else model.config.hidden_size
87
+ name, layers = find_layers(model)
88
+ print(f"loaded in {time.time()-t0:.0f}s | hidden_dim={H} | layers='{name}' n={len(layers)}")
89
+ for p in model.parameters(): p.requires_grad_(False)
90
+ holder = ChannelHolder(); cfg = InjectionConfig(hidden_dim=H)
91
+ for i in range(len(layers)):
92
+ layers[i] = ChanneledLayer(layers[i], cfg, holder)
93
+ trainable = [p for p in model.parameters() if p.requires_grad]
94
+ n_tr = sum(p.numel() for p in trainable)
95
+ print(f"MARS wired: {len(layers)} ChannelInjection blocks | trainable params {n_tr/1e6:.2f}M "
96
+ f"({n_tr/7.94e9*100:.3f}% of base) | params f32: {all(p.dtype==torch.float32 for p in trainable)}")
97
+
98
+ ids = tok("def fibonacci(n):", return_tensors="pt").input_ids
99
+ B, T = ids.shape
100
+ def chans(seed):
101
+ g = torch.Generator().manual_seed(seed)
102
+ d = {n: torch.zeros(B, dim) for n, dim in CHANNEL_ORDER}
103
+ d["identity"] = torch.randn(B, 1024, generator=g); return d
104
+ zeros = {n: torch.zeros(B, dim) for n, dim in CHANNEL_ORDER}
105
+
106
+ # 1. bit-exact at alpha=0
107
+ with torch.no_grad():
108
+ holder.channels = None; base_logits = model(ids).logits
109
+ holder.channels = zeros; mars_logits = model(ids).logits
110
+ holder.channels = None
111
+ diff = (base_logits - mars_logits).abs().max().item()
112
+ print(f"\n1. BIT-EXACT @alpha=0: max|Δlogits| = {diff:.2e} -> {'BIT-EXACT ✓' if diff < 1e-4 else 'NOT bit-exact ✗'}")
113
+
114
+ # 2/3. gradient-flow: one backward with channels -> alphas receive grad
115
+ holder.channels = chans(1)
116
+ out = model(ids).logits
117
+ loss = F.cross_entropy(out[:, -1].float(), torch.tensor([100]*B))
118
+ loss.backward()
119
+ galpha = sum((l.channel_inj.alphas.grad.abs().sum().item()
120
+ for l in layers if l.channel_inj.alphas.grad is not None))
121
+ print(f"2. GRAD-FLOW: Σ|alpha.grad| across 42 gates = {galpha:.3e} -> {'gates trainable ✓' if galpha>0 else 'NO grad ✗'}")
122
+ model.zero_grad(set_to_none=True)
123
+
124
+ # 4. channel counterfactual: bump alpha, SAME prompt + DIFFERENT channel -> DIFFERENT logits
125
+ for l in layers:
126
+ with torch.no_grad(): l.channel_inj.alphas.fill_(0.5)
127
+ with torch.no_grad():
128
+ holder.channels = chans(1); logA = model(ids).logits
129
+ holder.channels = chans(2); logB = model(ids).logits
130
+ holder.channels = None
131
+ cf = (logA - logB).abs().max().item()
132
+ flip = (logA[:, -1].argmax(-1) != logB[:, -1].argmax(-1)).float().mean().item()
133
+ print(f"3. COUNTERFACTUAL (alpha=0.5): max|Δlogits| chanA-vs-chanB = {cf:.3f} | next-token flip = {flip:.0%} "
134
+ f"-> {'channel drives output ✓' if cf>1e-2 else 'channel inert ✗'}")
135
+ print("\nVERDICT: first MARS-on-Gemma-4-E4B wired —",
136
+ "bit-exact base, gates trainable, channel measurably drives output." if diff<1e-4 and galpha>0 and cf>1e-2
137
+ else "CHECK above.")
138
+
139
+ if __name__ == "__main__":
140
+ main()