File size: 16,610 Bytes
b14638e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
"""Yaz POC architecture.

Standard byte-level causal transformer (3 blocks, d=128) with one twist:
just before unembed, a top-k=1 "fact atom" dictionary projects the
residual into d_dict=512 atoms, picks the single most-activated one,
and adds that atom's learnable decoder vector back into the residual.

This is the CRUD-target layer:
  - W_dec[:, atom_id]  ⟵  edit this single column = edit that fact
  - zero it             ⟵  delete the fact
  - append a new column ⟵  add a fact
The encoder's bias for atom_id controls "is this fact accessible?".
"""
from __future__ import annotations

import math
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor


@dataclass
class YazConfig:
    vocab_size: int = 256
    d_model: int = 64
    n_layers: int = 3
    n_heads: int = 4
    max_seq_len: int = 128
    ffn_expand: int = 4
    dropout: float = 0.0
    # Fact-atom layer:
    d_dict: int = 128        # number of addressable fact atoms
    fact_top_k: int = 1      # v4: strict 1, paired with anti-collapse machinery
    fact_gain: float = 1.0   # multiplier on the fact atom's contribution
    # Phase 9 multi-byte: each atom owns d_phase value vectors; a shared phase
    # head selects which one fires (by within-answer byte offset). d_phase=1 is
    # exactly the single-vector model (byte-identical), so all Phase 1-8 configs
    # are unaffected.
    d_phase: int = 1
    # Semantic re-keying: a learnable per-atom activation gain used by forward_routed
    # so a forced (Engram-routed) atom can DOMINATE the residual — restoring the edit
    # efficacy that a fixed activation=1.0 destroyed (backbone co-memorization). Only
    # created when use_atom_gain=True, so surface (route_atom=None) models are unchanged.
    use_atom_gain: bool = False
    atom_gain_init: float = 1.0


class CausalSelfAttention(nn.Module):
    def __init__(self, cfg: YazConfig):
        super().__init__()
        assert cfg.d_model % cfg.n_heads == 0
        self.n_heads = cfg.n_heads
        self.d_head = cfg.d_model // cfg.n_heads
        self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
        self.out = nn.Linear(cfg.d_model, cfg.d_model, bias=False)

    def forward(self, x: Tensor) -> Tensor:
        B, T, D = x.shape
        qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.d_head).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # (B, H, T, d_head)
        att = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_head)
        mask = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1)
        att = att.masked_fill(mask, float("-inf"))
        att = F.softmax(att, dim=-1)
        y = att @ v  # (B, H, T, d_head)
        y = y.transpose(1, 2).reshape(B, T, D)
        return self.out(y)


class FFN(nn.Module):
    def __init__(self, cfg: YazConfig):
        super().__init__()
        h = cfg.d_model * cfg.ffn_expand
        self.fc1 = nn.Linear(cfg.d_model, h)
        self.fc2 = nn.Linear(h, cfg.d_model)

    def forward(self, x: Tensor) -> Tensor:
        return self.fc2(F.gelu(self.fc1(x)))


class Block(nn.Module):
    def __init__(self, cfg: YazConfig):
        super().__init__()
        self.ln1 = nn.LayerNorm(cfg.d_model)
        self.attn = CausalSelfAttention(cfg)
        self.ln2 = nn.LayerNorm(cfg.d_model)
        self.ffn = FFN(cfg)

    def forward(self, x: Tensor) -> Tensor:
        x = x + self.attn(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x


class FactAtomLayer(nn.Module):
    """Top-k=1 atom dictionary. Each atom is an addressable fact slot.

    Shapes:
        encoder  W_enc: (d_dict, d_model)  — pick which atom fires
        decoder  W_dec: (d_model, d_dict)  — its contribution direction
        pre_bias        (d_model,)         — centering offset
    """

    def __init__(self, d_model: int, d_dict: int, top_k: int = 1, d_phase: int = 1,
                 use_atom_gain: bool = False, atom_gain_init: float = 1.0):
        super().__init__()
        if top_k < 1 or top_k > d_dict:
            raise ValueError(f"top_k must be in [1, {d_dict}], got {top_k}")
        if d_phase < 1:
            raise ValueError(f"d_phase must be >= 1, got {d_phase}")
        self.d_model = d_model
        self.d_dict = d_dict
        self.top_k = top_k
        self.d_phase = d_phase
        # Learnable per-atom gain for forward_routed (semantic re-keying). Absent by
        # default so legacy/surface state_dicts are byte-identical.
        self.atom_gain = (nn.Parameter(torch.full((d_dict,), float(atom_gain_init)))
                          if use_atom_gain else None)
        self.W_enc = nn.Linear(d_model, d_dict, bias=True)
        self.W_dec = nn.Linear(d_dict, d_model, bias=False)
        self.pre_bias = nn.Parameter(torch.zeros(d_model))
        # Tie initial weights — encoder = decoder.T, unit-norm columns.
        with torch.no_grad():
            w = torch.randn(d_dict, d_model)
            w = w / w.norm(dim=1, keepdim=True).clamp_min(1e-6)
            self.W_dec.weight.copy_(w.t())
            self.W_enc.weight.copy_(w)
            self.W_enc.bias.zero_()
        # Phase 9: extra decoder columns for phases 1..d_phase-1 (phase 0 = W_dec),
        # plus a shared phase head. Created ONLY when d_phase>1 so the d_phase=1
        # state_dict is identical to the legacy model. Extra columns start at zero
        # so an untrained multi-byte model == the single-vector model.
        if d_phase > 1:
            self.W_dec_extra = nn.Parameter(torch.zeros(d_phase - 1, d_model, d_dict))
            self.phase_head = nn.Linear(d_model, d_phase, bias=True)

    def encode(self, x: Tensor) -> Tensor:
        """Returns sparse activations (..., d_dict) with only top_k nonzero entries."""
        z = self.W_enc(x - self.pre_bias)
        z = F.relu(z)
        top_vals, top_idx = z.topk(self.top_k, dim=-1)
        sparse = torch.zeros_like(z).scatter_(-1, top_idx, top_vals)
        return sparse

    def encode_dense(self, x: Tensor) -> Tensor:
        """Returns the PRE-topk gate scores (..., d_dict), for load-balancing loss."""
        return F.relu(self.W_enc(x - self.pre_bias))

    def encode_logits(self, x: Tensor) -> Tensor:
        """Returns the PRE-ReLU encoder logits (..., d_dict), for supervised
        target-atom CE loss (Exp 3). Gradient flows freely through every
        atom dim (no ReLU dead-zone)."""
        return self.W_enc(x - self.pre_bias)

    def decode(self, z: Tensor) -> Tensor:
        return self.W_dec(z) + self.pre_bias

    @torch.no_grad()
    def resurrect(self, dead_idx: Tensor, source: Tensor) -> int:
        """Re-init the encoder row + decoder col for each dead atom from a
        random row of `source` (current residuals, shape (N, d_model)).

        Returns the number of atoms actually resurrected.
        """
        if dead_idx.numel() == 0 or source.numel() == 0:
            return 0
        N = source.size(0)
        # Pick |dead_idx| random source rows
        pick = torch.randint(0, N, (dead_idx.numel(),), device=source.device)
        v = source[pick]                          # (k_dead, d_model)
        norms = v.norm(dim=1, keepdim=True).clamp_min(1e-6)
        v = v / norms                             # unit-norm rows
        self.W_enc.weight[dead_idx] = v
        self.W_dec.weight[:, dead_idx] = v.t()    # decoder col = unit dir
        self.W_enc.bias[dead_idx] = 0.0
        return int(dead_idx.numel())

    def orthogonality_loss(self) -> Tensor:
        """Mean squared off-diagonal of normalized W_dec column Gram matrix.
        0 when all decoder columns are orthonormal; bounded by 1.
        """
        W = self.W_dec.weight                     # (d_model, d_dict)
        norms = W.norm(dim=0, keepdim=True).clamp_min(1e-6)
        Wn = W / norms
        G = Wn.t() @ Wn                           # (d_dict, d_dict)
        eye = torch.eye(self.d_dict, device=G.device, dtype=G.dtype)
        off = G - eye
        return (off ** 2).mean()

    def decode_phased(self, x: Tensor):
        """Phase 9 multi-byte path. Pick the top-1 atom (unchanged), then let the
        shared phase head choose which of the atom's d_phase value vectors fires.

        Returns (out, z_sparse, z_dense, phase_logits). For d_phase=1 this is
        algebraically identical to forward()/decode() with a single column.
        """
        z_dense = self.encode_dense(x)                          # (..., d_dict)
        top_vals, top_idx = z_dense.topk(self.top_k, dim=-1)    # top_k==1
        z_sparse = torch.zeros_like(z_dense).scatter_(-1, top_idx, top_vals)
        a = top_vals[..., 0]                                    # (...) activation
        k = top_idx[..., 0]                                     # (...) atom id
        phase_logits = self.phase_head(x)                       # (..., d_phase)
        v = phase_logits.argmax(dim=-1)                         # (...) chosen phase
        # Stacked decoder: D[0] = legacy W_dec, D[1:] = extra phase columns.
        D = torch.cat([self.W_dec.weight.unsqueeze(0), self.W_dec_extra], dim=0)  # (V, d_model, d_dict)
        shp = k.shape
        col = D[v.reshape(-1), :, k.reshape(-1)]                # (N, d_model)
        contrib = (a.reshape(-1, 1) * col).reshape(*shp, self.d_model)
        out = contrib + self.pre_bias
        return out, z_sparse, z_dense, phase_logits

    def forward(self, x: Tensor, return_dense: bool = False) -> tuple[Tensor, Tensor]:
        z_dense = self.encode_dense(x)
        top_vals, top_idx = z_dense.topk(self.top_k, dim=-1)
        z = torch.zeros_like(z_dense).scatter_(-1, top_idx, top_vals)
        out = self.decode(z)
        if return_dense:
            # Use z_dense so the LB-loss has gradient through W_enc.
            return out, z, z_dense
        return out, z

    def forward_routed(self, x: Tensor, route_atom: Tensor, route_pos: Tensor):
        """SEMANTIC RE-KEYING path. At positions where `route_pos` is True, FORCE the
        atom given by `route_atom` (one per sequence) to fire with activation 1.0 instead
        of the learned argmax(ReLU(W_enc·…)) selection. Elsewhere, learned routing is kept
        (so the language prefix is unaffected and W_enc still trains on non-fact tokens).

        This decouples WHICH atom fires from the byte-transformer surface activations:
        the caller picks the atom via a frozen semantic embedding (Engram), so paraphrases
        route to the same atom. W_dec / pre_bias / the additive contribution / CRUD are
        unchanged — only the selection is overridden.

        route_atom: (B,) int64 atom ids.  route_pos: (B, T) bool.
        Returns (out, z_sparse, z_dense) like forward(return_dense=True).
        """
        z_dense = self.encode_dense(x)                                   # (B,T,d_dict)
        top_vals, top_idx = z_dense.topk(self.top_k, dim=-1)
        z = torch.zeros_like(z_dense).scatter_(-1, top_idx, top_vals)    # learned top-1
        B, T, _ = z.shape
        forced = torch.zeros_like(z)
        idx = route_atom.view(B, 1, 1).expand(B, T, 1)
        if self.atom_gain is not None:
            # learnable per-atom magnitude so the forced atom can dominate the residual
            gain = self.atom_gain[route_atom].view(B, 1, 1).expand(B, T, 1)
            forced.scatter_(-1, idx, gain)
        else:
            forced.scatter_(-1, idx, 1.0)                                # legacy: act=1.0
        z = torch.where(route_pos.unsqueeze(-1), forced, z)
        out = self.decode(z)
        return out, z, z_dense


class YazLM(nn.Module):
    def __init__(self, cfg: YazConfig):
        super().__init__()
        self.cfg = cfg
        self.tok_embed = nn.Embedding(cfg.vocab_size, cfg.d_model)
        self.pos_embed = nn.Embedding(cfg.max_seq_len, cfg.d_model)
        self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)])
        self.ln_final = nn.LayerNorm(cfg.d_model)
        self.fact_layer = FactAtomLayer(cfg.d_model, cfg.d_dict, top_k=cfg.fact_top_k,
                                        d_phase=cfg.d_phase, use_atom_gain=cfg.use_atom_gain,
                                        atom_gain_init=cfg.atom_gain_init)
        # Tied embedding for unembed.
        self.unembed = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)

    def forward(self, ids: Tensor, return_fact_z: bool = False, return_dense: bool = False,
                return_atoms_only: bool = False, return_phase: bool = False,
                route_atom: Tensor | None = None, route_pos: Tensor | None = None):
        """ids: (B, T) of token ids.

        Returns logits (B, T, vocab). Optionally also returns the sparse
        fact-z (for inspection / CRUD address lookup), the dense
        pre-topk gate scores (for the load-balancing loss in training),
        the atoms-only logits (unembed of fact_contrib alone,
        used by the v5 atoms-only auxiliary loss), and/or the per-position
        phase_logits (Phase 9 multi-byte; appended last when return_phase=True).
        """
        B, T = ids.shape
        assert T <= self.cfg.max_seq_len, f"T={T} > max_seq_len={self.cfg.max_seq_len}"
        pos = torch.arange(T, device=ids.device)
        x = self.tok_embed(ids) + self.pos_embed(pos)[None, :, :]
        for blk in self.blocks:
            x = blk(x)
        x = self.ln_final(x)
        phase_logits = None
        if route_atom is not None and self.cfg.d_phase == 1:
            # SEMANTIC RE-KEYING path. Force the caller-supplied atom (chosen by an
            # Engram embedding) at route_pos; default route_pos = last position only.
            if route_pos is None:
                route_pos = torch.zeros(B, T, dtype=torch.bool, device=ids.device)
                route_pos[:, -1] = True
            fact_contrib, fact_z, fact_dense = self.fact_layer.forward_routed(x, route_atom, route_pos)
        elif self.cfg.d_phase > 1:
            # Phase 9 multi-byte path. d_phase=1 NEVER reaches here, so the legacy
            # path below is byte-identical for all Phase 1-8 models.
            fact_contrib, fact_z, fact_dense, phase_logits = self.fact_layer.decode_phased(x)
        elif return_dense:
            fact_contrib, fact_z, fact_dense = self.fact_layer(x, return_dense=True)
        else:
            fact_contrib, fact_z = self.fact_layer(x)
        # Additive fact: the transformer's residual carries language / style;
        # the fact layer contributes one or more learned directions per token.
        # Edits to W_dec[:, k] map onto next-byte logits via unembed; the
        # transformer's contribution is unaffected. CRUD-safe by construction.
        x_full = x + self.cfg.fact_gain * fact_contrib
        logits = self.unembed(x_full)
        # v5 atoms-only path: pretend the transformer contributed nothing,
        # only the fact layer's atom-decoder output (which already includes
        # pre_bias from FactAtomLayer.decode). Training on this with CE on
        # fact tokens forces the model to route fact prediction THROUGH the
        # atom dictionary instead of co-memorizing in the transformer weights.
        if return_atoms_only:
            atoms_only_logits = self.unembed(self.cfg.fact_gain * fact_contrib)
            if return_dense:
                out = (logits, fact_z, fact_dense, atoms_only_logits)
            elif return_fact_z:
                out = (logits, fact_z, atoms_only_logits)
            else:
                out = (logits, atoms_only_logits)
            return out + (phase_logits,) if return_phase else out
        if return_dense:
            out = (logits, fact_z, fact_dense)
            return out + (phase_logits,) if return_phase else out
        if return_fact_z:
            out = (logits, fact_z)
            return out + (phase_logits,) if return_phase else out
        return (logits, phase_logits) if return_phase else logits

    def count_params(self) -> int:
        return sum(p.numel() for p in self.parameters())


@torch.no_grad()
def greedy_generate(model: YazLM, prompt_ids: Tensor, n_new: int, stop_id: int | None = None) -> Tensor:
    """Greedy generation. prompt_ids: (1, T0)."""
    model.eval()
    out = prompt_ids
    max_ctx = model.cfg.max_seq_len
    for _ in range(n_new):
        ctx = out if out.shape[1] <= max_ctx else out[:, -max_ctx:]
        logits = model(ctx)
        nxt = logits[:, -1, :].argmax(dim=-1, keepdim=True)
        out = torch.cat([out, nxt], dim=1)
        if stop_id is not None and int(nxt.item()) == stop_id:
            break
    return out