CLIWorks commited on
Commit
c08afec
·
verified ·
1 Parent(s): dce8328

Upload agent-mythos-edit.py

Browse files
Files changed (1) hide show
  1. agent-mythos-edit.py +1113 -0
agent-mythos-edit.py ADDED
@@ -0,0 +1,1113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ SpiderPortal v5-Dense: English pretraining on FineWeb-Edu with AdamW.
4
+ Architecture: RDT (2 prelude + 6 recurrent + 2 coda) with:
5
+ - MLA (Multi-Latent Attention): 10.7x KV cache compression + sliding window
6
+ - Engram conditional memory at recurrent layers 1 and 4
7
+ - Dense FFN (all params active, MoE conversion in Phase 2)
8
+ - LTI Injection + ACT Halting + LoRA Adapter
9
+ - 32k context (extendable to 256k at inference via YaRN)
10
+ Config: hidden_size=2048, 6 recurrent layers, 32 experts (Phase 2), top-2 routing
11
+ Single GPU:
12
+ python mythos-fineweb-dense.py
13
+ Multi-GPU:
14
+ torchrun --nproc_per_node=$(python -c "import torch; print(torch.cuda.device_count())") mythos-fineweb-dense.py
15
+ """
16
+ import os
17
+ import math
18
+ import time
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ import torch.distributed as dist
23
+ import sys
24
+
25
+ # Simple print-based logging — no file rotation, no hanging
26
+ def log(msg, level="INFO"):
27
+ ts = time.strftime("%Y-%m-%d %H:%M:%S")
28
+ print(f"{ts} | {level} | {msg}", flush=True)
29
+
30
+ # Speed up CUDA memory allocation
31
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512,expandable_segments:True"
32
+ from torch.distributed.fsdp import (
33
+ FullyShardedDataParallel as FSDP,
34
+ ShardingStrategy,
35
+ MixedPrecision,
36
+ FullStateDictConfig,
37
+ StateDictType,
38
+ )
39
+ from torch.distributed.fsdp.wrap import ModuleWrapPolicy
40
+ from torch.utils.data import IterableDataset, DataLoader, get_worker_info
41
+ from contextlib import nullcontext
42
+ from dataclasses import dataclass, field
43
+ from typing import Optional, Tuple, Dict, List
44
+ from torch.nn import CrossEntropyLoss
45
+ from datasets import load_dataset
46
+ from transformers import AutoTokenizer
47
+
48
+
49
+ # ---------------------------------------------------------------------------
50
+ # SpiderPortal Model Architecture (Dense + MLA + Engram)
51
+ # ---------------------------------------------------------------------------
52
+
53
+ @dataclass
54
+ class SpiderPortalConfig:
55
+ vocab_size: int = 50257
56
+ hidden_size: int = 2048
57
+ num_hidden_layers: int = 6
58
+ num_attention_heads: int = 16
59
+ num_key_value_heads: int = 4
60
+ intermediate_size: int = 8192
61
+ hidden_act: str = "silu"
62
+ num_experts: int = 32
63
+ num_experts_per_tok: int = 2
64
+ num_shared_experts: int = 1
65
+ router_aux_loss_coef: float = 0.05
66
+ max_loop_iters: int = 4
67
+ act_threshold: float = 0.5
68
+ max_position_embeddings: int = 32768
69
+ rope_theta: float = 10000000.0
70
+ rope_scaling: dict = None
71
+ sliding_window: int = 4096
72
+ attention_dropout: float = 0.0
73
+ rms_norm_eps: float = 1e-6
74
+ initializer_range: float = 0.02
75
+ use_cache: bool = True
76
+ tie_word_embeddings: bool = True
77
+ prelude_layers: int = 2
78
+ coda_layers: int = 2
79
+ lora_rank: int = 128
80
+ loop_embed_dim: int = 128
81
+ vision_hidden_size: int = 2048
82
+ audio_hidden_size: int = 512
83
+ vision_num_frames: int = 60
84
+ vision_tokens_per_frame: int = 256
85
+ vision_temporal_tokens: int = 64
86
+ vision_temporal_layers: int = 2
87
+ model_type: str = "spiderportal"
88
+ torch_dtype: str = "bfloat16"
89
+
90
+ # MLA parameters (DeepSeek-V2 style, scaled for hidden_size=2048)
91
+ kv_lora_rank: int = 128
92
+ q_lora_rank: int = 256
93
+ qk_rope_head_dim: int = 64
94
+ qk_nope_head_dim: int = 64
95
+ v_head_dim: int = 64
96
+
97
+ # Engram parameters (DeepSeek conditional memory)
98
+ engram_layers: List[int] = field(default_factory=lambda: [1, 4])
99
+ engram_ngram_orders: Tuple[int, ...] = (2, 3)
100
+ engram_hash_heads: int = 4
101
+ engram_table_size: int = 65537 # prime number for hash table
102
+ engram_conv_kernel: int = 4
103
+ engram_conv_dilation: int = 3
104
+ engram_dim: int = 128 # per-head embedding dimension
105
+
106
+
107
+ def loop_index_embedding(h, loop_t, loop_dim, theta=10000.0):
108
+ freqs = 1.0 / (theta ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim))
109
+ angles = loop_t * freqs
110
+ emb = torch.cat([angles.sin(), angles.cos()], dim=-1)[:loop_dim]
111
+ emb_full = torch.zeros(h.shape[-1], device=h.device, dtype=h.dtype)
112
+ emb_full[:loop_dim] = emb
113
+ return h + emb_full.unsqueeze(0).unsqueeze(0)
114
+
115
+
116
+ class SpiderPortalRMSNorm(nn.Module):
117
+ def __init__(self, hidden_size, eps=1e-6):
118
+ super().__init__()
119
+ self.weight = nn.Parameter(torch.ones(hidden_size))
120
+ self.variance_epsilon = eps
121
+ def forward(self, hidden_states):
122
+ input_dtype = hidden_states.dtype
123
+ hidden_states = hidden_states.to(torch.float32)
124
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
125
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
126
+ return self.weight.to(input_dtype) * hidden_states.to(input_dtype)
127
+
128
+
129
+ def compute_yarn_inv_freq(head_dim, rope_theta, factor, orig_max, beta_fast=32.0, beta_slow=1.0):
130
+ dim = head_dim
131
+ orig_inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2).float() / dim))
132
+ pos_freqs = torch.arange(0, dim, 2).float() / dim
133
+ beta = (pos_freqs * math.log(rope_theta) / math.log(orig_max))
134
+ scale = torch.where(beta < beta_slow, torch.ones_like(beta), torch.where(beta > beta_fast, torch.ones_like(beta) / factor, 1.0 - (beta - beta_slow) / (beta_fast - beta_slow) * (1.0 - 1.0 / factor)))
135
+ return orig_inv_freq * scale
136
+
137
+
138
+ # ---------------------------------------------------------------------------
139
+ # MLA: Multi-Latent Attention (DeepSeek-V2 style) + Sliding Window
140
+ # ---------------------------------------------------------------------------
141
+
142
+ class SpiderPortalMLA(nn.Module):
143
+ """Multi-Latent Attention with compressed KV cache and sliding window.
144
+ For hidden_size=2048, num_heads=16:
145
+ - qk_nope_head_dim=64, qk_rope_head_dim=64 → total head_dim=128
146
+ - kv_lora_rank=128 → 10.7x compression vs full 2048-dim KV
147
+ - v_head_dim=64 → value projection
148
+ - sliding_window=4096 → local attention range
149
+ """
150
+ def __init__(self, config):
151
+ super().__init__()
152
+ self.config = config
153
+ self.hidden_size = config.hidden_size
154
+ self.num_heads = config.num_attention_heads
155
+ self.kv_lora_rank = config.kv_lora_rank
156
+ self.q_lora_rank = config.q_lora_rank
157
+ self.qk_rope_head_dim = config.qk_rope_head_dim
158
+ self.qk_nope_head_dim = config.qk_nope_head_dim
159
+ self.v_head_dim = config.v_head_dim
160
+ self.head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
161
+ self.sliding_window = getattr(config, 'sliding_window', None)
162
+
163
+ # Q projection: optional low-rank → full Q
164
+ if self.q_lora_rank > 0:
165
+ self.q_a_proj = nn.Linear(config.hidden_size, self.q_lora_rank, bias=False)
166
+ self.q_a_layernorm = SpiderPortalRMSNorm(self.q_lora_rank)
167
+ self.q_b_proj = nn.Linear(self.q_lora_rank, self.num_heads * self.head_dim, bias=False)
168
+ else:
169
+ self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
170
+
171
+ # KV compression: hidden → kv_lora_rank (shared latent)
172
+ self.kv_a_proj_with_mqa = nn.Linear(config.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=False)
173
+ self.kv_a_layernorm = SpiderPortalRMSNorm(self.kv_lora_rank)
174
+ # Decompress: kv_lora_rank → nope heads + v heads
175
+ self.kv_b_proj = nn.Linear(
176
+ self.kv_lora_rank,
177
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
178
+ bias=False,
179
+ )
180
+ # Output projection
181
+ self.o_proj = nn.Linear(self.num_heads * self.v_head_dim, config.hidden_size, bias=False)
182
+
183
+ # RoPE frequencies
184
+ rope_scaling = getattr(config, 'rope_scaling', None)
185
+ if rope_scaling and rope_scaling.get("type") == "yarn":
186
+ factor = rope_scaling.get("factor", 1.0)
187
+ orig_max_pos = rope_scaling.get("original_max_position_embeddings", config.max_position_embeddings)
188
+ inv_freq = compute_yarn_inv_freq(self.qk_rope_head_dim, config.rope_theta, factor, orig_max_pos)
189
+ else:
190
+ inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, self.qk_rope_head_dim, 2).float() / self.qk_rope_head_dim))
191
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
192
+
193
+ def _rotate_half(self, x):
194
+ x1 = x[..., :x.shape[-1] // 2]
195
+ x2 = x[..., x.shape[-1] // 2:]
196
+ return torch.cat((-x2, x1), dim=-1)
197
+
198
+ def _apply_rotary(self, x, cos, sin):
199
+ return (x * cos) + (self._rotate_half(x) * sin)
200
+
201
+ def _make_sliding_window_mask(self, q_len, kv_len, device, dtype):
202
+ """Create a sliding window causal mask."""
203
+ if self.sliding_window is None or self.sliding_window <= 0:
204
+ return None
205
+ mask = torch.full((q_len, kv_len), torch.finfo(dtype).min, device=device, dtype=dtype)
206
+ for i in range(q_len):
207
+ start = max(0, i - self.sliding_window + 1)
208
+ mask[i, start:i + 1] = 0.0
209
+ return mask
210
+
211
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
212
+ bsz, q_len, _ = hidden_states.size()
213
+
214
+ # Q projection
215
+ if self.q_lora_rank > 0:
216
+ q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
217
+ else:
218
+ q = self.q_proj(hidden_states)
219
+ q = q.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
220
+ q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
221
+
222
+ # KV: compress to latent, then decompress
223
+ kv_hidden = self.kv_a_proj_with_mqa(hidden_states)
224
+ kv_latent, k_rope = torch.split(kv_hidden, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
225
+ kv_latent_norm = self.kv_a_layernorm(kv_latent)
226
+ kv_b_out = self.kv_b_proj(kv_latent_norm)
227
+ k_nope, v = torch.split(kv_b_out, [self.num_heads * self.qk_nope_head_dim, self.num_heads * self.v_head_dim], dim=-1)
228
+
229
+ k_nope = k_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2)
230
+ v = v.view(bsz, q_len, self.num_heads, self.v_head_dim).transpose(1, 2)
231
+ k_rope = k_rope.unsqueeze(1)
232
+
233
+ # RoPE on Q and K rope parts
234
+ if position_ids is None:
235
+ position_ids = torch.arange(q_len, device=hidden_states.device).unsqueeze(0).expand(bsz, -1)
236
+ max_pos = position_ids.max().item() + 1
237
+ seq_len = max(max_pos, q_len)
238
+ t = torch.arange(seq_len, device=hidden_states.device, dtype=self.inv_freq.dtype)
239
+ freqs = torch.outer(t, self.inv_freq)
240
+ emb = torch.cat((freqs, freqs), dim=-1)
241
+ cos, sin = emb.cos(), emb.sin()
242
+ cos = cos[position_ids].unsqueeze(1)
243
+ sin = sin[position_ids].unsqueeze(1)
244
+
245
+ q_rope = self._apply_rotary(q_rope, cos, sin)
246
+ k_rope = self._apply_rotary(k_rope, cos, sin)
247
+
248
+ # Assemble full K
249
+ k_rope_expanded = k_rope.expand(-1, self.num_heads, -1, -1)
250
+ k_full = torch.cat([k_nope, k_rope_expanded], dim=-1)
251
+ q_full = torch.cat([q_nope, q_rope], dim=-1)
252
+
253
+ # KV cache
254
+ if past_key_value is not None:
255
+ k_full = torch.cat([past_key_value[0], k_full], dim=2)
256
+ v = torch.cat([past_key_value[1], v], dim=2)
257
+ past_kv = (k_full, v) if use_cache else None
258
+
259
+ # Build attention mask: user mask + sliding window
260
+ final_mask = attention_mask
261
+ if self.sliding_window is not None and self.sliding_window > 0:
262
+ kv_len = k_full.size(2)
263
+ sw_mask = self._make_sliding_window_mask(q_len, kv_len, hidden_states.device, hidden_states.dtype)
264
+ if final_mask is not None:
265
+ final_mask = final_mask + sw_mask
266
+ else:
267
+ final_mask = sw_mask
268
+
269
+ # Attention with SDPA
270
+ attn_output = F.scaled_dot_product_attention(
271
+ q_full, k_full, v,
272
+ attn_mask=final_mask,
273
+ dropout_p=self.config.attention_dropout if self.training else 0.0,
274
+ is_causal=(final_mask is None),
275
+ )
276
+ attn_output = attn_output.transpose(1, 2).contiguous()
277
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
278
+ return self.o_proj(attn_output), past_kv
279
+
280
+
281
+ # ---------------------------------------------------------------------------
282
+ # Engram: Conditional Memory via Scalable Lookup (DeepSeek style)
283
+ # ---------------------------------------------------------------------------
284
+
285
+ def _tokenizer_compress(token_ids, vocab_size=50257):
286
+ """Simulate NFKC + lowercase canonical ID projection."""
287
+ return token_ids % (vocab_size * 77 // 100)
288
+
289
+
290
+ class SpiderPortalEngram(nn.Module):
291
+ """Conditional memory module via NN-gram lookup.
292
+ Applied only at specific recurrent layers (config.engram_layers).
293
+ """
294
+ def __init__(self, config):
295
+ super().__init__()
296
+ self.config = config
297
+ self.ngram_orders = config.engram_ngram_orders
298
+ self.num_heads = config.engram_hash_heads
299
+ self.table_size = config.engram_table_size
300
+ self.d_mem = config.engram_dim
301
+
302
+ self.total_mem_dim = len(self.ngram_orders) * self.num_heads * self.d_mem
303
+
304
+ self.embed_tables = nn.ParameterDict()
305
+ for n in self.ngram_orders:
306
+ for h in range(self.num_heads):
307
+ key = f"e_{n}_{h}"
308
+ self.embed_tables[key] = nn.Parameter(
309
+ torch.randn(self.table_size, self.d_mem) * 0.02
310
+ )
311
+
312
+ self.register_buffer("hash_seeds", torch.tensor([
313
+ (h + 1) * 2654435761
314
+ for _ in self.ngram_orders
315
+ for h in range(self.num_heads)
316
+ ], dtype=torch.int64))
317
+
318
+ self.W_k = nn.Linear(self.total_mem_dim, config.hidden_size, bias=False)
319
+ self.W_v = nn.Linear(self.total_mem_dim, config.hidden_size, bias=False)
320
+
321
+ self.conv = nn.Conv1d(
322
+ config.hidden_size, config.hidden_size,
323
+ kernel_size=config.engram_conv_kernel,
324
+ padding=config.engram_conv_kernel - 1,
325
+ groups=config.hidden_size,
326
+ )
327
+ self.conv_dilation = config.engram_conv_dilation
328
+
329
+ with torch.no_grad():
330
+ self.conv.weight.zero_()
331
+ if self.conv.bias is not None:
332
+ self.conv.bias.zero_()
333
+
334
+ self.q_norm = SpiderPortalRMSNorm(config.hidden_size)
335
+ self.k_norm = SpiderPortalRMSNorm(config.hidden_size)
336
+
337
+ def _compute_indices(self, compressed_ids, n, head_idx):
338
+ """Vectorized NN-gram hash indices for a single (order, head)."""
339
+ bsz, seq_len = compressed_ids.shape
340
+ pad = torch.zeros(bsz, n - 1, dtype=compressed_ids.dtype, device=compressed_ids.device)
341
+ padded = torch.cat([pad, compressed_ids], dim=1)
342
+
343
+ indices_list = []
344
+ for i in range(n):
345
+ indices_list.append(padded[:, i:i + seq_len])
346
+ ngrams = torch.stack(indices_list, dim=-1)
347
+
348
+ seed = int(self.hash_seeds[head_idx].item())
349
+ h_val = torch.zeros(bsz, seq_len, dtype=torch.int64, device=compressed_ids.device)
350
+ for i in range(n):
351
+ h_val = h_val * 31 + ngrams[:, :, i]
352
+ h_val = h_val % self.table_size
353
+ h_val = (h_val * seed) % self.table_size
354
+ return h_val
355
+
356
+ def _retrieve(self, token_ids):
357
+ """Retrieve memory vectors for a batch of token sequences."""
358
+ bsz, seq_len = token_ids.shape
359
+ compressed = _tokenizer_compress(token_ids)
360
+
361
+ all_parts = []
362
+ head_counter = 0
363
+ for n in self.ngram_orders:
364
+ for h in range(self.num_heads):
365
+ key = f"e_{n}_{h}"
366
+ table = self.embed_tables[key]
367
+ indices = self._compute_indices(compressed, n, head_counter)
368
+ emb = table[indices.view(-1)]
369
+ all_parts.append(emb.view(bsz, seq_len, self.d_mem))
370
+ head_counter += 1
371
+
372
+ memory = torch.cat(all_parts, dim=-1)
373
+ return memory
374
+
375
+ def forward(self, hidden_states, token_ids):
376
+ mem = self._retrieve(token_ids)
377
+
378
+ q = hidden_states
379
+ k = self.W_k(mem)
380
+ v = self.W_v(mem)
381
+
382
+ q_norm = self.q_norm(q)
383
+ k_norm = self.k_norm(k)
384
+
385
+ alpha = torch.sigmoid(
386
+ (q_norm * k_norm).sum(dim=-1, keepdim=True) / math.sqrt(q.shape[-1])
387
+ )
388
+
389
+ v_gated = alpha * v
390
+
391
+ v_gated_t = v_gated.transpose(1, 2)
392
+ conv_out = self.conv(v_gated_t)
393
+ conv_out = conv_out[:, :, :v_gated_t.shape[-1]]
394
+ conv_out = conv_out.transpose(1, 2)
395
+
396
+ y = F.silu(conv_out) + v_gated
397
+
398
+ return y
399
+
400
+
401
+ # ---------------------------------------------------------------------------
402
+ # FFN Expert (dense)
403
+ # ---------------------------------------------------------------------------
404
+
405
+ class SpiderPortalExpert(nn.Module):
406
+ def __init__(self, config, intermediate_size=None):
407
+ super().__init__()
408
+ inter_size = intermediate_size or config.intermediate_size
409
+ self.gate_proj = nn.Linear(config.hidden_size, inter_size, bias=False)
410
+ self.up_proj = nn.Linear(config.hidden_size, inter_size, bias=False)
411
+ self.down_proj = nn.Linear(inter_size, config.hidden_size, bias=False)
412
+ self.act_fn = nn.SiLU()
413
+ def forward(self, hidden_states):
414
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
415
+
416
+
417
+ # ---------------------------------------------------------------------------
418
+ # Prelude/Coda Dense Layer (uses MLA)
419
+ # ---------------------------------------------------------------------------
420
+
421
+ class SpiderPortalDenseLayer(nn.Module):
422
+ """Prelude/coda dense layer with MLA attention."""
423
+ def __init__(self, config):
424
+ super().__init__()
425
+ self.self_attn = SpiderPortalMLA(config)
426
+ dense_intermediate = config.hidden_size * 4 // 3
427
+ self.ffn = SpiderPortalExpert(config, intermediate_size=dense_intermediate)
428
+ self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
429
+ self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
430
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
431
+ attn_input = self.input_layernorm(hidden_states)
432
+ attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
433
+ hidden_states = hidden_states + attn_output
434
+ ffn_input = self.post_attention_layernorm(hidden_states)
435
+ ffn_output = self.ffn(ffn_input)
436
+ hidden_states = hidden_states + ffn_output
437
+ return hidden_states, past_kv
438
+
439
+
440
+ # ---------------------------------------------------------------------------
441
+ # Recurrent Dense Layer (uses MLA + optional Engram)
442
+ # ---------------------------------------------------------------------------
443
+
444
+ class SpiderPortalRecurrentDenseLayer(nn.Module):
445
+ """Recurrent layer with MLA attention and optional Engram memory."""
446
+ def __init__(self, config, layer_idx, has_engram=False):
447
+ super().__init__()
448
+ self.layer_idx = layer_idx
449
+ self.has_engram = has_engram
450
+ self.self_attn = SpiderPortalMLA(config)
451
+ if has_engram:
452
+ self.engram = SpiderPortalEngram(config)
453
+ self.ffn = SpiderPortalExpert(config, intermediate_size=config.intermediate_size)
454
+ self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
455
+ self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
456
+ self.post_engram_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) if has_engram else None
457
+ def forward(self, hidden_states, token_ids=None, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
458
+ attn_input = self.input_layernorm(hidden_states)
459
+ attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
460
+ hidden_states = hidden_states + attn_output
461
+
462
+ if self.has_engram and token_ids is not None:
463
+ engram_out = self.engram(hidden_states, token_ids)
464
+ hidden_states = hidden_states + engram_out
465
+ if self.post_engram_layernorm is not None:
466
+ hidden_states = self.post_engram_layernorm(hidden_states)
467
+
468
+ ffn_input = self.post_attention_layernorm(hidden_states)
469
+ ffn_output = self.ffn(ffn_input)
470
+ hidden_states = hidden_states + ffn_output
471
+ return hidden_states, 0.0, past_kv
472
+
473
+
474
+ # ---------------------------------------------------------------------------
475
+ # LTI Injection, ACT Halting, LoRA Adapter
476
+ # ---------------------------------------------------------------------------
477
+
478
+ class LTIInjection(nn.Module):
479
+ def __init__(self, config):
480
+ super().__init__()
481
+ self.hidden_size = config.hidden_size
482
+ self.log_A = nn.Parameter(torch.full((config.hidden_size,), -2.0))
483
+ self.delta_t = nn.Parameter(torch.tensor(1.0))
484
+ self.B = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
485
+ with torch.no_grad():
486
+ self.B.weight.data.normal_(mean=0.0, std=0.01)
487
+ def get_A(self):
488
+ return -torch.exp(self.log_A)
489
+ def forward(self, h_t, e):
490
+ A = self.get_A()
491
+ return A * h_t + self.B(e)
492
+
493
+
494
+ class ACTHalting(nn.Module):
495
+ def __init__(self, config):
496
+ super().__init__()
497
+ self.halt_predictor = nn.Linear(config.hidden_size, 1)
498
+ self.threshold = config.act_threshold
499
+ def forward(self, hidden_states):
500
+ return torch.sigmoid(self.halt_predictor(hidden_states))
501
+
502
+
503
+ class LoRAAdapter(nn.Module):
504
+ def __init__(self, config):
505
+ super().__init__()
506
+ rank = config.lora_rank
507
+ self.down = nn.Linear(config.hidden_size, rank, bias=False)
508
+ self.B = nn.Parameter(torch.randn(rank, config.hidden_size) * 0.02)
509
+ self.scale = nn.Embedding(config.max_loop_iters, rank)
510
+ with torch.no_grad():
511
+ self.scale.weight.data.zero_()
512
+ self.down.weight.data.normal_(mean=0.0, std=0.001)
513
+ def forward(self, x, loop_t):
514
+ max_t = self.scale.num_embeddings - 1
515
+ t_idx = min(loop_t, max_t)
516
+ s = self.scale(torch.tensor(t_idx, device=x.device))
517
+ down = self.down(x) * s
518
+ return down @ self.B
519
+
520
+
521
+ def checkpoint(func, *args, **kwargs):
522
+ """Gradient checkpointing wrapper — saves VRAM at ~20% compute cost."""
523
+ if torch.is_grad_enabled():
524
+ return torch.utils.checkpoint.checkpoint(func, *args, use_reentrant=False, **kwargs)
525
+ return func(*args, **kwargs)
526
+
527
+
528
+ # ---------------------------------------------------------------------------
529
+ # Full Model
530
+ # ---------------------------------------------------------------------------
531
+
532
+ class SpiderPortalDenseModel(nn.Module):
533
+ """Full RDT model with MLA attention + Engram memory at layers 1,4.
534
+ Architecture:
535
+ 2x Prelude (MLA + dense FFN)
536
+ 6x Recurrent (MLA + Engram@L1,L4 + dense FFN) — with gradient checkpointing
537
+ 2x Coda (MLA + dense FFN)
538
+ LTI Injection + ACT Halting + LoRA Adapter
539
+ """
540
+ def __init__(self, config):
541
+ super().__init__()
542
+ self.config = config
543
+ self.prelude_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.prelude_layers)])
544
+ self.recurrent_layers = nn.ModuleList([
545
+ SpiderPortalRecurrentDenseLayer(config, i, has_engram=(i in config.engram_layers))
546
+ for i in range(config.num_hidden_layers)
547
+ ])
548
+ self.coda_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.coda_layers)])
549
+ self.norm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
550
+ self.injection = LTIInjection(config)
551
+ self.act_halting = ACTHalting(config)
552
+ self.lora_adapter = LoRAAdapter(config)
553
+ self.loop_embed_dim = config.loop_embed_dim
554
+ def forward(self, hidden_states, input_embedding=None, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, n_loops=None, token_ids=None):
555
+ n_loops = n_loops or self.config.max_loop_iters
556
+ input_embedding = input_embedding if input_embedding is not None else hidden_states
557
+ for layer in self.prelude_layers:
558
+ hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
559
+ e = hidden_states.clone()
560
+ B, T_seq, D = hidden_states.shape
561
+ halted = torch.zeros(B, T_seq, device=hidden_states.device, dtype=torch.bool)
562
+ cumulative_p = torch.zeros(B, T_seq, device=hidden_states.device, dtype=hidden_states.dtype)
563
+ h_out = torch.zeros_like(hidden_states)
564
+ past_key_values = past_key_values if past_key_values is not None else [None] * len(self.recurrent_layers)
565
+ for t in range(n_loops):
566
+ h_loop = loop_index_embedding(hidden_states, t, self.loop_embed_dim)
567
+ if t > 0:
568
+ injection = self.injection(hidden_states, input_embedding)
569
+ hidden_states = hidden_states + injection
570
+ new_past_key_values = []
571
+ for i, layer in enumerate(self.recurrent_layers):
572
+ hidden_states, aux_loss, past_kv = checkpoint(
573
+ layer, hidden_states,
574
+ token_ids=token_ids,
575
+ attention_mask=attention_mask,
576
+ position_ids=position_ids,
577
+ past_key_value=past_key_values[i] if t == 0 else None,
578
+ use_cache=use_cache
579
+ )
580
+ new_past_key_values.append(past_kv)
581
+ lora_delta = self.lora_adapter(hidden_states, t)
582
+ hidden_states = hidden_states + lora_delta
583
+ halt_prob = self.act_halting(hidden_states).squeeze(-1)
584
+ still_running = ~halted
585
+ remainder = (1.0 - cumulative_p).clamp(min=0)
586
+ weight = torch.where(cumulative_p + halt_prob >= self.config.act_threshold, remainder, halt_prob)
587
+ weight = weight * still_running.to(hidden_states.dtype)
588
+ h_out = h_out + weight.unsqueeze(-1) * hidden_states
589
+ cumulative_p = cumulative_p + halt_prob * still_running.to(hidden_states.dtype)
590
+ halted = halted | (cumulative_p >= self.config.act_threshold)
591
+ if halted.all() and not self.training:
592
+ break
593
+ never_halted = (~halted).to(hidden_states.dtype).unsqueeze(-1)
594
+ hidden_states = h_out + never_halted * hidden_states
595
+ for layer in self.coda_layers:
596
+ hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
597
+ hidden_states = self.norm(hidden_states)
598
+ return hidden_states, 0.0, new_past_key_values
599
+
600
+
601
+ class SpiderPortalForConditionalGeneration(nn.Module):
602
+ def __init__(self, config):
603
+ super().__init__()
604
+ self.config = config
605
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
606
+ self.model = SpiderPortalDenseModel(config)
607
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
608
+ if config.tie_word_embeddings:
609
+ self.lm_head.weight = self.embed_tokens.weight
610
+ self.apply(self._init_weights)
611
+ def _init_weights(self, module):
612
+ if isinstance(module, nn.Linear):
613
+ if hasattr(self, 'model') and module is self.model.injection.B:
614
+ return
615
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
616
+ if module.bias is not None:
617
+ module.bias.data.zero_()
618
+ elif isinstance(module, nn.Embedding):
619
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
620
+ def forward(self, input_ids, attention_mask=None, position_ids=None, labels=None, n_loops=None, use_cache=False):
621
+ hidden_states = self.embed_tokens(input_ids)
622
+ model_dtype = next(self.model.parameters()).dtype
623
+ hidden_states = hidden_states.to(model_dtype)
624
+ input_embedding = hidden_states.clone()
625
+ if attention_mask is None:
626
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
627
+ causal_mask = torch.full((attention_mask.size(0), 1, attention_mask.size(1), attention_mask.size(1)), 0.0, dtype=hidden_states.dtype, device=hidden_states.device)
628
+ causal_mask = causal_mask.masked_fill(~attention_mask.unsqueeze(1).unsqueeze(2), torch.finfo(hidden_states.dtype).min)
629
+ causal_mask = causal_mask.triu(1)
630
+ hidden_states, aux_loss, past_kv = self.model(
631
+ hidden_states, input_embedding=input_embedding,
632
+ attention_mask=causal_mask, position_ids=position_ids,
633
+ use_cache=use_cache, n_loops=n_loops, token_ids=input_ids
634
+ )
635
+ logits = self.lm_head(hidden_states)
636
+ loss = None
637
+ if labels is not None:
638
+ shift_logits = logits[..., :-1, :].contiguous()
639
+ shift_labels = labels[..., 1:].contiguous()
640
+ loss_fct = CrossEntropyLoss()
641
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
642
+ return {"loss": loss, "logits": logits, "aux_loss": aux_loss, "past_key_values": past_kv}
643
+ def get_num_params(self):
644
+ total = sum(p.numel() for p in self.parameters())
645
+ return {"total": total, "trainable": total}
646
+
647
+
648
+ # ---------------------------------------------------------------------------
649
+ # Dataset
650
+ # ---------------------------------------------------------------------------
651
+
652
+ class FineWebEduDataset(IterableDataset):
653
+ def __init__(self, tokenizer, seq_len: int, subset: str, rank: int, world_size: int, local_token_file=None):
654
+ self.tokenizer = tokenizer
655
+ self.seq_len = seq_len
656
+ self.subset = subset
657
+ self.rank = rank
658
+ self.world_size = world_size
659
+
660
+ # Local tokenized data - USE mmapped binary for speed
661
+ if local_token_file and os.path.exists(local_token_file):
662
+ import numpy as np
663
+ self.use_local = True
664
+ self.local_file = local_token_file
665
+ self.mm = np.memmap(local_token_file, dtype='<u4', mode='r')
666
+ self.num_tokens = len(self.mm)
667
+ self.num_samples = self.num_tokens // seq_len
668
+ log(f"Using pre-tokenized binary: {local_token_file} ({self.num_tokens:,} tokens)")
669
+ else:
670
+ self.use_local = False
671
+ log("WARNING: No pre-tokenized binary found. Using streaming tokenizer (SLOW).")
672
+ log("Run pretokenize_fineweb.py first for 50-100x speedup.")
673
+
674
+ def __iter__(self):
675
+ if self.use_local:
676
+ # Fast: use memory-mapped array
677
+ worker = get_worker_info()
678
+ num_workers = worker.num_workers if worker else 1
679
+ worker_id = worker.id if worker else 0
680
+
681
+ samples_per_worker = self.num_samples // (self.world_size * num_workers)
682
+ start_sample = (self.rank * num_workers + worker_id) * samples_per_worker
683
+ end_sample = start_sample + samples_per_worker
684
+
685
+ # Batch read tokens - convert to numpy array slice then tensor
686
+ import numpy as np
687
+ for i in range(start_sample, end_sample):
688
+ start_idx = i * self.seq_len
689
+ # Direct slice from memory-mapped array
690
+ tokens = self.mm[start_idx:start_idx + self.seq_len + 1].copy()
691
+
692
+ yield (
693
+ torch.from_numpy(tokens[:-1].astype('int64')),
694
+ torch.from_numpy(tokens[1:].astype('int64')),
695
+ )
696
+ else:
697
+ # Fallback to HuggingFace
698
+ worker = get_worker_info()
699
+ num_workers = worker.num_workers if worker else 1
700
+ worker_id = worker.id if worker else 0
701
+ total_shards = self.world_size * num_workers
702
+ shard_index = self.rank * num_workers + worker_id
703
+ ds = load_dataset(
704
+ "HuggingFaceFW/fineweb-edu",
705
+ name=self.subset,
706
+ split="train",
707
+ streaming=True,
708
+ ).shard(num_shards=total_shards, index=shard_index)
709
+ buf = []
710
+ for sample in ds:
711
+ buf.extend(self.tokenizer.encode(sample["text"]))
712
+ while len(buf) >= self.seq_len + 1:
713
+ chunk = buf[: self.seq_len + 1]
714
+ buf = buf[self.seq_len + 1 :]
715
+ yield (
716
+ torch.tensor(chunk[:-1], dtype=torch.long),
717
+ torch.tensor(chunk[1:], dtype=torch.long),
718
+ )
719
+
720
+
721
+ # ---------------------------------------------------------------------------
722
+ # LR schedule
723
+ # ---------------------------------------------------------------------------
724
+
725
+ def get_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) -> float:
726
+ if step < warmup:
727
+ return max_lr * step / warmup
728
+ if step >= total:
729
+ return min_lr
730
+ decay = (step - warmup) / (total - warmup)
731
+ return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * decay))
732
+
733
+
734
+ # ---------------------------------------------------------------------------
735
+ # Checkpointing
736
+ # ---------------------------------------------------------------------------
737
+
738
+ def save_weights_only(model, step, epoch, ckpt_dir, ddp):
739
+ if ddp:
740
+ with FSDP.state_dict_type(
741
+ model,
742
+ StateDictType.FULL_STATE_DICT,
743
+ FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
744
+ ):
745
+ model_state = model.state_dict()
746
+ else:
747
+ model_state = model.state_dict()
748
+ ckpt_path = os.path.join(ckpt_dir, f"spiderportal-v5-dense-ep{epoch}-step{step}.pt")
749
+ tmp_path = ckpt_path + ".tmp"
750
+ torch.save(model_state, tmp_path)
751
+ os.replace(tmp_path, ckpt_path)
752
+ size_mb = os.path.getsize(ckpt_path) / (1024 * 1024)
753
+ return ckpt_path, size_mb
754
+
755
+
756
+ def save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, ckpt_name="full"):
757
+ if ddp:
758
+ with FSDP.state_dict_type(
759
+ model,
760
+ StateDictType.FULL_STATE_DICT,
761
+ FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
762
+ ):
763
+ model_state = model.state_dict()
764
+ optim_state = FSDP.optim_state_dict(model, optimizer)
765
+ else:
766
+ model_state = model.state_dict()
767
+ optim_state = optimizer.state_dict()
768
+ if not master:
769
+ return None, 0
770
+ os.makedirs(ckpt_dir, exist_ok=True)
771
+ final_path = os.path.join(ckpt_dir, f"spiderportal-v5-dense-{ckpt_name}.pt")
772
+ tmp_path = final_path + ".tmp"
773
+ torch.save(
774
+ {
775
+ "step": step,
776
+ "epoch": epoch,
777
+ "model_state_dict": model_state,
778
+ "optimizer_state_dict": optim_state,
779
+ "cfg": cfg,
780
+ "vocab_size": vocab_size,
781
+ },
782
+ tmp_path,
783
+ )
784
+ os.replace(tmp_path, final_path)
785
+ size_mb = os.path.getsize(final_path) / (1024 * 1024)
786
+ return final_path, size_mb
787
+
788
+
789
+ def load_checkpoint(model, optimizer, path, ddp):
790
+ ckpt = torch.load(path, map_location="cpu", weights_only=False)
791
+ if ddp:
792
+ with FSDP.state_dict_type(
793
+ model,
794
+ StateDictType.FULL_STATE_DICT,
795
+ FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
796
+ ):
797
+ model.load_state_dict(ckpt["model_state_dict"])
798
+ optim_state = FSDP.optim_state_dict_to_load(
799
+ model=model,
800
+ optim=optimizer,
801
+ optim_state_dict=ckpt["optimizer_state_dict"],
802
+ )
803
+ optimizer.load_state_dict(optim_state)
804
+ else:
805
+ model.load_state_dict(ckpt["model_state_dict"])
806
+ optimizer.load_state_dict(ckpt["optimizer_state_dict"])
807
+ return int(ckpt["step"]), int(ckpt.get("epoch", 0))
808
+
809
+
810
+ # ---------------------------------------------------------------------------
811
+ # Main
812
+ # ---------------------------------------------------------------------------
813
+
814
+ def main():
815
+ # ------------------------------------------------------------------
816
+ # Distributed init
817
+ # ------------------------------------------------------------------
818
+ ddp = int(os.environ.get("RANK", -1)) != -1
819
+ if ddp:
820
+ dist.init_process_group("nccl")
821
+ rank = int(os.environ["RANK"])
822
+ local_rank = int(os.environ["LOCAL_RANK"])
823
+ world_size = int(os.environ["WORLD_SIZE"])
824
+ device = f"cuda:{local_rank}"
825
+ torch.cuda.set_device(device)
826
+ else:
827
+ rank = local_rank = 0
828
+ world_size = 1
829
+ device = "cuda" if torch.cuda.is_available() else "cpu"
830
+ master = rank == 0
831
+ if master:
832
+ log(
833
+ f"GPUs: {torch.cuda.device_count()} | World size: {world_size} | Device: {device}"
834
+ )
835
+
836
+ # ------------------------------------------------------------------
837
+ # Tokenizer
838
+ # ------------------------------------------------------------------
839
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
840
+ tokenizer.pad_token = tokenizer.eos_token
841
+ vocab_size = tokenizer.vocab_size
842
+ if master:
843
+ log(f"Tokenizer: gpt2 | Vocab size: {vocab_size:,}")
844
+
845
+ # ------------------------------------------------------------------
846
+ # Hyperparameters
847
+ # ------------------------------------------------------------------
848
+ seq_len = 2048
849
+ micro_batch = 32 # Increased — 96GB VRAM can handle this
850
+ target_tokens = 20_000_000_000
851
+ grad_accum = 2
852
+ global_batch_tok = world_size * micro_batch * grad_accum * seq_len
853
+ total_steps = target_tokens // global_batch_tok
854
+ warmup_steps = 200
855
+ lr = 3e-4
856
+ wd = 0.1
857
+ log_every = 10
858
+ ckpt_every = 500
859
+ ckpt_dir = "checkpoints-dense"
860
+ dataset_subset = "sample-10BT"
861
+
862
+ if master:
863
+ log(
864
+ f"[DENSE MLA+Engram] hidden=2048 | layers=6 | seq_len={seq_len} | micro_batch={micro_batch} | grad_accum={grad_accum} | "
865
+ f"global_batch_tokens={global_batch_tok:,} | total_steps={total_steps:,}"
866
+ )
867
+ log(
868
+ f"Attention: MLA (kv_lora_rank=128, sliding_window=4096) | "
869
+ f"Engram: layers [1,4] | Context: 32k | "
870
+ f"Gradient checkpointing: enabled"
871
+ )
872
+
873
+ # ------------------------------------------------------------------
874
+ # Model
875
+ # ------------------------------------------------------------------
876
+ cfg = SpiderPortalConfig(
877
+ hidden_size=2048, num_hidden_layers=6, num_attention_heads=16,
878
+ num_key_value_heads=4, intermediate_size=8192,
879
+ num_experts=32, num_experts_per_tok=2, num_shared_experts=1,
880
+ router_aux_loss_coef=0.05, max_loop_iters=4,
881
+ prelude_layers=2, coda_layers=2, lora_rank=128,
882
+ rope_theta=10000000.0,
883
+ rope_scaling=None,
884
+ max_position_embeddings=32768, sliding_window=4096,
885
+ tie_word_embeddings=True,
886
+ kv_lora_rank=128, q_lora_rank=256,
887
+ qk_rope_head_dim=64, qk_nope_head_dim=64, v_head_dim=64,
888
+ engram_layers=[1, 4],
889
+ engram_ngram_orders=(2, 3),
890
+ engram_hash_heads=4,
891
+ engram_table_size=65537,
892
+ engram_dim=128,
893
+ )
894
+ cfg.vocab_size = vocab_size
895
+
896
+ bf16_ok = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
897
+ amp_dtype = torch.bfloat16 if bf16_ok else torch.float16
898
+
899
+ model = SpiderPortalForConditionalGeneration(cfg)
900
+
901
+ if ddp:
902
+ mp_policy = MixedPrecision(
903
+ param_dtype=amp_dtype,
904
+ reduce_dtype=amp_dtype,
905
+ buffer_dtype=amp_dtype,
906
+ )
907
+ wrap_policy = ModuleWrapPolicy({SpiderPortalDenseLayer, SpiderPortalRecurrentDenseLayer})
908
+ model = FSDP(
909
+ model,
910
+ sharding_strategy=ShardingStrategy.FULL_SHARD,
911
+ mixed_precision=mp_policy,
912
+ auto_wrap_policy=wrap_policy,
913
+ device_id=local_rank,
914
+ )
915
+ amp_ctx = nullcontext()
916
+ else:
917
+ model = model.to(device)
918
+ amp_ctx = torch.amp.autocast(device_type="cuda", dtype=amp_dtype) if torch.cuda.is_available() else nullcontext()
919
+ # Enable torch.compile for 20-30% speedup
920
+ try:
921
+ model = torch.compile(model, mode="reduce-overhead")
922
+ if master:
923
+ log("torch.compile: enabled (reduce-overhead)")
924
+ except Exception as e:
925
+ if master:
926
+ log(f"torch.compile failed ({e}), using eager mode")
927
+
928
+ if master:
929
+ n_params = sum(p.numel() for p in model.parameters())
930
+ engram_params = sum(p.numel() for n, p in model.named_parameters() if 'engram' in n)
931
+ mla_params = sum(p.numel() for n, p in model.named_parameters() if 'self_attn' in n)
932
+ embed_params = sum(p.numel() for n, p in model.named_parameters() if 'embed_tokens' in n or 'lm_head' in n)
933
+ ffn_params = sum(p.numel() for n, p in model.named_parameters() if 'ffn' in n or 'gate_proj' in n or 'up_proj' in n or 'down_proj' in n)
934
+ other_params = n_params - engram_params - mla_params - embed_params - ffn_params
935
+ log(
936
+ f"Parameters: {n_params:,} (all active) | "
937
+ f"Embeddings: {embed_params:,} | MLA: {mla_params:,} | "
938
+ f"FFN: {ffn_params:,} | Engram: {engram_params:,} | "
939
+ f"Other: {other_params:,} | AMP dtype: {amp_dtype}"
940
+ )
941
+
942
+ # ------------------------------------------------------------------
943
+ # Optimizer — dual optimizer for Engram embeddings
944
+ # ------------------------------------------------------------------
945
+ engram_params_list = [p for n, p in model.named_parameters() if 'engram' in n and 'embed_tables' in n]
946
+ backbone_params = [p for n, p in model.named_parameters() if 'engram' not in n or 'embed_tables' not in n]
947
+
948
+ optimizer = torch.optim.AdamW(
949
+ backbone_params, lr=lr, weight_decay=wd, betas=(0.9, 0.95), fused=True
950
+ )
951
+ if engram_params_list:
952
+ engram_optimizer = torch.optim.Adam(
953
+ engram_params_list, lr=lr * 5, betas=(0.9, 0.95), eps=1e-8
954
+ )
955
+ else:
956
+ engram_optimizer = None
957
+
958
+ # ------------------------------------------------------------------
959
+ # Resume from latest checkpoint
960
+ # ------------------------------------------------------------------
961
+ start_step = 0
962
+ start_epoch = 1
963
+ best_loss = float("inf")
964
+ existing_ckpts = [f for f in os.listdir(ckpt_dir) if f.startswith("spiderportal-v5-dense-ep") and f.endswith(".pt") and "-step" not in f] if os.path.isdir(ckpt_dir) else []
965
+ if existing_ckpts:
966
+ latest = os.path.join(ckpt_dir, sorted(existing_ckpts)[-1])
967
+ if master:
968
+ log(f"Resuming from checkpoint: {latest}")
969
+ start_step, start_epoch = load_checkpoint(model, optimizer, latest, ddp)
970
+ if master:
971
+ log(f"Resumed at step {start_step}, epoch {start_epoch}")
972
+
973
+ # ------------------------------------------------------------------
974
+ # Dataset + DataLoader
975
+ # ------------------------------------------------------------------
976
+ # Check for pre-tokenized binary file
977
+ local_token_file = os.environ.get("TOKEN_FILE", "data/fineweb-edu-sample-10BT.bin")
978
+ dataset = FineWebEduDataset(tokenizer, seq_len, dataset_subset, rank, world_size, local_token_file=local_token_file)
979
+ num_workers = 16 if dataset.use_local else 4
980
+ prefetch = 8 if dataset.use_local else 2
981
+ loader = DataLoader(dataset, batch_size=micro_batch, num_workers=num_workers, pin_memory=True, prefetch_factor=prefetch)
982
+ if master:
983
+ log(f"DataLoader: num_workers={num_workers}, prefetch={prefetch}, use_local={dataset.use_local}")
984
+
985
+ # ------------------------------------------------------------------
986
+ # Training loop
987
+ # ------------------------------------------------------------------
988
+ if master:
989
+ os.makedirs(ckpt_dir, exist_ok=True)
990
+
991
+ model.train()
992
+ data_iter = iter(loader)
993
+ t0 = time.perf_counter()
994
+ step = start_step
995
+ epoch = start_epoch
996
+ step_ckpt_files = []
997
+ tokens_in_epoch = 0
998
+ tokens_per_epoch = target_tokens
999
+
1000
+ while step < total_steps:
1001
+ cur_lr = get_lr(step, warmup_steps, total_steps, lr, lr * 0.1)
1002
+ for g in optimizer.param_groups:
1003
+ g["lr"] = cur_lr
1004
+ if engram_optimizer:
1005
+ for g in engram_optimizer.param_groups:
1006
+ g["lr"] = cur_lr * 5
1007
+
1008
+ optimizer.zero_grad()
1009
+ if engram_optimizer:
1010
+ engram_optimizer.zero_grad()
1011
+ loss_accum = 0.0
1012
+
1013
+ for micro_step in range(grad_accum):
1014
+ try:
1015
+ x, y = next(data_iter)
1016
+ except StopIteration:
1017
+ # Dataset exhausted — reshuffle and restart
1018
+ if master:
1019
+ log(f"Dataset exhausted at step {step}, restarting DataLoader")
1020
+ dataset = FineWebEduDataset(tokenizer, seq_len, dataset_subset, rank, world_size, local_token_file=local_token_file)
1021
+ loader = DataLoader(dataset, batch_size=micro_batch, num_workers=num_workers, pin_memory=True, prefetch_factor=prefetch)
1022
+ data_iter = iter(loader)
1023
+ x, y = next(data_iter)
1024
+
1025
+ x = x.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True)
1026
+ y = y.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True)
1027
+
1028
+ sync = (
1029
+ nullcontext()
1030
+ if (not ddp or micro_step == grad_accum - 1)
1031
+ else model.no_sync()
1032
+ )
1033
+ with sync, amp_ctx:
1034
+ output = model(x)
1035
+ if isinstance(output, dict):
1036
+ logits = output["logits"]
1037
+ else:
1038
+ logits = output
1039
+ loss = nn.functional.cross_entropy(
1040
+ logits.view(-1, vocab_size), y.view(-1)
1041
+ )
1042
+ loss = loss / grad_accum
1043
+
1044
+ loss.backward()
1045
+ loss_accum += loss.item()
1046
+
1047
+ if ddp:
1048
+ grad_norm = model.clip_grad_norm_(1.0)
1049
+ else:
1050
+ grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
1051
+ optimizer.step()
1052
+ if engram_optimizer:
1053
+ engram_optimizer.step()
1054
+ step += 1
1055
+ tokens_in_epoch += global_batch_tok
1056
+
1057
+ if master and step % log_every == 0:
1058
+ dt = time.perf_counter() - t0
1059
+ tok_per_sec = global_batch_tok * log_every / dt
1060
+ tokens_seen = step * global_batch_tok
1061
+ log(
1062
+ f"Epoch {epoch} | step {step:6d}/{total_steps} | loss {loss_accum:.4f} "
1063
+ f"| gnorm {float(grad_norm):.2f} | lr {cur_lr:.2e} "
1064
+ f"| {tok_per_sec / 1e6:.2f}M tok/s "
1065
+ f"| {tokens_seen / 1e9:.2f}B tokens seen"
1066
+ )
1067
+ t0 = time.perf_counter()
1068
+
1069
+ if step % ckpt_every == 0 and master:
1070
+ ckpt_path, size_mb = save_weights_only(model, step, epoch, ckpt_dir, ddp)
1071
+ step_ckpt_files.append(ckpt_path)
1072
+ log(f"Saved weights-only: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)")
1073
+
1074
+ if tokens_in_epoch >= tokens_per_epoch:
1075
+ epoch_loss = loss_accum
1076
+ if master:
1077
+ epoch_time = (time.perf_counter() - t0) / 60
1078
+ log(f"Epoch {epoch} complete | loss={epoch_loss:.4f} | Time: {epoch_time:.1f}min")
1079
+
1080
+ for f in step_ckpt_files:
1081
+ if os.path.exists(f):
1082
+ os.remove(f)
1083
+ log(f" Deleted step checkpoint: {os.path.basename(f)}")
1084
+ step_ckpt_files.clear()
1085
+
1086
+ ckpt_path, size_mb = save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, f"ep{epoch}")
1087
+ if ckpt_path:
1088
+ log(f"Saved epoch checkpoint: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)")
1089
+
1090
+ if epoch_loss < best_loss:
1091
+ best_loss = epoch_loss
1092
+ ckpt_path, size_mb = save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, "best")
1093
+ if ckpt_path:
1094
+ log(f"Saved best checkpoint: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)")
1095
+
1096
+ epoch += 1
1097
+ tokens_in_epoch = 0
1098
+
1099
+ if step > start_step and master:
1100
+ ckpt_path, size_mb = save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, f"final-ep{epoch}")
1101
+ if ckpt_path:
1102
+ log(f"Saved final checkpoint: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)")
1103
+
1104
+ if ddp:
1105
+ dist.barrier()
1106
+ dist.destroy_process_group()
1107
+
1108
+ if master:
1109
+ log("Training complete.")
1110
+
1111
+
1112
+ if __name__ == "__main__":
1113
+ main()