| |
| """ |
| cdm_model_v2.py — Competitive Docking Memory V2 |
| |
| V1 finding: non-causal slots_final trick gives identical gradient signal to all |
| slots at every position → winner-take-all collapse (6/8 slots dead, K_eff=2). |
| |
| V2 fixes: |
| 1. CAUSAL slots: position t uses slots_t (summary of h[0..t-1]), not slots_final. |
| Each position gets a different gradient signal → routing diversifies. |
| |
| 2. DUAL attention path: |
| - Standard causal self-attention (sequence tokens only, no slots in KV) |
| - Slot cross-attention: each pos t attends to its K causal slot vectors |
| These two paths are summed before the residual, keeping KV cache clean. |
| |
| 3. MARGINAL ENTROPY REGULARIZATION: |
| Maximize entropy of marginal slot distribution across positions. |
| Within-position: concentrated (one slot wins per token = specialization) |
| Across-position: diverse (different tokens → different slots = no collapse) |
| Loss: -lambda_ent * H(E_t[g_k(t)]) where H = entropy |
| |
| 4. K=16 default (optimal from V1 ablation: K=16 beats K=8 by 17%, K=32 degrades) |
| |
| Architecture: Archon (DuoNeural) |
| Math analysis (parallel scan, entropy reg derivation): Aura (DuoNeural) |
| Date: 2026-06-11 |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from dataclasses import dataclass, field |
|
|
|
|
| @dataclass |
| class CDMConfigV2: |
| vocab_size: int = 50257 |
| n_layers: int = 8 |
| d_model: int = 384 |
| n_heads: int = 8 |
| n_kv_heads: int = 4 |
| d_ff: int = 1024 |
| K: int = 16 |
| max_len: int = 512 |
| dropout: float = 0.1 |
| entropy_reg: float = 0.02 |
|
|
|
|
| class RoPE(nn.Module): |
| def __init__(self, d_head: int, max_len: int): |
| super().__init__() |
| theta = 1.0 / (10000 ** (torch.arange(0, d_head, 2).float() / d_head)) |
| t = torch.arange(max_len).float() |
| freqs = torch.outer(t, theta) |
| self.register_buffer("cos", freqs.cos()[None, None, :, :]) |
| self.register_buffer("sin", freqs.sin()[None, None, :, :]) |
|
|
| def forward(self, x): |
| d = x.shape[-1] |
| x1, x2 = x[..., :d//2], x[..., d//2:] |
| cos = self.cos[:, :, :x.shape[2], :] |
| sin = self.sin[:, :, :x.shape[2], :] |
| return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1) |
|
|
| def forward_at(self, x, offset: int = 0): |
| """RoPE at absolute position `offset`. x: (B, H, T, d_head). Used for cached generation.""" |
| T = x.shape[2] |
| x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:] |
| cos = self.cos[:, :, offset:offset + T, :] |
| sin = self.sin[:, :, offset:offset + T, :] |
| return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1) |
|
|
|
|
| class CausalSelfAttention(nn.Module): |
| """Standard GQA causal self-attention. No slots here — they go through slot_xattn.""" |
| def __init__(self, cfg: CDMConfigV2): |
| super().__init__() |
| self.n_heads = cfg.n_heads |
| self.n_kv_heads = cfg.n_kv_heads |
| self.d_head = cfg.d_model // cfg.n_heads |
| self.n_rep = cfg.n_heads // cfg.n_kv_heads |
|
|
| self.q_proj = nn.Linear(cfg.d_model, cfg.n_heads * self.d_head, bias=False) |
| self.k_proj = nn.Linear(cfg.d_model, cfg.n_kv_heads * self.d_head, bias=False) |
| self.v_proj = nn.Linear(cfg.d_model, cfg.n_kv_heads * self.d_head, bias=False) |
| self.o_proj = nn.Linear(cfg.n_heads * self.d_head, cfg.d_model, bias=False) |
| self.rope = RoPE(self.d_head, cfg.max_len) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| B, T, _ = x.shape |
| Q = self.q_proj(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2) |
| K = self.k_proj(x).view(B, T, self.n_kv_heads, self.d_head).transpose(1, 2) |
| V = self.v_proj(x).view(B, T, self.n_kv_heads, self.d_head).transpose(1, 2) |
| Q, K = self.rope(Q), self.rope(K) |
| K = K.repeat_interleave(self.n_rep, dim=1) |
| V = V.repeat_interleave(self.n_rep, dim=1) |
| |
| out = F.scaled_dot_product_attention(Q, K, V, is_causal=True) |
| return self.o_proj(out.transpose(1, 2).contiguous().view(B, T, -1)) |
|
|
| def forward_cached(self, x_t: torch.Tensor, past_kv, position: int): |
| """ |
| Single-token forward with KV cache. |
| x_t: (B, 1, d) |
| past_kv: (K_cache: (B, n_kv_heads, T_past, d_head), |
| V_cache: (B, n_kv_heads, T_past, d_head)) or None |
| position: absolute token index (for RoPE) |
| Returns: (out: (B, 1, d), new_kv: (K_full, V_full)) |
| """ |
| B = x_t.shape[0] |
| Q = self.q_proj(x_t).view(B, 1, self.n_heads, self.d_head).transpose(1, 2) |
| K_n = self.k_proj(x_t).view(B, 1, self.n_kv_heads, self.d_head).transpose(1, 2) |
| V_n = self.v_proj(x_t).view(B, 1, self.n_kv_heads, self.d_head).transpose(1, 2) |
|
|
| Q = self.rope.forward_at(Q, offset=position) |
| K_n = self.rope.forward_at(K_n, offset=position) |
|
|
| if past_kv is not None: |
| K_c, V_c = past_kv |
| K_full = torch.cat([K_c, K_n], dim=2) |
| V_full = torch.cat([V_c, V_n], dim=2) |
| else: |
| K_full, V_full = K_n, V_n |
|
|
| K_attn = K_full.repeat_interleave(self.n_rep, dim=1) |
| V_attn = V_full.repeat_interleave(self.n_rep, dim=1) |
| |
| out = F.scaled_dot_product_attention(Q, K_attn, V_attn, is_causal=False) |
| out = self.o_proj(out.transpose(1, 2).contiguous().view(B, 1, -1)) |
| return out, (K_full, V_full) |
|
|
|
|
| class SlotCrossAttention(nn.Module): |
| """ |
| Per-position slot cross-attention. |
| |
| Each sequence position t attends to its K causal slot vectors from CDM. |
| slots_all[b, t, k, :] = summary of h[0..t-1] for slot k (causally correct). |
| |
| Implementation: batch over positions by reshaping (B, T) → (B*T, 1): |
| Q: (B*T, n_heads, 1, d_head) — one query per position |
| K,V: (B*T, n_kv_heads, K, d_head) — K slot keys/values per position |
| |
| Output: (B, T, d_model) |
| """ |
| def __init__(self, cfg: CDMConfigV2): |
| super().__init__() |
| self.n_heads = cfg.n_heads |
| self.n_kv_heads = cfg.n_kv_heads |
| self.d_head = cfg.d_model // cfg.n_heads |
| self.n_rep = cfg.n_heads // cfg.n_kv_heads |
| self.scale = self.d_head ** -0.5 |
|
|
| self.q_proj = nn.Linear(cfg.d_model, cfg.n_heads * self.d_head, bias=False) |
| self.k_proj = nn.Linear(cfg.d_model, cfg.n_kv_heads * self.d_head, bias=False) |
| self.v_proj = nn.Linear(cfg.d_model, cfg.n_kv_heads * self.d_head, bias=False) |
| self.o_proj = nn.Linear(cfg.n_heads * self.d_head, cfg.d_model, bias=False) |
|
|
| def forward(self, x: torch.Tensor, slots_all: torch.Tensor) -> torch.Tensor: |
| """ |
| x: (B, T, d_model) |
| slots_all: (B, T, K, d_model) — causal slot states |
| Returns: (B, T, d_model) |
| """ |
| B, T, d = x.shape |
| K = slots_all.shape[2] |
|
|
| |
| Q = self.q_proj(x) |
| Q = Q.view(B * T, 1, self.n_heads, self.d_head).transpose(1, 2) |
|
|
| |
| slots_flat = slots_all.view(B * T, K, d) |
| Ks = self.k_proj(slots_flat).view(B * T, K, self.n_kv_heads, self.d_head).transpose(1, 2) |
| Vs = self.v_proj(slots_flat).view(B * T, K, self.n_kv_heads, self.d_head).transpose(1, 2) |
|
|
| |
| Ks = Ks.repeat_interleave(self.n_rep, dim=1) |
| Vs = Vs.repeat_interleave(self.n_rep, dim=1) |
|
|
| |
| out = F.scaled_dot_product_attention(Q, Ks, Vs) |
|
|
| out = out.squeeze(2) |
| out = out.view(B, T, self.n_heads * self.d_head) |
| return self.o_proj(out) |
|
|
|
|
| class FFN(nn.Module): |
| def __init__(self, cfg: CDMConfigV2): |
| super().__init__() |
| self.gate = nn.Linear(cfg.d_model, cfg.d_ff, bias=False) |
| self.up = nn.Linear(cfg.d_model, cfg.d_ff, bias=False) |
| self.down = nn.Linear(cfg.d_ff, cfg.d_model, bias=False) |
| self.dropout = nn.Dropout(cfg.dropout) |
|
|
| def forward(self, x): |
| return self.dropout(self.down(F.silu(self.gate(x)) * self.up(x))) |
|
|
|
|
| class CompetitiveDockingMemory(nn.Module): |
| """ |
| CDM V2 — same linear recurrence as V1, but forward() now returns |
| (slots_all, gates) so the training loop can compute entropy reg loss. |
| |
| The key fix is NOT in this module — it's in CDMBlock.forward() where we |
| now use position-specific slots instead of slots_final for all positions. |
| """ |
| def __init__(self, cfg: CDMConfigV2): |
| super().__init__() |
| self.K = cfg.K |
| self.d = cfg.d_model |
|
|
| self.route = nn.Linear(cfg.d_model, cfg.K, bias=True) |
| self.eta = nn.Linear(cfg.d_model, 1, bias=True) |
| self.write_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False) |
| self.slot_init = nn.Parameter(torch.zeros(cfg.K, cfg.d_model)) |
|
|
| nn.init.zeros_(self.route.bias) |
| nn.init.constant_(self.eta.bias, -2.0) |
| nn.init.normal_(self.slot_init, std=0.02) |
|
|
| def compute_gates(self, h: torch.Tensor): |
| """h: (B, T, d) → gates: (B, T, K) — routing weights × global write intensity.""" |
| w = F.softmax(self.route(h), dim=-1) |
| eta = torch.sigmoid(self.eta(h)) |
| return w * eta |
|
|
| @staticmethod |
| def _sequential_scan(A: torch.Tensor, B: torch.Tensor, |
| init: torch.Tensor) -> torch.Tensor: |
| """ |
| Sequential scan for s_t = A_t * s_{t-1} + B_t. |
| |
| Memory: O(T * B * K * d) — stores one (B,K,d) state per timestep. |
| For B=32, T=256, K=16, d=384: ~200MB per block (vs ~3GB for parallel scan). |
| |
| The parallel O(log T) scan creates O(T * log T) intermediate tensors in the |
| autograd graph, blowing past 16GB VRAM at full batch. Sequential is the right |
| default for T≤512. Parallel scan can be revisited with gradient checkpointing. |
| |
| Returns slots_before: [s_{-1}, s_0, ..., s_{T-2}] — causal slot state at t. |
| """ |
| B_size, T, K, d = B.shape |
| |
| states = torch.empty(B_size, T, K, d, device=B.device, dtype=B.dtype) |
| s = init |
| states[:, 0] = s |
| for t in range(T - 1): |
| s = A[:, t] * s + B[:, t] |
| states[:, t + 1] = s |
| return states |
|
|
| def forward(self, h: torch.Tensor): |
| """ |
| h: (B, T, d) |
| Returns: |
| slots_all: (B, T, K, d) — CAUSAL slot state before each position |
| gates: (B, T, K) — routing gates (for entropy reg) |
| """ |
| B, T, d = h.shape |
| gates = self.compute_gates(h) |
| v = self.write_proj(h) |
|
|
| g = gates.unsqueeze(-1) |
| A = (1.0 - g).expand(B, T, self.K, d) |
| B_s = g * v.unsqueeze(2).expand(B, T, self.K, d) |
| init = self.slot_init.unsqueeze(0).expand(B, self.K, d) |
|
|
| slots_all = self._sequential_scan(A, B_s, init) |
| return slots_all, gates |
|
|
| def step(self, h_t: torch.Tensor, prev_state: torch.Tensor): |
| """ |
| Single-step incremental update for cached generation. |
| h_t: (B, d) — single token hidden state |
| prev_state: (B, K, d) — cached slot state from previous position |
| Returns: |
| new_state: (B, K, d) — updated slot state (cache for next step) |
| slots_for_sa: (B, 1, K, d) — prev_state as (T=1) causal slot (BEFORE this token) |
| gates_t: (B, K) — routing gates at this position |
| """ |
| h = h_t.unsqueeze(1) |
| gates_t = self.compute_gates(h)[:, 0, :] |
| v_t = self.write_proj(h)[:, 0, :] |
| g = gates_t.unsqueeze(-1) |
| |
| new_state = (1.0 - g) * prev_state + g * v_t.unsqueeze(1) |
| slots_for_sa = prev_state.unsqueeze(1) |
| return new_state, slots_for_sa, gates_t |
|
|
|
|
| def marginal_entropy_loss(gates: torch.Tensor) -> torch.Tensor: |
| """ |
| Marginal entropy regularization. |
| |
| Within each position: concentrated gate (one slot wins) = specialization. |
| Across positions: diverse marginal (different slots win at different positions). |
| |
| loss = -H(E_t[gates]) = -entropy of the time-averaged gate distribution. |
| Minimizing this loss MAXIMIZES entropy = encourages diversity across positions. |
| |
| gates: (B, T, K) — softmax outputs from CDM.route (or full gates w/ eta) |
| Returns: scalar loss (minimize to encourage diverse routing) |
| """ |
| |
| marginal = gates.mean(dim=1) |
| marginal = marginal / (marginal.sum(dim=-1, keepdim=True) + 1e-8) |
| log_marginal = torch.log(marginal + 1e-12) |
| entropy = -(marginal * log_marginal).sum(dim=-1) |
| return -entropy.mean() |
|
|
|
|
| class CDMBlockV2(nn.Module): |
| """ |
| V2 block: causal slots + dual attention path. |
| |
| Forward sequence: |
| 1. CDM: compute causal slot states slots_all[t] = summary of h[0..t-1] |
| 2. Self-attention: standard causal sequence self-attention |
| 3. Slot cross-attention: each position t attends to its K causal slot vectors |
| 4. Add both attention outputs (residual) |
| 5. FFN (residual) |
| """ |
| def __init__(self, cfg: CDMConfigV2): |
| super().__init__() |
| self.cdm = CompetitiveDockingMemory(cfg) |
| self.self_attn = CausalSelfAttention(cfg) |
| self.slot_xattn = SlotCrossAttention(cfg) |
| self.ffn = FFN(cfg) |
| self.norm_sa = nn.RMSNorm(cfg.d_model) |
| self.norm_sx = nn.RMSNorm(cfg.d_model) |
| self.norm_cdm = nn.RMSNorm(cfg.d_model) |
| self.norm_ff = nn.RMSNorm(cfg.d_model) |
| self.dropout = nn.Dropout(cfg.dropout) |
|
|
| def forward(self, x: torch.Tensor, return_slots: bool = False): |
| """ |
| x: (B, T, d) |
| Returns: (x_out, gates) normally, or (x_out, gates, slots_all) if return_slots=True |
| gates: (B, T, K) for entropy reg |
| slots_all: (B, T, K, d) causal slot states (for Logit Lens visualization) |
| """ |
| slots_all, gates = self.cdm(self.norm_cdm(x)) |
|
|
| sa_out = self.self_attn(self.norm_sa(x)) |
| sx_out = self.slot_xattn(self.norm_sx(x), slots_all) |
| x = x + self.dropout(sa_out + sx_out) |
|
|
| x = x + self.ffn(self.norm_ff(x)) |
| if return_slots: |
| return x, gates, slots_all |
| return x, gates |
|
|
| def forward_step(self, x_t: torch.Tensor, slot_state: torch.Tensor, |
| past_kv, position: int): |
| """ |
| Single-token step with slot + KV caches. |
| x_t: (B, 1, d) |
| slot_state: (B, K, d) — cached slot state (will be updated and returned) |
| past_kv: (K_cache, V_cache) or None |
| position: absolute token index |
| Returns: (x_out: (B, 1, d), new_slot_state: (B, K, d), new_kv, gates: (B, K)) |
| """ |
| h_t = x_t[:, 0, :] |
| new_slot_state, slots_for_sa, gates_t = self.cdm.step( |
| self.norm_cdm(h_t), slot_state |
| ) |
|
|
| sa_out, new_kv = self.self_attn.forward_cached( |
| self.norm_sa(x_t), past_kv, position |
| ) |
| sx_out = self.slot_xattn( |
| self.norm_sx(x_t), slots_for_sa |
| ) |
|
|
| x_t = x_t + sa_out + sx_out |
| x_t = x_t + self.ffn(self.norm_ff(x_t)) |
| return x_t, new_slot_state, new_kv, gates_t |
|
|
|
|
| class CDMLanguageModelV2(nn.Module): |
| def __init__(self, cfg: CDMConfigV2): |
| super().__init__() |
| self.cfg = cfg |
| self.embed = nn.Embedding(cfg.vocab_size, cfg.d_model) |
| self.blocks = nn.ModuleList([CDMBlockV2(cfg) for _ in range(cfg.n_layers)]) |
| self.norm = nn.RMSNorm(cfg.d_model) |
| self.head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) |
| self.head.weight = self.embed.weight |
| self._init_weights() |
|
|
| def _init_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Linear): |
| nn.init.normal_(m.weight, std=0.02) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.Embedding): |
| nn.init.normal_(m.weight, std=0.02) |
|
|
| def forward(self, idx: torch.Tensor): |
| """ |
| Returns: (logits, aux_loss) where aux_loss = entropy_reg across all layers. |
| In inference mode, aux_loss = 0. |
| Add aux_loss to cross-entropy loss during training. |
| """ |
| x = self.embed(idx) |
| aux_loss = torch.tensor(0.0, device=idx.device) |
|
|
| for block in self.blocks: |
| x, gates = block(x) |
| if self.training and self.cfg.entropy_reg > 0: |
| |
| |
| |
| aux_loss = aux_loss + self.cfg.entropy_reg * marginal_entropy_loss(gates) |
|
|
| x = self.norm(x) |
| return self.head(x), aux_loss |
|
|
| @torch.no_grad() |
| def generate(self, idx: torch.Tensor, max_new: int, temperature: float = 1.0, |
| top_k: int = 50) -> torch.Tensor: |
| self.eval() |
| for _ in range(max_new): |
| idx_cond = idx if idx.shape[1] <= self.cfg.max_len else idx[:, -self.cfg.max_len:] |
| logits, _ = self(idx_cond) |
| logits = logits[:, -1, :] / temperature |
| if top_k > 0: |
| v, _ = torch.topk(logits, min(top_k, logits.shape[-1])) |
| logits[logits < v[:, [-1]]] = float('-inf') |
| probs = F.softmax(logits, dim=-1) |
| next_tok = torch.multinomial(probs, num_samples=1) |
| idx = torch.cat([idx, next_tok], dim=1) |
| return idx |
|
|
| @torch.no_grad() |
| def generate_with_slots(self, idx: torch.Tensor, max_new: int, tokenizer, |
| temperature: float = 1.0, top_k: int = 50): |
| """ |
| Generate text and capture routing gate distributions per token. |
| Returns: (generated_text, snapshots) |
| snapshots: list of (token_str, all_layer_gates, winner_slot) per new token |
| all_layer_gates: list of n_layers lists, each with K floats (gate weights 0-1) |
| winner_slot: 0-indexed winning slot in last layer (argmax of last-layer gates) |
| |
| Gate weights show which slot "claimed" each token — this is the actual routing |
| specialization signal. Slot 11 (0-indexed) should dominate for punctuation. |
| """ |
| self.eval() |
| snapshots = [] |
|
|
| for _ in range(max_new): |
| idx_cond = idx if idx.shape[1] <= self.cfg.max_len else idx[:, -self.cfg.max_len:] |
| x = self.embed(idx_cond) |
| all_layer_gates = [] |
| for block in self.blocks: |
| x, gates = block(x) |
| |
| g = gates[0, -1, :].tolist() |
| all_layer_gates.append(g) |
| x = self.norm(x) |
| logits = self.head(x) |
|
|
| logits_next = logits[:, -1, :] / temperature |
| if top_k > 0: |
| v, _ = torch.topk(logits_next, min(top_k, logits_next.shape[-1])) |
| logits_next[logits_next < v[:, [-1]]] = float('-inf') |
| probs = F.softmax(logits_next, dim=-1) |
| next_tok = torch.multinomial(probs, num_samples=1) |
|
|
| tok_str = tokenizer.decode([next_tok[0, 0].item()]).strip() |
| last_gates = all_layer_gates[-1] |
| winner = int(max(range(len(last_gates)), key=lambda k: last_gates[k])) |
| snapshots.append((tok_str, all_layer_gates, winner)) |
|
|
| idx = torch.cat([idx, next_tok], dim=1) |
|
|
| generated_text = tokenizer.decode(idx[0].tolist(), skip_special_tokens=True) |
| return generated_text, snapshots |
|
|
| @torch.no_grad() |
| def generate_fast(self, idx: torch.Tensor, max_new: int, temperature: float = 1.0, |
| top_k: int = 50) -> torch.Tensor: |
| """ |
| Cache-aware autoregressive generation — O(1) per new token. |
| |
| vs generate(): re-runs full O(T) sequential scan each step → O(T²) total |
| vs generate_fast(): runs prefix once, then O(1) per new token → O(T + N) total |
| |
| How it works: |
| 1. Prefix pass: standard forward to build KV caches + final slot states |
| 2. Per-token: CDM.step() (single EMA update), forward_cached() (KV append+attend) |
| No Python loops over sequence length — O(1) arithmetic per token per layer |
| |
| Expected speedup: ~10-20× for typical 256-token context + 100 generated tokens. |
| At 256-token prefix + 200 new tokens: generate() = 456 × O(256) work; |
| generate_fast() = O(256) prefix + 200 × O(1) steps. |
| """ |
| self.eval() |
| B = idx.shape[0] |
| device = idx.device |
|
|
| |
| T_prefix = idx.shape[1] |
| x = self.embed(idx) |
|
|
| |
| |
| kv_caches = [None] * len(self.blocks) |
| slot_states = [] |
|
|
| for li, block in enumerate(self.blocks): |
| |
| slots_all, gates = block.cdm(block.norm_cdm(x)) |
|
|
| |
| x_norm_sa = block.norm_sa(x) |
| Q = block.self_attn.q_proj(x_norm_sa).view(B, T_prefix, block.self_attn.n_heads, block.self_attn.d_head).transpose(1, 2) |
| K_ = block.self_attn.k_proj(x_norm_sa).view(B, T_prefix, block.self_attn.n_kv_heads, block.self_attn.d_head).transpose(1, 2) |
| V_ = block.self_attn.v_proj(x_norm_sa).view(B, T_prefix, block.self_attn.n_kv_heads, block.self_attn.d_head).transpose(1, 2) |
| Q = block.self_attn.rope(Q) |
| K_ = block.self_attn.rope(K_) |
| K_exp = K_.repeat_interleave(block.self_attn.n_rep, dim=1) |
| V_exp = V_.repeat_interleave(block.self_attn.n_rep, dim=1) |
| sa_out = F.scaled_dot_product_attention(Q, K_exp, V_exp, is_causal=True) |
| sa_out = block.self_attn.o_proj(sa_out.transpose(1, 2).contiguous().view(B, T_prefix, -1)) |
| kv_caches[li] = (K_, V_) |
|
|
| sx_out = block.slot_xattn(block.norm_sx(x), slots_all) |
| x = x + sa_out + sx_out |
| x = x + block.ffn(block.norm_ff(x)) |
|
|
| |
| |
| |
| last_state = slots_all[:, -1, :, :] |
| |
| h_last = block.cdm.write_proj(block.norm_cdm(x[:, -1:, :]))[:, 0, :] |
| |
| |
| slot_states.append(last_state) |
|
|
| x_last = self.norm(x) |
| logits = self.head(x_last) |
|
|
| |
| logits_next = logits[:, -1, :] / temperature |
| if top_k > 0: |
| v_top, _ = torch.topk(logits_next, min(top_k, logits_next.shape[-1])) |
| logits_next[logits_next < v_top[:, [-1]]] = float('-inf') |
| next_tok = torch.multinomial(F.softmax(logits_next, dim=-1), num_samples=1) |
| idx = torch.cat([idx, next_tok], dim=1) |
|
|
| |
| for step_i in range(max_new - 1): |
| position = T_prefix + step_i |
| x_t = self.embed(next_tok) |
|
|
| new_slot_states = [] |
| new_kv_caches = [] |
|
|
| for li, block in enumerate(self.blocks): |
| x_t, new_ss, new_kv, _ = block.forward_step( |
| x_t, slot_states[li], kv_caches[li], position |
| ) |
| new_slot_states.append(new_ss) |
| new_kv_caches.append(new_kv) |
|
|
| slot_states = new_slot_states |
| kv_caches = new_kv_caches |
|
|
| x_t_norm = self.norm(x_t) |
| logits_next = self.head(x_t_norm)[:, 0, :] / temperature |
| if top_k > 0: |
| v_top, _ = torch.topk(logits_next, min(top_k, logits_next.shape[-1])) |
| logits_next[logits_next < v_top[:, [-1]]] = float('-inf') |
| next_tok = torch.multinomial(F.softmax(logits_next, dim=-1), num_samples=1) |
| idx = torch.cat([idx, next_tok], dim=1) |
|
|
| return idx |
|
|
| @torch.no_grad() |
| def benchmark_throughput(self, prompt: str, tokenizer, max_new: int = 128, |
| device: str = 'cuda', n_runs: int = 3): |
| """ |
| Compare generate() vs generate_fast() throughput. |
| Returns dict with tok/s for each method. |
| """ |
| import time |
| self.eval() |
| ids = tokenizer.encode(prompt, return_tensors='pt').to(device) |
| results = {} |
|
|
| for method_name, method in [('generate_slow', self.generate), |
| ('generate_fast', self.generate_fast)]: |
| timings = [] |
| for _ in range(n_runs): |
| torch.cuda.synchronize() if device == 'cuda' else None |
| t0 = time.perf_counter() |
| _ = method(ids.clone(), max_new=max_new, temperature=0.8, top_k=40) |
| torch.cuda.synchronize() if device == 'cuda' else None |
| t1 = time.perf_counter() |
| timings.append(max_new / (t1 - t0)) |
| results[method_name] = round(sum(timings) / n_runs, 1) |
| print(f" {method_name}: {results[method_name]:.1f} tok/s") |
|
|
| speedup = results['generate_fast'] / results['generate_slow'] |
| results['speedup_x'] = round(speedup, 2) |
| print(f" Speedup: {speedup:.1f}×") |
| return results |
|
|
| def param_count(self) -> int: |
| return sum(p.numel() for p in self.parameters()) |
|
|
|
|
| if __name__ == "__main__": |
| cfg = CDMConfigV2() |
| model = CDMLanguageModelV2(cfg) |
| n = model.param_count() |
| print(f"CDM V2: {n:,} params ({n/1e6:.1f}M)") |
| print(f" K={cfg.K}, d={cfg.d_model}, L={cfg.n_layers}, entropy_reg={cfg.entropy_reg}") |
|
|
| x = torch.randint(0, cfg.vocab_size, (2, 64)) |
| model.train() |
| logits, aux = model(x) |
| loss = F.cross_entropy(logits[:, :-1].reshape(-1, cfg.vocab_size), x[:, 1:].reshape(-1)) |
| total = loss + aux |
| total.backward() |
| print(f" Forward: {x.shape} → {logits.shape}") |
| print(f" CE loss={loss.item():.4f} entropy_reg={aux.item():.4f}") |
| print(f" Gradients OK: {all(p.grad is not None for p in model.parameters() if p.requires_grad)}") |
| print("OK") |
|
|