File size: 7,844 Bytes
3d57ce4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""mars_e4b.py — the FIRST MARS-on-Gemma-4-E4B: re-fit the Apuesta-1 ChannelInjection (cross-attn+ReZero)
to E4B dims (hidden 2560, 42 layers), wire it onto the frozen base, and MEASURE the wiring:
  1. bit-exact at alpha=0   (the residual add is a clean no-op → base unchanged)   [correctness]
  2. trainable param count   (the adapter is small; base frozen)
  3. gradient-flow           (one backward → ReZero alphas receive grad → trainable end-to-end)
  4. channel counterfactual  (bump alpha; same prompt + DIFFERENT channel → DIFFERENT logits)  [the MARS property]
Ported verbatim from training/train_apuesta1.py (iter-6 gated cross-attention). CPU, c-60.
"""
import sys, time, torch, torch.nn as nn, torch.nn.functional as F
from dataclasses import dataclass, field
torch.manual_seed(0)
MP = "/root/gemma-4-e4b-it"; torch.set_num_threads(30)

CHANNEL_ORDER = [("memory",1024),("affect",2),("time",16),("ethics",24),("identity",1024),("continuity",1024)]

@dataclass
class InjectionConfig:
    hidden_dim: int
    channel_inner_dim: int = 192
    channel_source_dims: dict = field(default_factory=lambda: dict(CHANNEL_ORDER))

class ChannelInjectionDelta(nn.Module):
    K_PER_CHANNEL = {"memory":8,"identity":8,"continuity":8,"time":3,"ethics":1,"affect":1}
    def __init__(self, cfg):
        super().__init__()
        self.channel_names = list(cfg.channel_source_dims.keys()); self.n_ch = len(self.channel_names)
        self.src_dims = [cfg.channel_source_dims[n] for n in self.channel_names]; self.max_src = max(self.src_dims)
        self.k_per = [self.K_PER_CHANNEL[n] for n in self.channel_names]; self.max_k = max(self.k_per)
        h, c = cfg.hidden_dim, cfg.channel_inner_dim; self.c = c; self.scale = c**-0.5
        self.pool_weight = nn.Parameter(torch.zeros(self.n_ch, c, self.max_src))
        self.expand_k = nn.Parameter(torch.zeros(self.n_ch, self.max_k, c, c))
        self.expand_v = nn.Parameter(torch.zeros(self.n_ch, self.max_k, c, c))
        self.q_proj = nn.Parameter(torch.zeros(c, h)); self.o_proj = nn.Parameter(torch.zeros(h, c))
        self.alphas = nn.Parameter(torch.zeros(self.n_ch))
        mask = torch.full((self.n_ch, self.max_k), float("-inf"))
        for i, k in enumerate(self.k_per): mask[i, :k] = 0.0
        self.register_buffer("key_mask", mask, persistent=False)
        for i, d in enumerate(self.src_dims): nn.init.normal_(self.pool_weight.data[i, :, :d], std=d**-0.5)
        nn.init.normal_(self.expand_k, std=c**-0.5); nn.init.normal_(self.expand_v, std=c**-0.5)
        nn.init.normal_(self.q_proj, std=h**-0.5); nn.init.normal_(self.o_proj, std=c**-0.5)
    def forward(self, hidden, channels):
        dt = hidden.dtype
        ch = torch.stack([F.pad(channels[n].to(dt), (0, self.max_src - self.src_dims[i]))
                          for i, n in enumerate(self.channel_names)], dim=1)
        pw,ek,ev = self.pool_weight.to(dt),self.expand_k.to(dt),self.expand_v.to(dt)
        qp,op,al = self.q_proj.to(dt),self.o_proj.to(dt),self.alphas.to(dt); km = self.key_mask.to(dt)
        s = torch.einsum("bns,ncs->bnc", ch, pw)
        kt = torch.einsum("bnc,nkdc->bnkd", s, ek); vt = torch.einsum("bnc,nkdc->bnkd", s, ev)
        q = torch.einsum("bth,ch->btc", hidden, qp)
        scores = torch.einsum("btc,bnkc->bntk", q, kt) * self.scale + km.view(1,self.n_ch,1,self.max_k)
        attn = scores.softmax(dim=-1); ctx = torch.einsum("bntk,bnkc->bntc", attn, vt)
        gated = (ctx * al.view(1,self.n_ch,1,1)).sum(dim=1)
        return torch.einsum("btc,hc->bth", gated, op)

class ChannelHolder:
    def __init__(self): self.channels = None

class ChanneledLayer(nn.Module):
    def __init__(self, base_layer, cfg, holder):
        super().__init__()
        self.base = base_layer
        self.norm_channels = nn.RMSNorm(cfg.hidden_dim)
        self.channel_inj = ChannelInjectionDelta(cfg); self._holder = holder
    def forward(self, *args, **kwargs):
        out = self.base(*args, **kwargs)
        hidden = out[0] if isinstance(out, tuple) else out
        chans = self._holder.channels
        if chans is not None:
            normed = self.norm_channels(hidden.float()).to(hidden.dtype)
            hidden = hidden + self.channel_inj(normed, chans).to(hidden.dtype)
        return (hidden, *out[1:]) if isinstance(out, tuple) else hidden

def find_layers(model):
    import torch.nn as nn
    best = None
    for name, mod in model.named_modules():
        if isinstance(mod, nn.ModuleList) and len(mod) >= 20:
            if best is None or len(mod) > len(best[1]): best = (name, mod)
    return best

def main():
    from transformers import AutoTokenizer, AutoModelForCausalLM
    tok = AutoTokenizer.from_pretrained(MP)
    t0 = time.time()
    model = AutoModelForCausalLM.from_pretrained(MP, dtype=torch.bfloat16, low_cpu_mem_usage=True).eval()
    H = model.config.text_config.hidden_size if hasattr(model.config, "text_config") else model.config.hidden_size
    name, layers = find_layers(model)
    print(f"loaded in {time.time()-t0:.0f}s | hidden_dim={H} | layers='{name}' n={len(layers)}")
    for p in model.parameters(): p.requires_grad_(False)
    holder = ChannelHolder(); cfg = InjectionConfig(hidden_dim=H)
    for i in range(len(layers)):
        layers[i] = ChanneledLayer(layers[i], cfg, holder)
    trainable = [p for p in model.parameters() if p.requires_grad]
    n_tr = sum(p.numel() for p in trainable)
    print(f"MARS wired: {len(layers)} ChannelInjection blocks | trainable params {n_tr/1e6:.2f}M "
          f"({n_tr/7.94e9*100:.3f}% of base) | params f32: {all(p.dtype==torch.float32 for p in trainable)}")

    ids = tok("def fibonacci(n):", return_tensors="pt").input_ids
    B, T = ids.shape
    def chans(seed):
        g = torch.Generator().manual_seed(seed)
        d = {n: torch.zeros(B, dim) for n, dim in CHANNEL_ORDER}
        d["identity"] = torch.randn(B, 1024, generator=g); return d
    zeros = {n: torch.zeros(B, dim) for n, dim in CHANNEL_ORDER}

    # 1. bit-exact at alpha=0
    with torch.no_grad():
        holder.channels = None;  base_logits = model(ids).logits
        holder.channels = zeros; mars_logits = model(ids).logits
    holder.channels = None
    diff = (base_logits - mars_logits).abs().max().item()
    print(f"\n1. BIT-EXACT @alpha=0: max|Δlogits| = {diff:.2e}  -> {'BIT-EXACT ✓' if diff < 1e-4 else 'NOT bit-exact ✗'}")

    # 2/3. gradient-flow: one backward with channels -> alphas receive grad
    holder.channels = chans(1)
    out = model(ids).logits
    loss = F.cross_entropy(out[:, -1].float(), torch.tensor([100]*B))
    loss.backward()
    galpha = sum((l.channel_inj.alphas.grad.abs().sum().item()
                  for l in layers if l.channel_inj.alphas.grad is not None))
    print(f"2. GRAD-FLOW: Σ|alpha.grad| across 42 gates = {galpha:.3e}  -> {'gates trainable ✓' if galpha>0 else 'NO grad ✗'}")
    model.zero_grad(set_to_none=True)

    # 4. channel counterfactual: bump alpha, SAME prompt + DIFFERENT channel -> DIFFERENT logits
    for l in layers:
        with torch.no_grad(): l.channel_inj.alphas.fill_(0.5)
    with torch.no_grad():
        holder.channels = chans(1); logA = model(ids).logits
        holder.channels = chans(2); logB = model(ids).logits
    holder.channels = None
    cf = (logA - logB).abs().max().item()
    flip = (logA[:, -1].argmax(-1) != logB[:, -1].argmax(-1)).float().mean().item()
    print(f"3. COUNTERFACTUAL (alpha=0.5): max|Δlogits| chanA-vs-chanB = {cf:.3f} | next-token flip = {flip:.0%} "
          f"-> {'channel drives output ✓' if cf>1e-2 else 'channel inert ✗'}")
    print("\nVERDICT: first MARS-on-Gemma-4-E4B wired —",
          "bit-exact base, gates trainable, channel measurably drives output." if diff<1e-4 and galpha>0 and cf>1e-2
          else "CHECK above.")

if __name__ == "__main__":
    main()