File size: 16,610 Bytes
b14638e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 | """Yaz POC architecture.
Standard byte-level causal transformer (3 blocks, d=128) with one twist:
just before unembed, a top-k=1 "fact atom" dictionary projects the
residual into d_dict=512 atoms, picks the single most-activated one,
and adds that atom's learnable decoder vector back into the residual.
This is the CRUD-target layer:
- W_dec[:, atom_id] ⟵ edit this single column = edit that fact
- zero it ⟵ delete the fact
- append a new column ⟵ add a fact
The encoder's bias for atom_id controls "is this fact accessible?".
"""
from __future__ import annotations
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
@dataclass
class YazConfig:
vocab_size: int = 256
d_model: int = 64
n_layers: int = 3
n_heads: int = 4
max_seq_len: int = 128
ffn_expand: int = 4
dropout: float = 0.0
# Fact-atom layer:
d_dict: int = 128 # number of addressable fact atoms
fact_top_k: int = 1 # v4: strict 1, paired with anti-collapse machinery
fact_gain: float = 1.0 # multiplier on the fact atom's contribution
# Phase 9 multi-byte: each atom owns d_phase value vectors; a shared phase
# head selects which one fires (by within-answer byte offset). d_phase=1 is
# exactly the single-vector model (byte-identical), so all Phase 1-8 configs
# are unaffected.
d_phase: int = 1
# Semantic re-keying: a learnable per-atom activation gain used by forward_routed
# so a forced (Engram-routed) atom can DOMINATE the residual — restoring the edit
# efficacy that a fixed activation=1.0 destroyed (backbone co-memorization). Only
# created when use_atom_gain=True, so surface (route_atom=None) models are unchanged.
use_atom_gain: bool = False
atom_gain_init: float = 1.0
class CausalSelfAttention(nn.Module):
def __init__(self, cfg: YazConfig):
super().__init__()
assert cfg.d_model % cfg.n_heads == 0
self.n_heads = cfg.n_heads
self.d_head = cfg.d_model // cfg.n_heads
self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
self.out = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
def forward(self, x: Tensor) -> Tensor:
B, T, D = x.shape
qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.d_head).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # (B, H, T, d_head)
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_head)
mask = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1)
att = att.masked_fill(mask, float("-inf"))
att = F.softmax(att, dim=-1)
y = att @ v # (B, H, T, d_head)
y = y.transpose(1, 2).reshape(B, T, D)
return self.out(y)
class FFN(nn.Module):
def __init__(self, cfg: YazConfig):
super().__init__()
h = cfg.d_model * cfg.ffn_expand
self.fc1 = nn.Linear(cfg.d_model, h)
self.fc2 = nn.Linear(h, cfg.d_model)
def forward(self, x: Tensor) -> Tensor:
return self.fc2(F.gelu(self.fc1(x)))
class Block(nn.Module):
def __init__(self, cfg: YazConfig):
super().__init__()
self.ln1 = nn.LayerNorm(cfg.d_model)
self.attn = CausalSelfAttention(cfg)
self.ln2 = nn.LayerNorm(cfg.d_model)
self.ffn = FFN(cfg)
def forward(self, x: Tensor) -> Tensor:
x = x + self.attn(self.ln1(x))
x = x + self.ffn(self.ln2(x))
return x
class FactAtomLayer(nn.Module):
"""Top-k=1 atom dictionary. Each atom is an addressable fact slot.
Shapes:
encoder W_enc: (d_dict, d_model) — pick which atom fires
decoder W_dec: (d_model, d_dict) — its contribution direction
pre_bias (d_model,) — centering offset
"""
def __init__(self, d_model: int, d_dict: int, top_k: int = 1, d_phase: int = 1,
use_atom_gain: bool = False, atom_gain_init: float = 1.0):
super().__init__()
if top_k < 1 or top_k > d_dict:
raise ValueError(f"top_k must be in [1, {d_dict}], got {top_k}")
if d_phase < 1:
raise ValueError(f"d_phase must be >= 1, got {d_phase}")
self.d_model = d_model
self.d_dict = d_dict
self.top_k = top_k
self.d_phase = d_phase
# Learnable per-atom gain for forward_routed (semantic re-keying). Absent by
# default so legacy/surface state_dicts are byte-identical.
self.atom_gain = (nn.Parameter(torch.full((d_dict,), float(atom_gain_init)))
if use_atom_gain else None)
self.W_enc = nn.Linear(d_model, d_dict, bias=True)
self.W_dec = nn.Linear(d_dict, d_model, bias=False)
self.pre_bias = nn.Parameter(torch.zeros(d_model))
# Tie initial weights — encoder = decoder.T, unit-norm columns.
with torch.no_grad():
w = torch.randn(d_dict, d_model)
w = w / w.norm(dim=1, keepdim=True).clamp_min(1e-6)
self.W_dec.weight.copy_(w.t())
self.W_enc.weight.copy_(w)
self.W_enc.bias.zero_()
# Phase 9: extra decoder columns for phases 1..d_phase-1 (phase 0 = W_dec),
# plus a shared phase head. Created ONLY when d_phase>1 so the d_phase=1
# state_dict is identical to the legacy model. Extra columns start at zero
# so an untrained multi-byte model == the single-vector model.
if d_phase > 1:
self.W_dec_extra = nn.Parameter(torch.zeros(d_phase - 1, d_model, d_dict))
self.phase_head = nn.Linear(d_model, d_phase, bias=True)
def encode(self, x: Tensor) -> Tensor:
"""Returns sparse activations (..., d_dict) with only top_k nonzero entries."""
z = self.W_enc(x - self.pre_bias)
z = F.relu(z)
top_vals, top_idx = z.topk(self.top_k, dim=-1)
sparse = torch.zeros_like(z).scatter_(-1, top_idx, top_vals)
return sparse
def encode_dense(self, x: Tensor) -> Tensor:
"""Returns the PRE-topk gate scores (..., d_dict), for load-balancing loss."""
return F.relu(self.W_enc(x - self.pre_bias))
def encode_logits(self, x: Tensor) -> Tensor:
"""Returns the PRE-ReLU encoder logits (..., d_dict), for supervised
target-atom CE loss (Exp 3). Gradient flows freely through every
atom dim (no ReLU dead-zone)."""
return self.W_enc(x - self.pre_bias)
def decode(self, z: Tensor) -> Tensor:
return self.W_dec(z) + self.pre_bias
@torch.no_grad()
def resurrect(self, dead_idx: Tensor, source: Tensor) -> int:
"""Re-init the encoder row + decoder col for each dead atom from a
random row of `source` (current residuals, shape (N, d_model)).
Returns the number of atoms actually resurrected.
"""
if dead_idx.numel() == 0 or source.numel() == 0:
return 0
N = source.size(0)
# Pick |dead_idx| random source rows
pick = torch.randint(0, N, (dead_idx.numel(),), device=source.device)
v = source[pick] # (k_dead, d_model)
norms = v.norm(dim=1, keepdim=True).clamp_min(1e-6)
v = v / norms # unit-norm rows
self.W_enc.weight[dead_idx] = v
self.W_dec.weight[:, dead_idx] = v.t() # decoder col = unit dir
self.W_enc.bias[dead_idx] = 0.0
return int(dead_idx.numel())
def orthogonality_loss(self) -> Tensor:
"""Mean squared off-diagonal of normalized W_dec column Gram matrix.
0 when all decoder columns are orthonormal; bounded by 1.
"""
W = self.W_dec.weight # (d_model, d_dict)
norms = W.norm(dim=0, keepdim=True).clamp_min(1e-6)
Wn = W / norms
G = Wn.t() @ Wn # (d_dict, d_dict)
eye = torch.eye(self.d_dict, device=G.device, dtype=G.dtype)
off = G - eye
return (off ** 2).mean()
def decode_phased(self, x: Tensor):
"""Phase 9 multi-byte path. Pick the top-1 atom (unchanged), then let the
shared phase head choose which of the atom's d_phase value vectors fires.
Returns (out, z_sparse, z_dense, phase_logits). For d_phase=1 this is
algebraically identical to forward()/decode() with a single column.
"""
z_dense = self.encode_dense(x) # (..., d_dict)
top_vals, top_idx = z_dense.topk(self.top_k, dim=-1) # top_k==1
z_sparse = torch.zeros_like(z_dense).scatter_(-1, top_idx, top_vals)
a = top_vals[..., 0] # (...) activation
k = top_idx[..., 0] # (...) atom id
phase_logits = self.phase_head(x) # (..., d_phase)
v = phase_logits.argmax(dim=-1) # (...) chosen phase
# Stacked decoder: D[0] = legacy W_dec, D[1:] = extra phase columns.
D = torch.cat([self.W_dec.weight.unsqueeze(0), self.W_dec_extra], dim=0) # (V, d_model, d_dict)
shp = k.shape
col = D[v.reshape(-1), :, k.reshape(-1)] # (N, d_model)
contrib = (a.reshape(-1, 1) * col).reshape(*shp, self.d_model)
out = contrib + self.pre_bias
return out, z_sparse, z_dense, phase_logits
def forward(self, x: Tensor, return_dense: bool = False) -> tuple[Tensor, Tensor]:
z_dense = self.encode_dense(x)
top_vals, top_idx = z_dense.topk(self.top_k, dim=-1)
z = torch.zeros_like(z_dense).scatter_(-1, top_idx, top_vals)
out = self.decode(z)
if return_dense:
# Use z_dense so the LB-loss has gradient through W_enc.
return out, z, z_dense
return out, z
def forward_routed(self, x: Tensor, route_atom: Tensor, route_pos: Tensor):
"""SEMANTIC RE-KEYING path. At positions where `route_pos` is True, FORCE the
atom given by `route_atom` (one per sequence) to fire with activation 1.0 instead
of the learned argmax(ReLU(W_enc·…)) selection. Elsewhere, learned routing is kept
(so the language prefix is unaffected and W_enc still trains on non-fact tokens).
This decouples WHICH atom fires from the byte-transformer surface activations:
the caller picks the atom via a frozen semantic embedding (Engram), so paraphrases
route to the same atom. W_dec / pre_bias / the additive contribution / CRUD are
unchanged — only the selection is overridden.
route_atom: (B,) int64 atom ids. route_pos: (B, T) bool.
Returns (out, z_sparse, z_dense) like forward(return_dense=True).
"""
z_dense = self.encode_dense(x) # (B,T,d_dict)
top_vals, top_idx = z_dense.topk(self.top_k, dim=-1)
z = torch.zeros_like(z_dense).scatter_(-1, top_idx, top_vals) # learned top-1
B, T, _ = z.shape
forced = torch.zeros_like(z)
idx = route_atom.view(B, 1, 1).expand(B, T, 1)
if self.atom_gain is not None:
# learnable per-atom magnitude so the forced atom can dominate the residual
gain = self.atom_gain[route_atom].view(B, 1, 1).expand(B, T, 1)
forced.scatter_(-1, idx, gain)
else:
forced.scatter_(-1, idx, 1.0) # legacy: act=1.0
z = torch.where(route_pos.unsqueeze(-1), forced, z)
out = self.decode(z)
return out, z, z_dense
class YazLM(nn.Module):
def __init__(self, cfg: YazConfig):
super().__init__()
self.cfg = cfg
self.tok_embed = nn.Embedding(cfg.vocab_size, cfg.d_model)
self.pos_embed = nn.Embedding(cfg.max_seq_len, cfg.d_model)
self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)])
self.ln_final = nn.LayerNorm(cfg.d_model)
self.fact_layer = FactAtomLayer(cfg.d_model, cfg.d_dict, top_k=cfg.fact_top_k,
d_phase=cfg.d_phase, use_atom_gain=cfg.use_atom_gain,
atom_gain_init=cfg.atom_gain_init)
# Tied embedding for unembed.
self.unembed = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
def forward(self, ids: Tensor, return_fact_z: bool = False, return_dense: bool = False,
return_atoms_only: bool = False, return_phase: bool = False,
route_atom: Tensor | None = None, route_pos: Tensor | None = None):
"""ids: (B, T) of token ids.
Returns logits (B, T, vocab). Optionally also returns the sparse
fact-z (for inspection / CRUD address lookup), the dense
pre-topk gate scores (for the load-balancing loss in training),
the atoms-only logits (unembed of fact_contrib alone,
used by the v5 atoms-only auxiliary loss), and/or the per-position
phase_logits (Phase 9 multi-byte; appended last when return_phase=True).
"""
B, T = ids.shape
assert T <= self.cfg.max_seq_len, f"T={T} > max_seq_len={self.cfg.max_seq_len}"
pos = torch.arange(T, device=ids.device)
x = self.tok_embed(ids) + self.pos_embed(pos)[None, :, :]
for blk in self.blocks:
x = blk(x)
x = self.ln_final(x)
phase_logits = None
if route_atom is not None and self.cfg.d_phase == 1:
# SEMANTIC RE-KEYING path. Force the caller-supplied atom (chosen by an
# Engram embedding) at route_pos; default route_pos = last position only.
if route_pos is None:
route_pos = torch.zeros(B, T, dtype=torch.bool, device=ids.device)
route_pos[:, -1] = True
fact_contrib, fact_z, fact_dense = self.fact_layer.forward_routed(x, route_atom, route_pos)
elif self.cfg.d_phase > 1:
# Phase 9 multi-byte path. d_phase=1 NEVER reaches here, so the legacy
# path below is byte-identical for all Phase 1-8 models.
fact_contrib, fact_z, fact_dense, phase_logits = self.fact_layer.decode_phased(x)
elif return_dense:
fact_contrib, fact_z, fact_dense = self.fact_layer(x, return_dense=True)
else:
fact_contrib, fact_z = self.fact_layer(x)
# Additive fact: the transformer's residual carries language / style;
# the fact layer contributes one or more learned directions per token.
# Edits to W_dec[:, k] map onto next-byte logits via unembed; the
# transformer's contribution is unaffected. CRUD-safe by construction.
x_full = x + self.cfg.fact_gain * fact_contrib
logits = self.unembed(x_full)
# v5 atoms-only path: pretend the transformer contributed nothing,
# only the fact layer's atom-decoder output (which already includes
# pre_bias from FactAtomLayer.decode). Training on this with CE on
# fact tokens forces the model to route fact prediction THROUGH the
# atom dictionary instead of co-memorizing in the transformer weights.
if return_atoms_only:
atoms_only_logits = self.unembed(self.cfg.fact_gain * fact_contrib)
if return_dense:
out = (logits, fact_z, fact_dense, atoms_only_logits)
elif return_fact_z:
out = (logits, fact_z, atoms_only_logits)
else:
out = (logits, atoms_only_logits)
return out + (phase_logits,) if return_phase else out
if return_dense:
out = (logits, fact_z, fact_dense)
return out + (phase_logits,) if return_phase else out
if return_fact_z:
out = (logits, fact_z)
return out + (phase_logits,) if return_phase else out
return (logits, phase_logits) if return_phase else logits
def count_params(self) -> int:
return sum(p.numel() for p in self.parameters())
@torch.no_grad()
def greedy_generate(model: YazLM, prompt_ids: Tensor, n_new: int, stop_id: int | None = None) -> Tensor:
"""Greedy generation. prompt_ids: (1, T0)."""
model.eval()
out = prompt_ids
max_ctx = model.cfg.max_seq_len
for _ in range(n_new):
ctx = out if out.shape[1] <= max_ctx else out[:, -max_ctx:]
logits = model(ctx)
nxt = logits[:, -1, :].argmax(dim=-1, keepdim=True)
out = torch.cat([out, nxt], dim=1)
if stop_id is not None and int(nxt.item()) == stop_id:
break
return out
|