CLIWorks commited on
Commit
e86c5bf
·
verified ·
1 Parent(s): bf84936

Upload mythos-fineweb-dense.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. mythos-fineweb-dense.py +1237 -0
mythos-fineweb-dense.py ADDED
@@ -0,0 +1,1237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ SpiderPortal v5-Dense: English pretraining on FineWeb-Edu with AdamW.
4
+
5
+ Architecture: RDT (2 prelude + 6 recurrent + 2 coda) with:
6
+ - MLA (Multi-Latent Attention): 10.7x KV cache compression + sliding window
7
+ - Engram conditional memory at recurrent layers 1 and 4
8
+ - Dense FFN (all params active, MoE conversion in Phase 2)
9
+ - LTI Injection + ACT Halting + LoRA Adapter
10
+ - 32k context (extendable to 256k at inference via YaRN)
11
+
12
+ Config: hidden_size=2048, 6 recurrent layers, 32 experts (Phase 2), top-2 routing
13
+
14
+ Single GPU:
15
+ python mythos-fineweb-dense.py
16
+
17
+ Multi-GPU:
18
+ torchrun --nproc_per_node=$(python -c "import torch; print(torch.cuda.device_count())") mythos-fineweb-dense.py
19
+ """
20
+ import os
21
+ import math
22
+ import time
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+ import torch.distributed as dist
27
+ from loguru import logger
28
+ import sys
29
+
30
+ # Configure loguru to file + stderr
31
+ LOG_FILE = "train_spiderportal.log"
32
+ logger.remove()
33
+ logger.add(sys.stderr, format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}")
34
+ logger.add(LOG_FILE, rotation="100 MB", retention="10 days", format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}")
35
+
36
+ # Speed up CUDA memory allocation
37
+ import torch
38
+ torch.cuda.empty_cache()
39
+
40
+ # Numba CPU fallback
41
+ from triton_kernels import (
42
+ numba_dispatch,
43
+ NUMBA_AVAILABLE as _NUMBA_OK,
44
+ )
45
+ from numba_cuda_kernels import (
46
+ cuda_engram_hash,
47
+ cuda_engram_gate,
48
+ cuda_act_halting,
49
+ cuda_engram_conv1d,
50
+ cuda_available as _CUDA_OK,
51
+ )
52
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
53
+ from torch.distributed.fsdp import (
54
+ FullyShardedDataParallel as FSDP,
55
+ ShardingStrategy,
56
+ MixedPrecision,
57
+ FullStateDictConfig,
58
+ StateDictType,
59
+ )
60
+ from torch.distributed.fsdp.wrap import ModuleWrapPolicy
61
+ from torch.utils.data import IterableDataset, DataLoader, get_worker_info
62
+ from contextlib import nullcontext
63
+ from dataclasses import dataclass, field
64
+ from typing import Optional, Tuple, Dict, List
65
+ from torch.nn import CrossEntropyLoss
66
+ from datasets import load_dataset
67
+ from transformers import AutoTokenizer
68
+
69
+
70
+ # ---------------------------------------------------------------------------
71
+ # SpiderPortal Model Architecture (Dense + MLA + Engram)
72
+ # ---------------------------------------------------------------------------
73
+
74
+ @dataclass
75
+ class SpiderPortalConfig:
76
+ vocab_size: int = 50257
77
+ hidden_size: int = 2048
78
+ num_hidden_layers: int = 6
79
+ num_attention_heads: int = 16
80
+ num_key_value_heads: int = 4
81
+ intermediate_size: int = 8192
82
+ hidden_act: str = "silu"
83
+ num_experts: int = 32
84
+ num_experts_per_tok: int = 2
85
+ num_shared_experts: int = 1
86
+ router_aux_loss_coef: float = 0.05
87
+ max_loop_iters: int = 2
88
+ act_threshold: float = 0.5
89
+ max_position_embeddings: int = 32768
90
+ rope_theta: float = 10000000.0
91
+ rope_scaling: dict = None
92
+ sliding_window: int = 4096
93
+ attention_dropout: float = 0.0
94
+ rms_norm_eps: float = 1e-6
95
+ initializer_range: float = 0.02
96
+ use_cache: bool = True
97
+ tie_word_embeddings: bool = True
98
+ prelude_layers: int = 2
99
+ coda_layers: int = 2
100
+ lora_rank: int = 128
101
+ loop_embed_dim: int = 128
102
+ vision_hidden_size: int = 2048
103
+ audio_hidden_size: int = 512
104
+ vision_num_frames: int = 60
105
+ vision_tokens_per_frame: int = 256
106
+ vision_temporal_tokens: int = 64
107
+ vision_temporal_layers: int = 2
108
+ model_type: str = "spiderportal"
109
+ torch_dtype: str = "bfloat16"
110
+
111
+ # MLA parameters (DeepSeek-V2 style, scaled for hidden_size=2048)
112
+ kv_lora_rank: int = 128
113
+ q_lora_rank: int = 256
114
+ qk_rope_head_dim: int = 64
115
+ qk_nope_head_dim: int = 64
116
+ v_head_dim: int = 64
117
+
118
+ # Engram parameters (DeepSeek conditional memory)
119
+ engram_layers: List[int] = field(default_factory=lambda: [1, 4])
120
+ engram_ngram_orders: Tuple[int, ...] = (2, 3)
121
+ engram_hash_heads: int = 4
122
+ engram_table_size: int = 65537 # prime number for hash table
123
+ engram_conv_kernel: int = 4
124
+ engram_conv_dilation: int = 3
125
+ engram_dim: int = 128 # per-head embedding dimension
126
+
127
+
128
+ def loop_index_embedding(h, loop_t, loop_dim, theta=10000.0):
129
+ freqs = 1.0 / (theta ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim))
130
+ angles = loop_t * freqs
131
+ emb = torch.cat([angles.sin(), angles.cos()], dim=-1)[:loop_dim]
132
+ emb_full = torch.zeros(h.shape[-1], device=h.device, dtype=h.dtype)
133
+ emb_full[:loop_dim] = emb
134
+ return h + emb_full.unsqueeze(0).unsqueeze(0)
135
+
136
+
137
+ class SpiderPortalRMSNorm(nn.Module):
138
+ def __init__(self, hidden_size, eps=1e-6):
139
+ super().__init__()
140
+ self.weight = nn.Parameter(torch.ones(hidden_size))
141
+ self.variance_epsilon = eps
142
+ def forward(self, hidden_states):
143
+ # bf16-only RMSNorm: no dtype conversions inside forward.
144
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
145
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
146
+ return self.weight * hidden_states
147
+
148
+
149
+ def compute_yarn_inv_freq(head_dim, rope_theta, factor, orig_max, beta_fast=32.0, beta_slow=1.0):
150
+ dim = head_dim
151
+ orig_inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2).float() / dim))
152
+ pos_freqs = torch.arange(0, dim, 2).float() / dim
153
+ beta = (pos_freqs * math.log(rope_theta) / math.log(orig_max))
154
+ scale = torch.where(beta < beta_slow, torch.ones_like(beta), torch.where(beta > beta_fast, torch.ones_like(beta) / factor, 1.0 - (beta - beta_slow) / (beta_fast - beta_slow) * (1.0 - 1.0 / factor)))
155
+ return orig_inv_freq * scale
156
+
157
+
158
+ # ---------------------------------------------------------------------------
159
+ # MLA: Multi-Latent Attention (DeepSeek-V2 style) + Sliding Window
160
+ # ---------------------------------------------------------------------------
161
+
162
+ class SpiderPortalMLA(nn.Module):
163
+ """Multi-Latent Attention with compressed KV cache and sliding window.
164
+
165
+ For hidden_size=2048, num_heads=16:
166
+ - qk_nope_head_dim=64, qk_rope_head_dim=64 → total head_dim=128
167
+ - kv_lora_rank=128 → 10.7x compression vs full 2048-dim KV
168
+ - v_head_dim=64 → value projection
169
+ - sliding_window=4096 → local attention range
170
+ """
171
+ def __init__(self, config):
172
+ super().__init__()
173
+ self.config = config
174
+ self.hidden_size = config.hidden_size
175
+ self.num_heads = config.num_attention_heads
176
+ self.kv_lora_rank = config.kv_lora_rank
177
+ self.q_lora_rank = config.q_lora_rank
178
+ self.qk_rope_head_dim = config.qk_rope_head_dim
179
+ self.qk_nope_head_dim = config.qk_nope_head_dim
180
+ self.v_head_dim = config.v_head_dim
181
+ self.head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
182
+ self.sliding_window = getattr(config, 'sliding_window', None)
183
+
184
+ # Q projection: optional low-rank → full Q
185
+ if self.q_lora_rank > 0:
186
+ self.q_a_proj = nn.Linear(config.hidden_size, self.q_lora_rank, bias=False)
187
+ self.q_a_layernorm = SpiderPortalRMSNorm(self.q_lora_rank)
188
+ self.q_b_proj = nn.Linear(self.q_lora_rank, self.num_heads * self.head_dim, bias=False)
189
+ else:
190
+ self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
191
+
192
+ # KV compression: hidden → kv_lora_rank (shared latent)
193
+ self.kv_a_proj_with_mqa = nn.Linear(config.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=False)
194
+ self.kv_a_layernorm = SpiderPortalRMSNorm(self.kv_lora_rank)
195
+ # Decompress: kv_lora_rank → nope heads + v heads
196
+ self.kv_b_proj = nn.Linear(
197
+ self.kv_lora_rank,
198
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
199
+ bias=False,
200
+ )
201
+ # Output projection
202
+ self.o_proj = nn.Linear(self.num_heads * self.v_head_dim, config.hidden_size, bias=False)
203
+
204
+ # RoPE frequencies
205
+ rope_scaling = getattr(config, 'rope_scaling', None)
206
+ if rope_scaling and rope_scaling.get("type") == "yarn":
207
+ factor = rope_scaling.get("factor", 1.0)
208
+ orig_max_pos = rope_scaling.get("original_max_position_embeddings", config.max_position_embeddings)
209
+ inv_freq = compute_yarn_inv_freq(self.qk_rope_head_dim, config.rope_theta, factor, orig_max_pos)
210
+ else:
211
+ inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, self.qk_rope_head_dim, 2).float() / self.qk_rope_head_dim))
212
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
213
+
214
+ def _rotate_half(self, x):
215
+ x1 = x[..., :x.shape[-1] // 2]
216
+ x2 = x[..., x.shape[-1] // 2:]
217
+ return torch.cat((-x2, x1), dim=-1)
218
+
219
+ def _apply_rotary(self, x, cos, sin):
220
+ return (x * cos) + (self._rotate_half(x) * sin)
221
+
222
+ def _make_sliding_window_mask(self, q_len, kv_len, device, dtype):
223
+ """Unused: sliding_window disabled."""
224
+ return None
225
+
226
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
227
+ bsz, q_len, _ = hidden_states.size()
228
+ # Q projection
229
+ if self.q_lora_rank > 0:
230
+ q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
231
+ else:
232
+ q = self.q_proj(hidden_states)
233
+ q = q.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
234
+ q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
235
+
236
+ # KV: compress to latent, then decompress
237
+ kv_hidden = self.kv_a_proj_with_mqa(hidden_states)
238
+ kv_latent, k_rope = torch.split(kv_hidden, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
239
+ kv_latent_norm = self.kv_a_layernorm(kv_latent)
240
+ kv_b_out = self.kv_b_proj(kv_latent_norm)
241
+ k_nope, v = torch.split(kv_b_out, [self.num_heads * self.qk_nope_head_dim, self.num_heads * self.v_head_dim], dim=-1)
242
+
243
+ k_nope = k_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2)
244
+ v = v.view(bsz, q_len, self.num_heads, self.v_head_dim).transpose(1, 2)
245
+ k_rope = k_rope.unsqueeze(1)
246
+
247
+ # RoPE on Q and K rope parts
248
+ if position_ids is None:
249
+ position_ids = torch.arange(q_len, device=hidden_states.device).unsqueeze(0).expand(bsz, -1)
250
+ max_pos = position_ids.max().item() + 1
251
+ seq_len = max(max_pos, q_len)
252
+ t = torch.arange(seq_len, device=hidden_states.device, dtype=self.inv_freq.dtype)
253
+ freqs = torch.outer(t, self.inv_freq)
254
+ emb = torch.cat((freqs, freqs), dim=-1)
255
+ cos, sin = emb.cos(), emb.sin()
256
+ cos_full = cos[position_ids].unsqueeze(1)
257
+ sin_full = sin[position_ids].unsqueeze(1)
258
+
259
+ q_rope = self._apply_rotary(q_rope, cos_full, sin_full)
260
+ k_rope = self._apply_rotary(k_rope, cos_full, sin_full)
261
+
262
+ # Assemble full K
263
+ k_rope_expanded = k_rope.expand(-1, self.num_heads, -1, -1)
264
+ k_full = torch.cat([k_nope, k_rope_expanded], dim=-1)
265
+ q_full = torch.cat([q_nope, q_rope], dim=-1)
266
+
267
+ # KV cache
268
+ if past_key_value is not None:
269
+ k_full = torch.cat([past_key_value[0], k_full], dim=2)
270
+ v = torch.cat([past_key_value[1], v], dim=2)
271
+ past_kv = (k_full, v) if use_cache else None
272
+
273
+ # Attention with SDPA — is_causal=True for flash-attention fast path
274
+ # No 4D causal mask needed; sliding window disabled, so pure causal.
275
+ attn_output = F.scaled_dot_product_attention(
276
+ q_full, k_full, v,
277
+ attn_mask=None,
278
+ dropout_p=self.config.attention_dropout if self.training else 0.0,
279
+ is_causal=True,
280
+ )
281
+ attn_output = attn_output.transpose(1, 2).contiguous()
282
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
283
+ return self.o_proj(attn_output), past_kv
284
+
285
+
286
+ # ---------------------------------------------------------------------------
287
+ # Engram: Conditional Memory via Scalable Lookup (DeepSeek style)
288
+ # ---------------------------------------------------------------------------
289
+
290
+ def _tokenizer_compress(token_ids, vocab_size=50257):
291
+ """Simulate NFKC + lowercase canonical ID projection."""
292
+ return token_ids % (vocab_size * 77 // 100)
293
+
294
+
295
+ class SpiderPortalEngram(nn.Module):
296
+ """Conditional memory module via NN-gram lookup.
297
+
298
+ Applied only at specific recurrent layers (config.engram_layers).
299
+ """
300
+ def __init__(self, config):
301
+ super().__init__()
302
+ self.config = config
303
+ self.ngram_orders = list(config.engram_ngram_orders)
304
+ self.num_heads_per_order = config.engram_hash_heads
305
+ self.table_size = config.engram_table_size
306
+ self.d_mem = config.engram_dim
307
+
308
+ self.total_mem_dim = len(self.ngram_orders) * self.num_heads_per_order * self.d_mem
309
+
310
+ # Stacked embedding table with offsets: [orders, heads, table_size, d_mem]
311
+ # This matches the deepseek MultiHeadEmbedding principle.
312
+ self.embed = nn.Parameter(
313
+ torch.randn(len(self.ngram_orders), self.num_heads_per_order, self.table_size, self.d_mem) * 0.02
314
+ )
315
+
316
+ # Seeds per (order, head) in a stable head_counter ordering.
317
+ seeds = []
318
+ for _order in self.ngram_orders:
319
+ for h in range(self.num_heads_per_order):
320
+ seeds.append((h + 1) * 2654435761)
321
+ self.register_buffer("hash_seeds", torch.tensor(seeds, dtype=torch.int64), persistent=False)
322
+
323
+ self.W_k = nn.Linear(self.total_mem_dim, config.hidden_size, bias=False)
324
+ self.W_v = nn.Linear(self.total_mem_dim, config.hidden_size, bias=False)
325
+
326
+ self.conv = nn.Conv1d(
327
+ config.hidden_size, config.hidden_size,
328
+ kernel_size=config.engram_conv_kernel,
329
+ padding=config.engram_conv_kernel - 1,
330
+ groups=config.hidden_size,
331
+ )
332
+ self.conv_dilation = config.engram_conv_dilation
333
+
334
+ with torch.no_grad():
335
+ self.conv.weight.zero_()
336
+ if self.conv.bias is not None:
337
+ self.conv.bias.zero_()
338
+
339
+ self.q_norm = SpiderPortalRMSNorm(config.hidden_size)
340
+ self.k_norm = SpiderPortalRMSNorm(config.hidden_size)
341
+
342
+ # No caching: required for gradient checkpoint recomputation stability.
343
+ self._fwd_cache = None
344
+
345
+ def _compute_indices(self, compressed_ids, n, head_idx):
346
+ """Vectorized NN-gram hash indices for a single (order, head)."""
347
+ # Kept for backward compatibility; not used in the stacked embedding path.
348
+ bsz, seq_len = compressed_ids.shape
349
+ pad = torch.zeros(bsz, n - 1, dtype=compressed_ids.dtype, device=compressed_ids.device)
350
+ padded = torch.cat([pad, compressed_ids], dim=1)
351
+
352
+ indices_list = []
353
+ for i in range(n):
354
+ indices_list.append(padded[:, i:i + seq_len])
355
+ ngrams = torch.stack(indices_list, dim=-1)
356
+
357
+ seed = int(self.hash_seeds[head_idx].item())
358
+ h_val = torch.zeros(bsz, seq_len, dtype=torch.int64, device=compressed_ids.device)
359
+ for i in range(n):
360
+ h_val = h_val * 31 + ngrams[:, :, i]
361
+ h_val = h_val % self.table_size
362
+ h_val = (h_val * seed) % self.table_size
363
+ return h_val
364
+
365
+ def _compute_hash(self, compressed, n, head_counter, bsz, seq_len):
366
+ """Compute n-gram hash indices, with Numba CPU fallback."""
367
+ if not compressed.is_cuda and NUMBA_AVAILABLE:
368
+ import numpy as np
369
+ h_val_np = numba_dispatch(
370
+ "hash_indices",
371
+ compressed.cpu().numpy().astype(np.int64),
372
+ n, self.table_size,
373
+ int(self.hash_seeds[head_counter].item()),
374
+ )
375
+ if h_val_np is not None:
376
+ return torch.from_numpy(h_val_np).to(compressed.device)
377
+
378
+ pad = torch.zeros(bsz, n - 1, dtype=compressed.dtype, device=compressed.device)
379
+ padded = torch.cat([pad, compressed], dim=1)
380
+ ngrams = torch.stack([padded[:, i : i + seq_len] for i in range(n)], dim=-1)
381
+ h_val = torch.zeros(bsz, seq_len, dtype=torch.int64, device=compressed.device)
382
+ for i in range(n):
383
+ h_val = h_val * 31 + ngrams[:, :, i].to(torch.int64)
384
+ h_val = h_val % self.table_size
385
+ return h_val
386
+
387
+ def _retrieve(self, token_ids):
388
+ """Retrieve memory vectors for a batch of token sequences."""
389
+ bsz, seq_len = token_ids.shape
390
+ compressed = _tokenizer_compress(token_ids)
391
+
392
+ # Use Numba CUDA hash if faster (PyTorch path is default, ~0.2ms per call)
393
+ indices = cuda_engram_hash(
394
+ compressed, self.hash_seeds,
395
+ self.ngram_orders, self.num_heads_per_order, self.table_size,
396
+ )
397
+ if indices is not None:
398
+ all_parts = []
399
+ head_counter = 0
400
+ for order_idx, n in enumerate(self.ngram_orders):
401
+ head_indices = indices[:, :, head_counter:head_counter + self.num_heads_per_order]
402
+ emb_table = self.embed[order_idx]
403
+ idx = head_indices.permute(0, 2, 1).unsqueeze(-1).expand(-1, -1, -1, self.d_mem)
404
+ mem = torch.gather(emb_table.unsqueeze(0).expand(bsz, -1, -1, -1), dim=2, index=idx)
405
+ mem = mem.permute(0, 2, 1, 3).reshape(bsz, seq_len, self.num_heads_per_order * self.d_mem)
406
+ all_parts.append(mem)
407
+ head_counter += self.num_heads_per_order
408
+ return torch.cat(all_parts, dim=-1)
409
+
410
+ # PyTorch fallback (CPU or if CUDA kernel unavailable)
411
+ all_parts = []
412
+ head_counter = 0
413
+ for order_idx, n in enumerate(self.ngram_orders):
414
+ h_val = self._compute_hash(compressed, n, head_counter, bsz, seq_len)
415
+ seeds_slice = self.hash_seeds[head_counter : head_counter + self.num_heads_per_order]
416
+ indices_pt = (h_val.unsqueeze(-1) * seeds_slice.view(1, 1, -1)) % self.table_size
417
+ emb_table = self.embed[order_idx]
418
+ idx = indices_pt.permute(0, 2, 1).unsqueeze(-1).expand(-1, -1, -1, self.d_mem)
419
+ mem = torch.gather(emb_table.unsqueeze(0).expand(bsz, -1, -1, -1), dim=2, index=idx)
420
+ mem = mem.permute(0, 2, 1, 3).reshape(bsz, seq_len, self.num_heads_per_order * self.d_mem)
421
+ all_parts.append(mem)
422
+ head_counter += self.num_heads_per_order
423
+ return torch.cat(all_parts, dim=-1)
424
+
425
+ def forward(self, hidden_states, token_ids, layer_id: int):
426
+ mem = self._retrieve(token_ids)
427
+
428
+ q = hidden_states
429
+ k = self.W_k(mem)
430
+ v = self.W_v(mem)
431
+ q_norm = self.q_norm(q)
432
+ k_norm = self.k_norm(k)
433
+ alpha = torch.sigmoid(
434
+ (q_norm * k_norm).sum(dim=-1, keepdim=True) / math.sqrt(q.shape[-1])
435
+ )
436
+ v_gated = alpha * v
437
+ v_gated_t = v_gated.transpose(1, 2)
438
+ conv_out = self.conv(v_gated_t)
439
+ conv_out = conv_out[:, :, :v_gated_t.shape[-1]]
440
+ conv_out = conv_out.transpose(1, 2)
441
+
442
+ y = F.silu(conv_out) + v_gated
443
+ return y
444
+
445
+
446
+ # ---------------------------------------------------------------------------
447
+ # FFN Expert (dense)
448
+ # ---------------------------------------------------------------------------
449
+
450
+ class SpiderPortalExpert(nn.Module):
451
+ def __init__(self, config, intermediate_size=None):
452
+ super().__init__()
453
+ inter_size = intermediate_size or config.intermediate_size
454
+ self.gate_proj = nn.Linear(config.hidden_size, inter_size, bias=False)
455
+ self.up_proj = nn.Linear(config.hidden_size, inter_size, bias=False)
456
+ self.down_proj = nn.Linear(inter_size, config.hidden_size, bias=False)
457
+ self.act_fn = nn.SiLU()
458
+ def forward(self, hidden_states):
459
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
460
+
461
+
462
+ # ---------------------------------------------------------------------------
463
+ # Prelude/Coda Dense Layer (uses MLA)
464
+ # ---------------------------------------------------------------------------
465
+
466
+ class SpiderPortalDenseLayer(nn.Module):
467
+ """Prelude/coda dense layer with MLA attention."""
468
+ def __init__(self, config):
469
+ super().__init__()
470
+ self.self_attn = SpiderPortalMLA(config)
471
+ dense_intermediate = config.hidden_size * 4 // 3
472
+ self.ffn = SpiderPortalExpert(config, intermediate_size=dense_intermediate)
473
+ self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
474
+ self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
475
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
476
+ attn_input = self.input_layernorm(hidden_states)
477
+ attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
478
+ hidden_states = hidden_states + attn_output
479
+ ffn_input = self.post_attention_layernorm(hidden_states)
480
+ ffn_output = self.ffn(ffn_input)
481
+ hidden_states = hidden_states + ffn_output
482
+ return hidden_states, past_kv
483
+
484
+
485
+ # ---------------------------------------------------------------------------
486
+ # Recurrent Dense Layer (uses MLA + optional Engram)
487
+ # ---------------------------------------------------------------------------
488
+
489
+ class SpiderPortalRecurrentDenseLayer(nn.Module):
490
+ """Recurrent layer with MLA attention and optional Engram memory."""
491
+ def __init__(self, config, layer_idx, has_engram=False):
492
+ super().__init__()
493
+ self.layer_idx = layer_idx
494
+ self.has_engram = has_engram
495
+ self.self_attn = SpiderPortalMLA(config)
496
+ if has_engram:
497
+ self.engram = SpiderPortalEngram(config)
498
+ self.ffn = SpiderPortalExpert(config, intermediate_size=config.intermediate_size)
499
+ self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
500
+ self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
501
+ self.post_engram_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) if has_engram else None
502
+ def forward(self, hidden_states, token_ids=None, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
503
+ attn_input = self.input_layernorm(hidden_states)
504
+ attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
505
+ hidden_states = hidden_states + attn_output
506
+
507
+ if self.has_engram and token_ids is not None:
508
+ engram_out = self.engram(hidden_states, token_ids, layer_id=self.layer_idx)
509
+ hidden_states = hidden_states + engram_out
510
+ if self.post_engram_layernorm is not None:
511
+ hidden_states = self.post_engram_layernorm(hidden_states)
512
+
513
+ ffn_input = self.post_attention_layernorm(hidden_states)
514
+ ffn_output = self.ffn(ffn_input)
515
+ hidden_states = hidden_states + ffn_output
516
+ return hidden_states, 0.0, past_kv
517
+
518
+
519
+ # ---------------------------------------------------------------------------
520
+ # LTI Injection, ACT Halting, LoRA Adapter
521
+ # ---------------------------------------------------------------------------
522
+
523
+ class LTIInjection(nn.Module):
524
+ def __init__(self, config):
525
+ super().__init__()
526
+ self.hidden_size = config.hidden_size
527
+ self.log_A = nn.Parameter(torch.full((config.hidden_size,), -2.0))
528
+ self.delta_t = nn.Parameter(torch.tensor(1.0))
529
+ self.B = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
530
+ with torch.no_grad():
531
+ self.B.weight.data.normal_(mean=0.0, std=0.01)
532
+ def get_A(self):
533
+ return -torch.exp(self.log_A)
534
+ def forward(self, h_t, e):
535
+ A = self.get_A()
536
+ return A * h_t + self.B(e)
537
+
538
+
539
+ class ACTHalting(nn.Module):
540
+ def __init__(self, config):
541
+ super().__init__()
542
+ self.halt_predictor = nn.Linear(config.hidden_size, 1)
543
+ self.threshold = config.act_threshold
544
+ def forward(self, hidden_states):
545
+ return torch.sigmoid(self.halt_predictor(hidden_states))
546
+
547
+
548
+ class LoRAAdapter(nn.Module):
549
+ def __init__(self, config):
550
+ super().__init__()
551
+ rank = config.lora_rank
552
+ self.down = nn.Linear(config.hidden_size, rank, bias=False)
553
+ self.B = nn.Parameter(torch.randn(rank, config.hidden_size) * 0.02)
554
+ self.scale = nn.Embedding(config.max_loop_iters, rank)
555
+ with torch.no_grad():
556
+ self.scale.weight.data.zero_()
557
+ self.down.weight.data.normal_(mean=0.0, std=0.001)
558
+ def forward(self, x, loop_t):
559
+ max_t = self.scale.num_embeddings - 1
560
+ t_idx = min(loop_t, max_t)
561
+ s = self.scale(torch.tensor(t_idx, device=x.device))
562
+ down = self.down(x) * s
563
+ return down @ self.B
564
+
565
+
566
+ def checkpoint(func, *args, **kwargs):
567
+ """Gradient checkpointing wrapper — saves VRAM at ~20% compute cost."""
568
+ if torch.is_grad_enabled():
569
+ return torch.utils.checkpoint.checkpoint(func, *args, use_reentrant=False, **kwargs)
570
+ return func(*args, **kwargs)
571
+
572
+
573
+ # ---------------------------------------------------------------------------
574
+ # Full Model
575
+ # ---------------------------------------------------------------------------
576
+
577
+ class SpiderPortalDenseModel(nn.Module):
578
+ """Full RDT model with MLA attention + Engram memory at layers 1,4.
579
+
580
+ Architecture:
581
+ 2x Prelude (MLA + dense FFN)
582
+ 6x Recurrent (MLA + Engram@L1,L4 + dense FFN) — with gradient checkpointing
583
+ 2x Coda (MLA + dense FFN)
584
+ LTI Injection + ACT Halting + LoRA Adapter
585
+ """
586
+ def __init__(self, config):
587
+ super().__init__()
588
+ self.config = config
589
+ self.prelude_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.prelude_layers)])
590
+ self.recurrent_layers = nn.ModuleList([
591
+ SpiderPortalRecurrentDenseLayer(config, i, has_engram=(i in config.engram_layers))
592
+ for i in range(config.num_hidden_layers)
593
+ ])
594
+ self.coda_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.coda_layers)])
595
+ self.norm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
596
+ self.injection = LTIInjection(config)
597
+ self.act_halting = ACTHalting(config)
598
+ self.lora_adapter = LoRAAdapter(config)
599
+ self.loop_embed_dim = config.loop_embed_dim
600
+ def forward(self, hidden_states, input_embedding=None, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, n_loops=None, token_ids=None):
601
+ n_loops = n_loops or 1
602
+ input_embedding = input_embedding if input_embedding is not None else hidden_states
603
+ for layer in self.prelude_layers:
604
+ hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
605
+ e = hidden_states.clone()
606
+ B, T_seq, D = hidden_states.shape
607
+ halted = torch.zeros(B, T_seq, device=hidden_states.device, dtype=torch.bool)
608
+ cumulative_p = torch.zeros(B, T_seq, device=hidden_states.device, dtype=hidden_states.dtype)
609
+ h_out = torch.zeros_like(hidden_states)
610
+ past_key_values = past_key_values if past_key_values is not None else [None] * len(self.recurrent_layers)
611
+ for t in range(n_loops):
612
+ h_loop = loop_index_embedding(hidden_states, t, self.loop_embed_dim)
613
+ if t > 0:
614
+ injection = self.injection(hidden_states, input_embedding)
615
+ hidden_states = hidden_states + injection
616
+ new_past_key_values = []
617
+ for i, layer in enumerate(self.recurrent_layers):
618
+ hidden_states, aux_loss, past_kv = checkpoint(
619
+ layer, hidden_states,
620
+ token_ids=token_ids,
621
+ attention_mask=attention_mask,
622
+ position_ids=position_ids,
623
+ past_key_value=past_key_values[i] if t == 0 else None,
624
+ use_cache=use_cache
625
+ )
626
+ new_past_key_values.append(past_kv)
627
+ lora_delta = self.lora_adapter(hidden_states, t)
628
+ hidden_states = hidden_states + lora_delta
629
+ halt_prob = self.act_halting(hidden_states).squeeze(-1)
630
+ still_running = ~halted
631
+ remainder = (1.0 - cumulative_p).clamp(min=0)
632
+ weight = torch.where(cumulative_p + halt_prob >= self.config.act_threshold, remainder, halt_prob)
633
+ weight = weight * still_running.to(hidden_states.dtype)
634
+ h_out = h_out + weight.unsqueeze(-1) * hidden_states
635
+ cumulative_p = cumulative_p + halt_prob * still_running.to(hidden_states.dtype)
636
+ halted = halted | (cumulative_p >= self.config.act_threshold)
637
+ if halted.all() and not self.training:
638
+ break
639
+ never_halted = (~halted).to(hidden_states.dtype).unsqueeze(-1)
640
+ hidden_states = h_out + never_halted * hidden_states
641
+ for layer in self.coda_layers:
642
+ hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
643
+ hidden_states = self.norm(hidden_states)
644
+ return hidden_states, 0.0, new_past_key_values
645
+
646
+
647
+ class SpiderPortalForConditionalGeneration(nn.Module):
648
+ def __init__(self, config):
649
+ super().__init__()
650
+ self.config = config
651
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
652
+ self.model = SpiderPortalDenseModel(config)
653
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
654
+ if config.tie_word_embeddings:
655
+ self.lm_head.weight = self.embed_tokens.weight
656
+ self.apply(self._init_weights)
657
+ def _init_weights(self, module):
658
+ if isinstance(module, nn.Linear):
659
+ if hasattr(self, 'model') and module is self.model.injection.B:
660
+ return
661
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
662
+ if module.bias is not None:
663
+ module.bias.data.zero_()
664
+ elif isinstance(module, nn.Embedding):
665
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
666
+ def forward(self, input_ids, attention_mask=None, position_ids=None, labels=None, n_loops=None, use_cache=False):
667
+ hidden_states = self.embed_tokens(input_ids)
668
+ model_dtype = next(self.model.parameters()).dtype
669
+ hidden_states = hidden_states.to(model_dtype)
670
+ input_embedding = hidden_states.clone()
671
+ hidden_states, aux_loss, past_kv = self.model(
672
+ hidden_states, input_embedding=input_embedding,
673
+ attention_mask=None, position_ids=position_ids,
674
+ use_cache=use_cache, n_loops=n_loops, token_ids=input_ids
675
+ )
676
+ logits = self.lm_head(hidden_states)
677
+ loss = None
678
+ if labels is not None:
679
+ shift_logits = logits[..., :-1, :].contiguous()
680
+ shift_labels = labels[..., 1:].contiguous()
681
+ loss_fct = CrossEntropyLoss()
682
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
683
+ return {"loss": loss, "logits": logits, "aux_loss": aux_loss, "past_key_values": past_kv}
684
+ def get_num_params(self):
685
+ total = sum(p.numel() for p in self.parameters())
686
+ return {"total": total, "trainable": total}
687
+
688
+
689
+ # ---------------------------------------------------------------------------
690
+ # Dataset
691
+ # ---------------------------------------------------------------------------
692
+
693
+ class FineWebEduDataset(IterableDataset):
694
+ def __init__(self, tokenizer, seq_len: int, subset: str, rank: int, world_size: int):
695
+ self.tokenizer = tokenizer
696
+ self.seq_len = seq_len
697
+ self.subset = subset
698
+ self.rank = rank
699
+ self.world_size = world_size
700
+
701
+ # Local tokenized data - USE mmapped binary for speed
702
+ LOCAL_TOKEN_FILE = "/data/fineweb_tokenized/train_tokens.bin"
703
+
704
+ if os.path.exists(LOCAL_TOKEN_FILE):
705
+ # Use memory-mapped file for fast I/O
706
+ import numpy as np
707
+ self.use_local = True
708
+ self.local_file = LOCAL_TOKEN_FILE
709
+ # Memory map for zero-copy reading
710
+ self.mm = np.memmap(LOCAL_TOKEN_FILE, dtype='<u4', mode='r')
711
+ self.num_tokens = len(self.mm)
712
+ self.num_samples = self.num_tokens // seq_len
713
+ else:
714
+ self.use_local = False
715
+
716
+ def __iter__(self):
717
+ if self.use_local:
718
+ # Fast: use memory-mapped array
719
+ worker = get_worker_info()
720
+ num_workers = worker.num_workers if worker else 1
721
+ worker_id = worker.id if worker else 0
722
+
723
+ samples_per_worker = self.num_samples // (self.world_size * num_workers)
724
+ start_sample = (self.rank * num_workers + worker_id) * samples_per_worker
725
+ end_sample = start_sample + samples_per_worker
726
+
727
+ # Batch read tokens - convert to numpy array slice then tensor
728
+ import numpy as np
729
+ for i in range(start_sample, end_sample):
730
+ start_idx = i * self.seq_len
731
+ # Direct slice from memory-mapped array (avoid extra copies when possible)
732
+ tokens = self.mm[start_idx : start_idx + self.seq_len + 1]
733
+ x_np = tokens[:-1].astype("int64", copy=False)
734
+ y_np = tokens[1:].astype("int64", copy=False)
735
+ yield torch.from_numpy(x_np), torch.from_numpy(y_np)
736
+ else:
737
+ # Fallback to HuggingFace
738
+ worker = get_worker_info()
739
+ num_workers = worker.num_workers if worker else 1
740
+ worker_id = worker.id if worker else 0
741
+ total_shards = self.world_size * num_workers
742
+ shard_index = self.rank * num_workers + worker_id
743
+ ds = load_dataset(
744
+ "HuggingFaceFW/fineweb-edu",
745
+ name=self.subset,
746
+ split="train",
747
+ streaming=True,
748
+ ).shard(num_shards=total_shards, index=shard_index)
749
+ buf = []
750
+ for sample in ds:
751
+ buf.extend(self.tokenizer.encode(sample["text"]))
752
+ while len(buf) >= self.seq_len + 1:
753
+ chunk = buf[: self.seq_len + 1]
754
+ buf = buf[self.seq_len + 1 :]
755
+ yield (
756
+ torch.tensor(chunk[:-1], dtype=torch.long),
757
+ torch.tensor(chunk[1:], dtype=torch.long),
758
+ )
759
+
760
+
761
+ # ---------------------------------------------------------------------------
762
+ # LR schedule
763
+ # ---------------------------------------------------------------------------
764
+
765
+ def get_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) -> float:
766
+ if step < warmup:
767
+ return max_lr * step / warmup
768
+ if step >= total:
769
+ return min_lr
770
+ decay = (step - warmup) / (total - warmup)
771
+ return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * decay))
772
+
773
+
774
+ # ---------------------------------------------------------------------------
775
+ # Checkpointing
776
+ # ---------------------------------------------------------------------------
777
+
778
+ def save_weights_only(model, step, epoch, ckpt_dir, ddp):
779
+ if ddp:
780
+ with FSDP.state_dict_type(
781
+ model,
782
+ StateDictType.FULL_STATE_DICT,
783
+ FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
784
+ ):
785
+ model_state = model.state_dict()
786
+ else:
787
+ model_state = model.state_dict()
788
+ ckpt_path = os.path.join(ckpt_dir, f"spiderportal-v5-dense-ep{epoch}-step{step}.pt")
789
+ tmp_path = ckpt_path + ".tmp"
790
+ torch.save(model_state, tmp_path)
791
+ os.replace(tmp_path, ckpt_path)
792
+ size_mb = os.path.getsize(ckpt_path) / (1024 * 1024)
793
+ return ckpt_path, size_mb
794
+
795
+
796
+ def save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, ckpt_name="full"):
797
+ if ddp:
798
+ with FSDP.state_dict_type(
799
+ model,
800
+ StateDictType.FULL_STATE_DICT,
801
+ FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
802
+ ):
803
+ model_state = model.state_dict()
804
+ optim_state = FSDP.optim_state_dict(model, optimizer)
805
+ else:
806
+ model_state = model.state_dict()
807
+ optim_state = optimizer.state_dict()
808
+ if not master:
809
+ return None, 0
810
+ os.makedirs(ckpt_dir, exist_ok=True)
811
+ final_path = os.path.join(ckpt_dir, f"spiderportal-v5-dense-{ckpt_name}.pt")
812
+ tmp_path = final_path + ".tmp"
813
+ torch.save(
814
+ {
815
+ "step": step,
816
+ "epoch": epoch,
817
+ "model_state_dict": model_state,
818
+ "optimizer_state_dict": optim_state,
819
+ "cfg": cfg,
820
+ "vocab_size": vocab_size,
821
+ },
822
+ tmp_path,
823
+ )
824
+ os.replace(tmp_path, final_path)
825
+ size_mb = os.path.getsize(final_path) / (1024 * 1024)
826
+ return final_path, size_mb
827
+
828
+
829
+ def load_checkpoint(model, optimizer, path, ddp):
830
+ ckpt = torch.load(path, map_location="cpu", weights_only=False)
831
+ if ddp:
832
+ with FSDP.state_dict_type(
833
+ model,
834
+ StateDictType.FULL_STATE_DICT,
835
+ FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
836
+ ):
837
+ model.load_state_dict(ckpt["model_state_dict"])
838
+ optim_state = FSDP.optim_state_dict_to_load(
839
+ model=model,
840
+ optim=optimizer,
841
+ optim_state_dict=ckpt["optimizer_state_dict"],
842
+ )
843
+ optimizer.load_state_dict(optim_state)
844
+ else:
845
+ model.load_state_dict(ckpt["model_state_dict"])
846
+ optimizer.load_state_dict(ckpt["optimizer_state_dict"])
847
+ return int(ckpt["step"]), int(ckpt.get("epoch", 0))
848
+
849
+
850
+ # ---------------------------------------------------------------------------
851
+ # Main
852
+ # ---------------------------------------------------------------------------
853
+
854
+ def main():
855
+ # ------------------------------------------------------------------
856
+ # Distributed init
857
+ # ------------------------------------------------------------------
858
+ ddp = int(os.environ.get("RANK", -1)) != -1
859
+ if ddp:
860
+ dist.init_process_group("nccl")
861
+ rank = int(os.environ["RANK"])
862
+ local_rank = int(os.environ["LOCAL_RANK"])
863
+ world_size = int(os.environ["WORLD_SIZE"])
864
+ device = f"cuda:{local_rank}"
865
+ torch.cuda.set_device(device)
866
+ else:
867
+ rank = local_rank = 0
868
+ world_size = 1
869
+ device = "cuda" if torch.cuda.is_available() else "cpu"
870
+ master = rank == 0
871
+
872
+ # ------------------------------------------------------------------
873
+ # Tokenizer
874
+ # ------------------------------------------------------------------
875
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
876
+ tokenizer.pad_token = tokenizer.eos_token
877
+ vocab_size = tokenizer.vocab_size
878
+ if master:
879
+ logger.info(f"Tokenizer: gpt2 | Vocab size: {vocab_size:,}")
880
+
881
+ # ------------------------------------------------------------------
882
+ # Hyperparameters
883
+ # ------------------------------------------------------------------
884
+ seq_len = int(os.environ.get("SEQ_LEN", "2048"))
885
+ micro_batch = int(os.environ.get("MICRO_BATCH", "32"))
886
+ target_tokens = int(os.environ.get("TARGET_TOKENS", "50_000_000"))
887
+ grad_accum = int(os.environ.get("GRAD_ACCUM", "1"))
888
+ global_batch_tok = world_size * micro_batch * grad_accum * seq_len
889
+ total_steps = target_tokens // global_batch_tok
890
+ warmup_steps = 200
891
+ lr = 3e-4
892
+ wd = 0.1
893
+ log_every = 10
894
+ ckpt_every = int(os.environ.get("CKPT_EVERY", "500"))
895
+ ckpt_dir = "checkpoints-dense"
896
+ dataset_subset = "sample-10BT"
897
+
898
+ if master:
899
+ logger.info(
900
+ f"[DENSE MLA+Engram] hidden=2048 | layers=6 | seq_len={seq_len} | micro_batch={micro_batch} | grad_accum={grad_accum} | "
901
+ f"global_batch_tokens={global_batch_tok:,} | total_steps={total_steps:,}"
902
+ )
903
+ logger.info(
904
+ "Attention: MLA (sliding_window disabled) | "
905
+ "Engram: layers [1,4] | Context: 32k | "
906
+ "Gradient checkpointing: enabled"
907
+ )
908
+
909
+ # ------------------------------------------------------------------
910
+ # Model
911
+ # ------------------------------------------------------------------
912
+ cfg = SpiderPortalConfig(
913
+ hidden_size=2048, num_hidden_layers=6, num_attention_heads=16,
914
+ num_key_value_heads=4, intermediate_size=4096,
915
+ num_experts=32, num_experts_per_tok=2, num_shared_experts=1,
916
+ router_aux_loss_coef=0.05, max_loop_iters=2,
917
+ prelude_layers=2, coda_layers=2, lora_rank=128,
918
+ rope_theta=10000000.0,
919
+ rope_scaling=None,
920
+ max_position_embeddings=32768,
921
+ sliding_window=0,
922
+ tie_word_embeddings=True,
923
+ kv_lora_rank=128, q_lora_rank=256,
924
+ qk_rope_head_dim=64, qk_nope_head_dim=64, v_head_dim=64,
925
+ engram_layers=[1, 4],
926
+ engram_ngram_orders=(2, 3),
927
+ engram_hash_heads=4,
928
+ engram_table_size=65537,
929
+ engram_dim=128,
930
+ )
931
+ cfg.vocab_size = vocab_size
932
+
933
+ bf16_ok = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
934
+ amp_dtype = torch.bfloat16 if bf16_ok else torch.float16
935
+
936
+ model = SpiderPortalForConditionalGeneration(cfg).to(torch.bfloat16)
937
+
938
+ if ddp:
939
+ mp_policy = MixedPrecision(
940
+ param_dtype=amp_dtype,
941
+ reduce_dtype=amp_dtype,
942
+ buffer_dtype=amp_dtype,
943
+ )
944
+ wrap_policy = ModuleWrapPolicy({SpiderPortalDenseLayer, SpiderPortalRecurrentDenseLayer})
945
+ model = FSDP(
946
+ model,
947
+ sharding_strategy=ShardingStrategy.FULL_SHARD,
948
+ mixed_precision=mp_policy,
949
+ auto_wrap_policy=wrap_policy,
950
+ device_id=local_rank,
951
+ )
952
+ else:
953
+ model = model.to(device)
954
+ # torch.compile disabled - causes issues with custom model
955
+ # model = torch.compile(model, mode="reduce-overhead")
956
+
957
+ if os.environ.get("SKIP_MXFP8", "0") == "1":
958
+ if master:
959
+ logger.info("MXFP8 disabled via SKIP_MXFP8=1 — using native bf16")
960
+ else:
961
+ # Apply MXFP8 (emulated) via _to_mxfp8_then_scaled_mm autograd Function.
962
+ # Uses single Function node (mem-efficient) with EMULATED kernels for compute 12.0 compat.
963
+ # Weights stay bf16; quantize/dequantize wraps mm for fp8 bandwidth savings.
964
+ try:
965
+ from torchao.prototype.mx_formats.mx_linear import _to_mxfp8_then_scaled_mm
966
+ from torchao.prototype.mx_formats.config import ScaleCalculationMode
967
+ from torchao.quantization.quantize_.common.kernel_preference import KernelPreference
968
+
969
+ # MXFP8: try AUTO (native cuBLAS) first, fall back to EMULATED
970
+ _mxfp8_kernel = KernelPreference.EMULATED
971
+ try:
972
+ _t_a = torch.randn(32, 2048, device='cuda', dtype=torch.bfloat16)
973
+ _t_b = torch.randn(32, 2048, device='cuda', dtype=torch.bfloat16)
974
+ _t_a_fp8 = _t_a.to(torch.float8_e4m3fn)
975
+ _t_b_fp8 = _t_b.to(torch.float8_e4m3fn)
976
+ _sa = torch.ones(32, 64, dtype=torch.float8_e8m0fnu, device='cuda')
977
+ _sb = torch.ones(32, 64, dtype=torch.float8_e8m0fnu, device='cuda')
978
+ from torchao.prototype.mx_formats.utils import to_blocked
979
+ _sa_b = to_blocked(_sa)
980
+ _sb_b = to_blocked(_sb)
981
+ _t_out = torch._scaled_mm(_t_a_fp8, _t_b_fp8.t(), scale_a=_sa_b, scale_b=_sb_b, out_dtype=torch.bfloat16)
982
+ _t_ref = _t_a @ _t_b.t()
983
+ if (_t_out.float() - _t_ref.float()).abs().max().item() < 10.0:
984
+ _mxfp8_kernel = KernelPreference.AUTO
985
+ if master:
986
+ logger.info("Native MXFP8 _scaled_mm available!")
987
+ except Exception:
988
+ pass
989
+
990
+ class MXFP8Linear(nn.Module):
991
+ def __init__(self, linear: nn.Linear):
992
+ super().__init__()
993
+ self.weight = nn.Parameter(linear.weight.data.clone().to(torch.bfloat16))
994
+ self.bias = nn.Parameter(linear.bias.data.clone().to(torch.bfloat16)) if linear.bias is not None else None
995
+ self.in_features = linear.in_features
996
+ self.out_features = linear.out_features
997
+
998
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
999
+ x = x.to(torch.bfloat16)
1000
+ in_dim = x.shape[-1]
1001
+ pad = (32 - in_dim % 32) % 32
1002
+ if pad:
1003
+ x = F.pad(x, (0, pad))
1004
+ w = F.pad(self.weight, (0, pad))
1005
+ else:
1006
+ w = self.weight
1007
+ orig = x.shape
1008
+ x_2d = x.reshape(-1, x.shape[-1])
1009
+ out = _to_mxfp8_then_scaled_mm(
1010
+ x_2d, w,
1011
+ kernel_preference=_mxfp8_kernel,
1012
+ scale_calculation_mode=ScaleCalculationMode.RCEIL,
1013
+ wgrad_with_hp=False,
1014
+ )
1015
+ out = out.reshape(*orig[:-1], out.shape[-1])
1016
+ if self.bias is not None:
1017
+ out = out + self.bias
1018
+ return out
1019
+
1020
+ total = sum(1 for _ in model.modules() if isinstance(_, nn.Linear))
1021
+ count = 0
1022
+ for name, mod in list(model.named_modules()):
1023
+ if isinstance(mod, nn.Linear):
1024
+ if mod.in_features % 32 != 0 or mod.out_features % 32 != 0:
1025
+ continue
1026
+ parent_name, _, child_name = name.rpartition(".")
1027
+ parent = model.get_submodule(parent_name) if parent_name else model
1028
+ setattr(parent, child_name, MXFP8Linear(mod))
1029
+ count += 1
1030
+ if master:
1031
+ _mode = "AUTO" if _mxfp8_kernel == KernelPreference.AUTO else "EMULATED"
1032
+ logger.info(f"MXFP8 ({_mode}) quantized {count}/{total} Linear layers")
1033
+ except Exception as e:
1034
+ if master:
1035
+ logger.info(f"MXFP8 not available, bf16 fallback: {e}")
1036
+ import traceback; traceback.print_exc()
1037
+
1038
+ amp_ctx = (
1039
+ torch.amp.autocast(device_type="cuda", dtype=amp_dtype)
1040
+ if "cuda" in device
1041
+ else nullcontext()
1042
+ )
1043
+
1044
+ amp_ctx = nullcontext() if ddp else amp_ctx
1045
+
1046
+ # Enable SDPA best kernels when available.
1047
+ try:
1048
+ from torch.nn.attention import sdpa_kernel
1049
+
1050
+ sdpa_ctx = sdpa_kernel(enable_flash=True, enable_mem_efficient=True, enable_math=True)
1051
+ except Exception:
1052
+ sdpa_ctx = nullcontext()
1053
+
1054
+ if master:
1055
+ n_params = sum(p.numel() for p in model.parameters())
1056
+ engram_params = sum(p.numel() for n, p in model.named_parameters() if 'engram' in n)
1057
+ mla_params = sum(p.numel() for n, p in model.named_parameters() if 'self_attn' in n)
1058
+ embed_params = sum(p.numel() for n, p in model.named_parameters() if 'embed_tokens' in n or 'lm_head' in n)
1059
+ ffn_params = sum(p.numel() for n, p in model.named_parameters() if 'ffn' in n or 'gate_proj' in n or 'up_proj' in n or 'down_proj' in n)
1060
+ other_params = n_params - engram_params - mla_params - embed_params - ffn_params
1061
+ logger.info(
1062
+ f"Parameters: {n_params:,} (all active) | "
1063
+ f"Embeddings: {embed_params:,} | MLA: {mla_params:,} | "
1064
+ f"FFN: {ffn_params:,} | Engram: {engram_params:,} | "
1065
+ f"Other: {other_params:,} | AMP dtype: {amp_dtype}"
1066
+ )
1067
+
1068
+ # ------------------------------------------------------------------
1069
+ # Optimizer — dual optimizer for Engram embeddings
1070
+ # ------------------------------------------------------------------
1071
+ engram_params_list = [p for n, p in model.named_parameters() if 'engram' in n and 'embed' in n and 'proj' not in n]
1072
+ backbone_params = [p for n, p in model.named_parameters() if not ('engram' in n and 'embed' in n and 'proj' not in n)]
1073
+
1074
+ optimizer = torch.optim.AdamW(
1075
+ backbone_params, lr=lr, weight_decay=wd, betas=(0.9, 0.95), fused=False, foreach=True, eps=1e-8
1076
+ )
1077
+ if engram_params_list:
1078
+ engram_optimizer = torch.optim.Adam(
1079
+ engram_params_list, lr=lr * 5, betas=(0.9, 0.95), eps=1e-8
1080
+ )
1081
+ else:
1082
+ engram_optimizer = None
1083
+
1084
+ # ------------------------------------------------------------------
1085
+ # Resume from latest checkpoint
1086
+ # ------------------------------------------------------------------
1087
+ start_step = 0
1088
+ start_epoch = 1
1089
+ best_loss = float("inf")
1090
+ existing_ckpts = [f for f in os.listdir(ckpt_dir) if f.startswith("spiderportal-v5-dense-ep") and f.endswith(".pt") and "-step" not in f] if os.path.isdir(ckpt_dir) else []
1091
+ if existing_ckpts:
1092
+ latest = os.path.join(ckpt_dir, sorted(existing_ckpts)[-1])
1093
+ if master:
1094
+ logger.info(f"Resuming from checkpoint: {latest}")
1095
+ start_step, start_epoch = load_checkpoint(model, optimizer, latest, ddp)
1096
+ if master:
1097
+ logger.success(f"Resumed at step {start_step}, epoch {start_epoch}")
1098
+
1099
+ # ------------------------------------------------------------------
1100
+ # Dataset + DataLoader
1101
+ # ------------------------------------------------------------------
1102
+ dataset = FineWebEduDataset(tokenizer, seq_len, dataset_subset, rank, world_size)
1103
+ loader = DataLoader(dataset, batch_size=micro_batch, num_workers=4, pin_memory=True, prefetch_factor=1)
1104
+
1105
+ # ------------------------------------------------------------------
1106
+ # Training loop
1107
+ # ------------------------------------------------------------------
1108
+ if master:
1109
+ os.makedirs(ckpt_dir, exist_ok=True)
1110
+
1111
+ model.train()
1112
+ data_iter = iter(loader)
1113
+ t0 = time.perf_counter()
1114
+ step = start_step
1115
+ epoch = start_epoch
1116
+ step_ckpt_files = []
1117
+ tokens_in_epoch = 0
1118
+ tokens_per_epoch = target_tokens
1119
+
1120
+ # Allow env override for quick debugging.
1121
+ max_steps_override = os.environ.get("MAX_STEPS", None)
1122
+ while step < total_steps:
1123
+ if max_steps_override is not None and step >= int(max_steps_override):
1124
+ break
1125
+ cur_lr = get_lr(step, warmup_steps, total_steps, lr, lr * 0.1)
1126
+ for g in optimizer.param_groups:
1127
+ g["lr"] = cur_lr
1128
+ if engram_optimizer:
1129
+ for g in engram_optimizer.param_groups:
1130
+ g["lr"] = cur_lr * 5
1131
+
1132
+ optimizer.zero_grad()
1133
+ if engram_optimizer:
1134
+ engram_optimizer.zero_grad()
1135
+ loss_accum = 0.0
1136
+
1137
+ for micro_step in range(grad_accum):
1138
+ try:
1139
+ x, y = next(data_iter)
1140
+ except StopIteration:
1141
+ data_iter = iter(loader)
1142
+ x, y = next(data_iter)
1143
+
1144
+ x = x.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True)
1145
+ y = y.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True)
1146
+
1147
+ sync = (
1148
+ nullcontext()
1149
+ if (not ddp or micro_step == grad_accum - 1)
1150
+ else model.no_sync()
1151
+ )
1152
+ with sync, amp_ctx, sdpa_ctx:
1153
+ output = model(x)
1154
+ if master and step == start_step and micro_step == 0:
1155
+ peak_vram = torch.cuda.max_memory_allocated() / 1024**3
1156
+ logger.info(f"Reached first model forward | Peak VRAM: {peak_vram:.1f}GB")
1157
+ if isinstance(output, dict):
1158
+ logits = output["logits"]
1159
+ else:
1160
+ logits = output
1161
+ loss = nn.functional.cross_entropy(
1162
+ logits.view(-1, vocab_size), y.view(-1)
1163
+ )
1164
+ loss = loss / grad_accum
1165
+
1166
+ loss.backward()
1167
+ if master and step == start_step and micro_step == 0:
1168
+ logger.info("Reached first backward")
1169
+ loss_accum += loss.item()
1170
+
1171
+ if ddp:
1172
+ grad_norm = model.clip_grad_norm_(1.0)
1173
+ else:
1174
+ grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
1175
+ optimizer.step()
1176
+ if engram_optimizer:
1177
+ engram_optimizer.step()
1178
+ step += 1
1179
+ tokens_in_epoch += global_batch_tok
1180
+
1181
+ if master and step % log_every == 0:
1182
+ dt = time.perf_counter() - t0
1183
+ tok_per_sec = global_batch_tok * log_every / dt
1184
+ tokens_seen = step * global_batch_tok
1185
+ logger.info(
1186
+ f"Epoch {epoch} | step {step:6d}/{total_steps} | loss {loss_accum:.4f} "
1187
+ f"| gnorm {float(grad_norm):.2f} | lr {cur_lr:.2e} "
1188
+ f"| {tok_per_sec / 1e6:.2f}M tok/s "
1189
+ f"| {tokens_seen / 1e9:.2f}B tokens seen"
1190
+ )
1191
+ t0 = time.perf_counter()
1192
+
1193
+ if step % ckpt_every == 0 and master:
1194
+ ckpt_path, size_mb = save_weights_only(model, step, epoch, ckpt_dir, ddp)
1195
+ step_ckpt_files.append(ckpt_path)
1196
+ logger.info(f"Saved weights-only: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)")
1197
+
1198
+ if tokens_in_epoch >= tokens_per_epoch:
1199
+ epoch_loss = loss_accum
1200
+ if master:
1201
+ epoch_time = (time.perf_counter() - t0) / 60
1202
+ logger.info(f"Epoch {epoch} complete | loss={epoch_loss:.4f} | Time: {epoch_time:.1f}min")
1203
+
1204
+ for f in step_ckpt_files:
1205
+ if os.path.exists(f):
1206
+ os.remove(f)
1207
+ logger.info(f" Deleted step checkpoint: {os.path.basename(f)}")
1208
+ step_ckpt_files.clear()
1209
+
1210
+ ckpt_path, size_mb = save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, f"ep{epoch}")
1211
+ if ckpt_path:
1212
+ logger.info(f"Saved epoch checkpoint: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)")
1213
+
1214
+ if epoch_loss < best_loss:
1215
+ best_loss = epoch_loss
1216
+ ckpt_path, size_mb = save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, "best")
1217
+ if ckpt_path:
1218
+ logger.info(f"Saved best checkpoint: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)")
1219
+
1220
+ epoch += 1
1221
+ tokens_in_epoch = 0
1222
+
1223
+ if step > start_step and master:
1224
+ ckpt_path, size_mb = save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, f"final-ep{epoch}")
1225
+ if ckpt_path:
1226
+ logger.info(f"Saved final checkpoint: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)")
1227
+
1228
+ if ddp:
1229
+ dist.barrier()
1230
+ dist.destroy_process_group()
1231
+
1232
+ if master:
1233
+ logger.success("Training complete.")
1234
+
1235
+
1236
+ if __name__ == "__main__":
1237
+ main()