CLIWorks commited on
Commit
dce8328
·
verified ·
1 Parent(s): 63d4191

Delete mythos-fineweb-dense.py

Browse files
Files changed (1) hide show
  1. mythos-fineweb-dense.py +0 -1060
mythos-fineweb-dense.py DELETED
@@ -1,1060 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- SpiderPortal v5-Dense: English pretraining on FineWeb-Edu with AdamW.
4
-
5
- Architecture: RDT (2 prelude + 6 recurrent + 2 coda) with:
6
- - MLA (Multi-Latent Attention): 10.7x KV cache compression + sliding window
7
- - Engram conditional memory at recurrent layers 1 and 4
8
- - Dense FFN (all params active, MoE conversion in Phase 2)
9
- - LTI Injection + ACT Halting + LoRA Adapter
10
- - 32k context (extendable to 256k at inference via YaRN)
11
-
12
- Config: hidden_size=2048, 6 recurrent layers, 32 experts (Phase 2), top-2 routing
13
-
14
- Single GPU:
15
- python mythos-fineweb-dense.py
16
-
17
- Multi-GPU:
18
- torchrun --nproc_per_node=$(python -c "import torch; print(torch.cuda.device_count())") mythos-fineweb-dense.py
19
- """
20
- import os
21
- import math
22
- import time
23
- import torch
24
- import torch.nn as nn
25
- import torch.nn.functional as F
26
- import torch.distributed as dist
27
- from loguru import logger
28
- from torch.distributed.fsdp import (
29
- FullyShardedDataParallel as FSDP,
30
- ShardingStrategy,
31
- MixedPrecision,
32
- FullStateDictConfig,
33
- StateDictType,
34
- )
35
- from torch.distributed.fsdp.wrap import ModuleWrapPolicy
36
- from torch.utils.data import IterableDataset, DataLoader, get_worker_info
37
- from contextlib import nullcontext
38
- from dataclasses import dataclass, field
39
- from typing import Optional, Tuple, Dict, List
40
- from torch.nn import CrossEntropyLoss
41
- from datasets import load_dataset
42
- from transformers import AutoTokenizer
43
-
44
-
45
- # ---------------------------------------------------------------------------
46
- # SpiderPortal Model Architecture (Dense + MLA + Engram)
47
- # ---------------------------------------------------------------------------
48
-
49
- @dataclass
50
- class SpiderPortalConfig:
51
- vocab_size: int = 50257
52
- hidden_size: int = 2048
53
- num_hidden_layers: int = 6
54
- num_attention_heads: int = 16
55
- num_key_value_heads: int = 4
56
- intermediate_size: int = 8192
57
- hidden_act: str = "silu"
58
- num_experts: int = 32
59
- num_experts_per_tok: int = 2
60
- num_shared_experts: int = 1
61
- router_aux_loss_coef: float = 0.05
62
- max_loop_iters: int = 4
63
- act_threshold: float = 0.5
64
- max_position_embeddings: int = 32768
65
- rope_theta: float = 10000000.0
66
- rope_scaling: dict = None
67
- sliding_window: int = 4096
68
- attention_dropout: float = 0.0
69
- rms_norm_eps: float = 1e-6
70
- initializer_range: float = 0.02
71
- use_cache: bool = True
72
- tie_word_embeddings: bool = True
73
- prelude_layers: int = 2
74
- coda_layers: int = 2
75
- lora_rank: int = 128
76
- loop_embed_dim: int = 128
77
- vision_hidden_size: int = 2048
78
- audio_hidden_size: int = 512
79
- vision_num_frames: int = 60
80
- vision_tokens_per_frame: int = 256
81
- vision_temporal_tokens: int = 64
82
- vision_temporal_layers: int = 2
83
- model_type: str = "spiderportal"
84
- torch_dtype: str = "bfloat16"
85
-
86
- # MLA parameters (DeepSeek-V2 style, scaled for hidden_size=2048)
87
- kv_lora_rank: int = 128
88
- q_lora_rank: int = 256
89
- qk_rope_head_dim: int = 64
90
- qk_nope_head_dim: int = 64
91
- v_head_dim: int = 64
92
-
93
- # Engram parameters (DeepSeek conditional memory)
94
- engram_layers: List[int] = field(default_factory=lambda: [1, 4])
95
- engram_ngram_orders: Tuple[int, ...] = (2, 3)
96
- engram_hash_heads: int = 4
97
- engram_table_size: int = 65537 # prime number for hash table
98
- engram_conv_kernel: int = 4
99
- engram_conv_dilation: int = 3
100
- engram_dim: int = 128 # per-head embedding dimension
101
-
102
-
103
- def loop_index_embedding(h, loop_t, loop_dim, theta=10000.0):
104
- freqs = 1.0 / (theta ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim))
105
- angles = loop_t * freqs
106
- emb = torch.cat([angles.sin(), angles.cos()], dim=-1)[:loop_dim]
107
- emb_full = torch.zeros(h.shape[-1], device=h.device, dtype=h.dtype)
108
- emb_full[:loop_dim] = emb
109
- return h + emb_full.unsqueeze(0).unsqueeze(0)
110
-
111
-
112
- class SpiderPortalRMSNorm(nn.Module):
113
- def __init__(self, hidden_size, eps=1e-6):
114
- super().__init__()
115
- self.weight = nn.Parameter(torch.ones(hidden_size))
116
- self.variance_epsilon = eps
117
- def forward(self, hidden_states):
118
- input_dtype = hidden_states.dtype
119
- hidden_states = hidden_states.to(torch.float32)
120
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
121
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
122
- return self.weight.to(input_dtype) * hidden_states.to(input_dtype)
123
-
124
-
125
- def compute_yarn_inv_freq(head_dim, rope_theta, factor, orig_max, beta_fast=32.0, beta_slow=1.0):
126
- dim = head_dim
127
- orig_inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2).float() / dim))
128
- pos_freqs = torch.arange(0, dim, 2).float() / dim
129
- beta = (pos_freqs * math.log(rope_theta) / math.log(orig_max))
130
- scale = torch.where(beta < beta_slow, torch.ones_like(beta), torch.where(beta > beta_fast, torch.ones_like(beta) / factor, 1.0 - (beta - beta_slow) / (beta_fast - beta_slow) * (1.0 - 1.0 / factor)))
131
- return orig_inv_freq * scale
132
-
133
-
134
- # ---------------------------------------------------------------------------
135
- # MLA: Multi-Latent Attention (DeepSeek-V2 style) + Sliding Window
136
- # ---------------------------------------------------------------------------
137
-
138
- class SpiderPortalMLA(nn.Module):
139
- """Multi-Latent Attention with compressed KV cache and sliding window.
140
-
141
- For hidden_size=2048, num_heads=16:
142
- - qk_nope_head_dim=64, qk_rope_head_dim=64 → total head_dim=128
143
- - kv_lora_rank=128 → 10.7x compression vs full 2048-dim KV
144
- - v_head_dim=64 → value projection
145
- - sliding_window=4096 → local attention range
146
- """
147
- def __init__(self, config):
148
- super().__init__()
149
- self.config = config
150
- self.hidden_size = config.hidden_size
151
- self.num_heads = config.num_attention_heads
152
- self.kv_lora_rank = config.kv_lora_rank
153
- self.q_lora_rank = config.q_lora_rank
154
- self.qk_rope_head_dim = config.qk_rope_head_dim
155
- self.qk_nope_head_dim = config.qk_nope_head_dim
156
- self.v_head_dim = config.v_head_dim
157
- self.head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
158
- self.sliding_window = getattr(config, 'sliding_window', None)
159
-
160
- # Q projection: optional low-rank → full Q
161
- if self.q_lora_rank > 0:
162
- self.q_a_proj = nn.Linear(config.hidden_size, self.q_lora_rank, bias=False)
163
- self.q_a_layernorm = SpiderPortalRMSNorm(self.q_lora_rank)
164
- self.q_b_proj = nn.Linear(self.q_lora_rank, self.num_heads * self.head_dim, bias=False)
165
- else:
166
- self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
167
-
168
- # KV compression: hidden → kv_lora_rank (shared latent)
169
- self.kv_a_proj_with_mqa = nn.Linear(config.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=False)
170
- self.kv_a_layernorm = SpiderPortalRMSNorm(self.kv_lora_rank)
171
- # Decompress: kv_lora_rank → nope heads + v heads
172
- self.kv_b_proj = nn.Linear(
173
- self.kv_lora_rank,
174
- self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
175
- bias=False,
176
- )
177
- # Output projection
178
- self.o_proj = nn.Linear(self.num_heads * self.v_head_dim, config.hidden_size, bias=False)
179
-
180
- # RoPE frequencies
181
- rope_scaling = getattr(config, 'rope_scaling', None)
182
- if rope_scaling and rope_scaling.get("type") == "yarn":
183
- factor = rope_scaling.get("factor", 1.0)
184
- orig_max_pos = rope_scaling.get("original_max_position_embeddings", config.max_position_embeddings)
185
- inv_freq = compute_yarn_inv_freq(self.qk_rope_head_dim, config.rope_theta, factor, orig_max_pos)
186
- else:
187
- inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, self.qk_rope_head_dim, 2).float() / self.qk_rope_head_dim))
188
- self.register_buffer("inv_freq", inv_freq, persistent=False)
189
-
190
- def _rotate_half(self, x):
191
- x1 = x[..., :x.shape[-1] // 2]
192
- x2 = x[..., x.shape[-1] // 2:]
193
- return torch.cat((-x2, x1), dim=-1)
194
-
195
- def _apply_rotary(self, x, cos, sin):
196
- return (x * cos) + (self._rotate_half(x) * sin)
197
-
198
- def _make_sliding_window_mask(self, q_len, kv_len, device, dtype):
199
- """Create a sliding window causal mask."""
200
- if self.sliding_window is None or self.sliding_window <= 0:
201
- return None
202
- mask = torch.full((q_len, kv_len), torch.finfo(dtype).min, device=device, dtype=dtype)
203
- for i in range(q_len):
204
- start = max(0, i - self.sliding_window + 1)
205
- mask[i, start:i + 1] = 0.0
206
- return mask
207
-
208
- def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
209
- bsz, q_len, _ = hidden_states.size()
210
-
211
- # Q projection
212
- if self.q_lora_rank > 0:
213
- q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
214
- else:
215
- q = self.q_proj(hidden_states)
216
- q = q.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
217
- q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
218
-
219
- # KV: compress to latent, then decompress
220
- kv_hidden = self.kv_a_proj_with_mqa(hidden_states)
221
- kv_latent, k_rope = torch.split(kv_hidden, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
222
- kv_latent_norm = self.kv_a_layernorm(kv_latent)
223
- kv_b_out = self.kv_b_proj(kv_latent_norm)
224
- k_nope, v = torch.split(kv_b_out, [self.num_heads * self.qk_nope_head_dim, self.num_heads * self.v_head_dim], dim=-1)
225
-
226
- k_nope = k_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2)
227
- v = v.view(bsz, q_len, self.num_heads, self.v_head_dim).transpose(1, 2)
228
- k_rope = k_rope.unsqueeze(1)
229
-
230
- # RoPE on Q and K rope parts
231
- if position_ids is None:
232
- position_ids = torch.arange(q_len, device=hidden_states.device).unsqueeze(0).expand(bsz, -1)
233
- max_pos = position_ids.max().item() + 1
234
- seq_len = max(max_pos, q_len)
235
- t = torch.arange(seq_len, device=hidden_states.device, dtype=self.inv_freq.dtype)
236
- freqs = torch.outer(t, self.inv_freq)
237
- emb = torch.cat((freqs, freqs), dim=-1)
238
- cos, sin = emb.cos(), emb.sin()
239
- cos = cos[position_ids].unsqueeze(1)
240
- sin = sin[position_ids].unsqueeze(1)
241
-
242
- q_rope = self._apply_rotary(q_rope, cos, sin)
243
- k_rope = self._apply_rotary(k_rope, cos, sin)
244
-
245
- # Assemble full K
246
- k_rope_expanded = k_rope.expand(-1, self.num_heads, -1, -1)
247
- k_full = torch.cat([k_nope, k_rope_expanded], dim=-1)
248
- q_full = torch.cat([q_nope, q_rope], dim=-1)
249
-
250
- # KV cache
251
- if past_key_value is not None:
252
- k_full = torch.cat([past_key_value[0], k_full], dim=2)
253
- v = torch.cat([past_key_value[1], v], dim=2)
254
- past_kv = (k_full, v) if use_cache else None
255
-
256
- # Build attention mask: user mask + sliding window
257
- final_mask = attention_mask
258
- if self.sliding_window is not None and self.sliding_window > 0:
259
- kv_len = k_full.size(2)
260
- sw_mask = self._make_sliding_window_mask(q_len, kv_len, hidden_states.device, hidden_states.dtype)
261
- if final_mask is not None:
262
- final_mask = final_mask + sw_mask
263
- else:
264
- final_mask = sw_mask
265
-
266
- # Attention with SDPA
267
- attn_output = F.scaled_dot_product_attention(
268
- q_full, k_full, v,
269
- attn_mask=final_mask,
270
- dropout_p=self.config.attention_dropout if self.training else 0.0,
271
- is_causal=(final_mask is None),
272
- )
273
- attn_output = attn_output.transpose(1, 2).contiguous()
274
- attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
275
- return self.o_proj(attn_output), past_kv
276
-
277
-
278
- # ---------------------------------------------------------------------------
279
- # Engram: Conditional Memory via Scalable Lookup (DeepSeek style)
280
- # ---------------------------------------------------------------------------
281
-
282
- def _tokenizer_compress(token_ids, vocab_size=50257):
283
- """Simulate NFKC + lowercase canonical ID projection."""
284
- return token_ids % (vocab_size * 77 // 100)
285
-
286
-
287
- class SpiderPortalEngram(nn.Module):
288
- """Conditional memory module via NN-gram lookup.
289
-
290
- Applied only at specific recurrent layers (config.engram_layers).
291
- """
292
- def __init__(self, config):
293
- super().__init__()
294
- self.config = config
295
- self.ngram_orders = config.engram_ngram_orders
296
- self.num_heads = config.engram_hash_heads
297
- self.table_size = config.engram_table_size
298
- self.d_mem = config.engram_dim
299
-
300
- self.total_mem_dim = len(self.ngram_orders) * self.num_heads * self.d_mem
301
-
302
- self.embed_tables = nn.ParameterDict()
303
- for n in self.ngram_orders:
304
- for h in range(self.num_heads):
305
- key = f"e_{n}_{h}"
306
- self.embed_tables[key] = nn.Parameter(
307
- torch.randn(self.table_size, self.d_mem) * 0.02
308
- )
309
-
310
- self.register_buffer("hash_seeds", torch.tensor([
311
- (h + 1) * 2654435761
312
- for _ in self.ngram_orders
313
- for h in range(self.num_heads)
314
- ], dtype=torch.int64))
315
-
316
- self.W_k = nn.Linear(self.total_mem_dim, config.hidden_size, bias=False)
317
- self.W_v = nn.Linear(self.total_mem_dim, config.hidden_size, bias=False)
318
-
319
- self.conv = nn.Conv1d(
320
- config.hidden_size, config.hidden_size,
321
- kernel_size=config.engram_conv_kernel,
322
- padding=config.engram_conv_kernel - 1,
323
- groups=config.hidden_size,
324
- )
325
- self.conv_dilation = config.engram_conv_dilation
326
-
327
- with torch.no_grad():
328
- self.conv.weight.zero_()
329
- if self.conv.bias is not None:
330
- self.conv.bias.zero_()
331
-
332
- self.q_norm = SpiderPortalRMSNorm(config.hidden_size)
333
- self.k_norm = SpiderPortalRMSNorm(config.hidden_size)
334
-
335
- def _compute_indices(self, compressed_ids, n, head_idx):
336
- """Vectorized NN-gram hash indices for a single (order, head)."""
337
- bsz, seq_len = compressed_ids.shape
338
- pad = torch.zeros(bsz, n - 1, dtype=compressed_ids.dtype, device=compressed_ids.device)
339
- padded = torch.cat([pad, compressed_ids], dim=1)
340
-
341
- indices_list = []
342
- for i in range(n):
343
- indices_list.append(padded[:, i:i + seq_len])
344
- ngrams = torch.stack(indices_list, dim=-1)
345
-
346
- seed = int(self.hash_seeds[head_idx].item())
347
- h_val = torch.zeros(bsz, seq_len, dtype=torch.int64, device=compressed_ids.device)
348
- for i in range(n):
349
- h_val = h_val * 31 + ngrams[:, :, i]
350
- h_val = h_val % self.table_size
351
- h_val = (h_val * seed) % self.table_size
352
- return h_val
353
-
354
- def _retrieve(self, token_ids):
355
- """Retrieve memory vectors for a batch of token sequences."""
356
- bsz, seq_len = token_ids.shape
357
- compressed = _tokenizer_compress(token_ids)
358
-
359
- all_parts = []
360
- head_counter = 0
361
- for n in self.ngram_orders:
362
- for h in range(self.num_heads):
363
- key = f"e_{n}_{h}"
364
- table = self.embed_tables[key]
365
- indices = self._compute_indices(compressed, n, head_counter)
366
- emb = table[indices.view(-1)]
367
- all_parts.append(emb.view(bsz, seq_len, self.d_mem))
368
- head_counter += 1
369
-
370
- memory = torch.cat(all_parts, dim=-1)
371
- return memory
372
-
373
- def forward(self, hidden_states, token_ids):
374
- mem = self._retrieve(token_ids)
375
-
376
- q = hidden_states
377
- k = self.W_k(mem)
378
- v = self.W_v(mem)
379
-
380
- q_norm = self.q_norm(q)
381
- k_norm = self.k_norm(k)
382
-
383
- alpha = torch.sigmoid(
384
- (q_norm * k_norm).sum(dim=-1, keepdim=True) / math.sqrt(q.shape[-1])
385
- )
386
-
387
- v_gated = alpha * v
388
-
389
- v_gated_t = v_gated.transpose(1, 2)
390
- conv_out = self.conv(v_gated_t)
391
- conv_out = conv_out[:, :, :v_gated_t.shape[-1]]
392
- conv_out = conv_out.transpose(1, 2)
393
-
394
- y = F.silu(conv_out) + v_gated
395
-
396
- return y
397
-
398
-
399
- # ---------------------------------------------------------------------------
400
- # FFN Expert (dense)
401
- # ---------------------------------------------------------------------------
402
-
403
- class SpiderPortalExpert(nn.Module):
404
- def __init__(self, config, intermediate_size=None):
405
- super().__init__()
406
- inter_size = intermediate_size or config.intermediate_size
407
- self.gate_proj = nn.Linear(config.hidden_size, inter_size, bias=False)
408
- self.up_proj = nn.Linear(config.hidden_size, inter_size, bias=False)
409
- self.down_proj = nn.Linear(inter_size, config.hidden_size, bias=False)
410
- self.act_fn = nn.SiLU()
411
- def forward(self, hidden_states):
412
- return self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
413
-
414
-
415
- # ---------------------------------------------------------------------------
416
- # Prelude/Coda Dense Layer (uses MLA)
417
- # ---------------------------------------------------------------------------
418
-
419
- class SpiderPortalDenseLayer(nn.Module):
420
- """Prelude/coda dense layer with MLA attention."""
421
- def __init__(self, config):
422
- super().__init__()
423
- self.self_attn = SpiderPortalMLA(config)
424
- dense_intermediate = config.hidden_size * 4 // 3
425
- self.ffn = SpiderPortalExpert(config, intermediate_size=dense_intermediate)
426
- self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
427
- self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
428
- def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
429
- attn_input = self.input_layernorm(hidden_states)
430
- attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
431
- hidden_states = hidden_states + attn_output
432
- ffn_input = self.post_attention_layernorm(hidden_states)
433
- ffn_output = self.ffn(ffn_input)
434
- hidden_states = hidden_states + ffn_output
435
- return hidden_states, past_kv
436
-
437
-
438
- # ---------------------------------------------------------------------------
439
- # Recurrent Dense Layer (uses MLA + optional Engram)
440
- # ---------------------------------------------------------------------------
441
-
442
- class SpiderPortalRecurrentDenseLayer(nn.Module):
443
- """Recurrent layer with MLA attention and optional Engram memory."""
444
- def __init__(self, config, layer_idx, has_engram=False):
445
- super().__init__()
446
- self.layer_idx = layer_idx
447
- self.has_engram = has_engram
448
- self.self_attn = SpiderPortalMLA(config)
449
- if has_engram:
450
- self.engram = SpiderPortalEngram(config)
451
- self.ffn = SpiderPortalExpert(config, intermediate_size=config.intermediate_size)
452
- self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
453
- self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
454
- self.post_engram_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) if has_engram else None
455
- def forward(self, hidden_states, token_ids=None, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
456
- attn_input = self.input_layernorm(hidden_states)
457
- attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
458
- hidden_states = hidden_states + attn_output
459
-
460
- if self.has_engram and token_ids is not None:
461
- engram_out = self.engram(hidden_states, token_ids)
462
- hidden_states = hidden_states + engram_out
463
- if self.post_engram_layernorm is not None:
464
- hidden_states = self.post_engram_layernorm(hidden_states)
465
-
466
- ffn_input = self.post_attention_layernorm(hidden_states)
467
- ffn_output = self.ffn(ffn_input)
468
- hidden_states = hidden_states + ffn_output
469
- return hidden_states, 0.0, past_kv
470
-
471
-
472
- # ---------------------------------------------------------------------------
473
- # LTI Injection, ACT Halting, LoRA Adapter
474
- # ---------------------------------------------------------------------------
475
-
476
- class LTIInjection(nn.Module):
477
- def __init__(self, config):
478
- super().__init__()
479
- self.hidden_size = config.hidden_size
480
- self.log_A = nn.Parameter(torch.full((config.hidden_size,), -2.0))
481
- self.delta_t = nn.Parameter(torch.tensor(1.0))
482
- self.B = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
483
- with torch.no_grad():
484
- self.B.weight.data.normal_(mean=0.0, std=0.01)
485
- def get_A(self):
486
- return -torch.exp(self.log_A)
487
- def forward(self, h_t, e):
488
- A = self.get_A()
489
- return A * h_t + self.B(e)
490
-
491
-
492
- class ACTHalting(nn.Module):
493
- def __init__(self, config):
494
- super().__init__()
495
- self.halt_predictor = nn.Linear(config.hidden_size, 1)
496
- self.threshold = config.act_threshold
497
- def forward(self, hidden_states):
498
- return torch.sigmoid(self.halt_predictor(hidden_states))
499
-
500
-
501
- class LoRAAdapter(nn.Module):
502
- def __init__(self, config):
503
- super().__init__()
504
- rank = config.lora_rank
505
- self.down = nn.Linear(config.hidden_size, rank, bias=False)
506
- self.B = nn.Parameter(torch.randn(rank, config.hidden_size) * 0.02)
507
- self.scale = nn.Embedding(config.max_loop_iters, rank)
508
- with torch.no_grad():
509
- self.scale.weight.data.zero_()
510
- self.down.weight.data.normal_(mean=0.0, std=0.001)
511
- def forward(self, x, loop_t):
512
- max_t = self.scale.num_embeddings - 1
513
- t_idx = min(loop_t, max_t)
514
- s = self.scale(torch.tensor(t_idx, device=x.device))
515
- down = self.down(x) * s
516
- return down @ self.B
517
-
518
-
519
- def checkpoint(func, *args, **kwargs):
520
- """Gradient checkpointing wrapper — saves VRAM at ~20% compute cost."""
521
- if torch.is_grad_enabled():
522
- return torch.utils.checkpoint.checkpoint(func, *args, use_reentrant=False, **kwargs)
523
- return func(*args, **kwargs)
524
-
525
-
526
- # ---------------------------------------------------------------------------
527
- # Full Model
528
- # ---------------------------------------------------------------------------
529
-
530
- class SpiderPortalDenseModel(nn.Module):
531
- """Full RDT model with MLA attention + Engram memory at layers 1,4.
532
-
533
- Architecture:
534
- 2x Prelude (MLA + dense FFN)
535
- 6x Recurrent (MLA + Engram@L1,L4 + dense FFN) — with gradient checkpointing
536
- 2x Coda (MLA + dense FFN)
537
- LTI Injection + ACT Halting + LoRA Adapter
538
- """
539
- def __init__(self, config):
540
- super().__init__()
541
- self.config = config
542
- self.prelude_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.prelude_layers)])
543
- self.recurrent_layers = nn.ModuleList([
544
- SpiderPortalRecurrentDenseLayer(config, i, has_engram=(i in config.engram_layers))
545
- for i in range(config.num_hidden_layers)
546
- ])
547
- self.coda_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.coda_layers)])
548
- self.norm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
549
- self.injection = LTIInjection(config)
550
- self.act_halting = ACTHalting(config)
551
- self.lora_adapter = LoRAAdapter(config)
552
- self.loop_embed_dim = config.loop_embed_dim
553
- def forward(self, hidden_states, input_embedding=None, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, n_loops=None, token_ids=None):
554
- n_loops = n_loops or self.config.max_loop_iters
555
- input_embedding = input_embedding if input_embedding is not None else hidden_states
556
- for layer in self.prelude_layers:
557
- hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
558
- e = hidden_states.clone()
559
- B, T_seq, D = hidden_states.shape
560
- halted = torch.zeros(B, T_seq, device=hidden_states.device, dtype=torch.bool)
561
- cumulative_p = torch.zeros(B, T_seq, device=hidden_states.device, dtype=hidden_states.dtype)
562
- h_out = torch.zeros_like(hidden_states)
563
- past_key_values = past_key_values if past_key_values is not None else [None] * len(self.recurrent_layers)
564
- for t in range(n_loops):
565
- h_loop = loop_index_embedding(hidden_states, t, self.loop_embed_dim)
566
- if t > 0:
567
- injection = self.injection(hidden_states, input_embedding)
568
- hidden_states = hidden_states + injection
569
- new_past_key_values = []
570
- for i, layer in enumerate(self.recurrent_layers):
571
- hidden_states, aux_loss, past_kv = checkpoint(
572
- layer, hidden_states,
573
- token_ids=token_ids,
574
- attention_mask=attention_mask,
575
- position_ids=position_ids,
576
- past_key_value=past_key_values[i] if t == 0 else None,
577
- use_cache=use_cache
578
- )
579
- new_past_key_values.append(past_kv)
580
- lora_delta = self.lora_adapter(hidden_states, t)
581
- hidden_states = hidden_states + lora_delta
582
- halt_prob = self.act_halting(hidden_states).squeeze(-1)
583
- still_running = ~halted
584
- remainder = (1.0 - cumulative_p).clamp(min=0)
585
- weight = torch.where(cumulative_p + halt_prob >= self.config.act_threshold, remainder, halt_prob)
586
- weight = weight * still_running.to(hidden_states.dtype)
587
- h_out = h_out + weight.unsqueeze(-1) * hidden_states
588
- cumulative_p = cumulative_p + halt_prob * still_running.to(hidden_states.dtype)
589
- halted = halted | (cumulative_p >= self.config.act_threshold)
590
- if halted.all() and not self.training:
591
- break
592
- never_halted = (~halted).to(hidden_states.dtype).unsqueeze(-1)
593
- hidden_states = h_out + never_halted * hidden_states
594
- for layer in self.coda_layers:
595
- hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
596
- hidden_states = self.norm(hidden_states)
597
- return hidden_states, 0.0, new_past_key_values
598
-
599
-
600
- class SpiderPortalForConditionalGeneration(nn.Module):
601
- def __init__(self, config):
602
- super().__init__()
603
- self.config = config
604
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
605
- self.model = SpiderPortalDenseModel(config)
606
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
607
- if config.tie_word_embeddings:
608
- self.lm_head.weight = self.embed_tokens.weight
609
- self.apply(self._init_weights)
610
- def _init_weights(self, module):
611
- if isinstance(module, nn.Linear):
612
- if hasattr(self, 'model') and module is self.model.injection.B:
613
- return
614
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
615
- if module.bias is not None:
616
- module.bias.data.zero_()
617
- elif isinstance(module, nn.Embedding):
618
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
619
- def forward(self, input_ids, attention_mask=None, position_ids=None, labels=None, n_loops=None, use_cache=False):
620
- hidden_states = self.embed_tokens(input_ids)
621
- model_dtype = next(self.model.parameters()).dtype
622
- hidden_states = hidden_states.to(model_dtype)
623
- input_embedding = hidden_states.clone()
624
- if attention_mask is None:
625
- attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
626
- causal_mask = torch.full((attention_mask.size(0), 1, attention_mask.size(1), attention_mask.size(1)), 0.0, dtype=hidden_states.dtype, device=hidden_states.device)
627
- causal_mask = causal_mask.masked_fill(~attention_mask.unsqueeze(1).unsqueeze(2), torch.finfo(hidden_states.dtype).min)
628
- causal_mask = causal_mask.triu(1)
629
- hidden_states, aux_loss, past_kv = self.model(
630
- hidden_states, input_embedding=input_embedding,
631
- attention_mask=causal_mask, position_ids=position_ids,
632
- use_cache=use_cache, n_loops=n_loops, token_ids=input_ids
633
- )
634
- logits = self.lm_head(hidden_states)
635
- loss = None
636
- if labels is not None:
637
- shift_logits = logits[..., :-1, :].contiguous()
638
- shift_labels = labels[..., 1:].contiguous()
639
- loss_fct = CrossEntropyLoss()
640
- loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
641
- return {"loss": loss, "logits": logits, "aux_loss": aux_loss, "past_key_values": past_kv}
642
- def get_num_params(self):
643
- total = sum(p.numel() for p in self.parameters())
644
- return {"total": total, "trainable": total}
645
-
646
-
647
- # ---------------------------------------------------------------------------
648
- # Dataset
649
- # ---------------------------------------------------------------------------
650
-
651
- class FineWebEduDataset(IterableDataset):
652
- def __init__(self, tokenizer, seq_len: int, subset: str, rank: int, world_size: int):
653
- self.tokenizer = tokenizer
654
- self.seq_len = seq_len
655
- self.subset = subset
656
- self.rank = rank
657
- self.world_size = world_size
658
- def __iter__(self):
659
- worker = get_worker_info()
660
- num_workers = worker.num_workers if worker else 1
661
- worker_id = worker.id if worker else 0
662
- total_shards = self.world_size * num_workers
663
- shard_index = self.rank * num_workers + worker_id
664
- ds = load_dataset(
665
- "HuggingFaceFW/fineweb-edu",
666
- name=self.subset,
667
- split="train",
668
- streaming=True,
669
- ).shard(num_shards=total_shards, index=shard_index)
670
- buf = []
671
- for sample in ds:
672
- buf.extend(self.tokenizer.encode(sample["text"]))
673
- while len(buf) >= self.seq_len + 1:
674
- chunk = buf[: self.seq_len + 1]
675
- buf = buf[self.seq_len + 1 :]
676
- yield (
677
- torch.tensor(chunk[:-1], dtype=torch.long),
678
- torch.tensor(chunk[1:], dtype=torch.long),
679
- )
680
-
681
-
682
- # ---------------------------------------------------------------------------
683
- # LR schedule
684
- # ---------------------------------------------------------------------------
685
-
686
- def get_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) -> float:
687
- if step < warmup:
688
- return max_lr * step / warmup
689
- if step >= total:
690
- return min_lr
691
- decay = (step - warmup) / (total - warmup)
692
- return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * decay))
693
-
694
-
695
- # ---------------------------------------------------------------------------
696
- # Checkpointing
697
- # ---------------------------------------------------------------------------
698
-
699
- def save_weights_only(model, step, epoch, ckpt_dir, ddp):
700
- if ddp:
701
- with FSDP.state_dict_type(
702
- model,
703
- StateDictType.FULL_STATE_DICT,
704
- FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
705
- ):
706
- model_state = model.state_dict()
707
- else:
708
- model_state = model.state_dict()
709
- ckpt_path = os.path.join(ckpt_dir, f"spiderportal-v5-dense-ep{epoch}-step{step}.pt")
710
- tmp_path = ckpt_path + ".tmp"
711
- torch.save(model_state, tmp_path)
712
- os.replace(tmp_path, ckpt_path)
713
- size_mb = os.path.getsize(ckpt_path) / (1024 * 1024)
714
- return ckpt_path, size_mb
715
-
716
-
717
- def save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, ckpt_name="full"):
718
- if ddp:
719
- with FSDP.state_dict_type(
720
- model,
721
- StateDictType.FULL_STATE_DICT,
722
- FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
723
- ):
724
- model_state = model.state_dict()
725
- optim_state = FSDP.optim_state_dict(model, optimizer)
726
- else:
727
- model_state = model.state_dict()
728
- optim_state = optimizer.state_dict()
729
- if not master:
730
- return None, 0
731
- os.makedirs(ckpt_dir, exist_ok=True)
732
- final_path = os.path.join(ckpt_dir, f"spiderportal-v5-dense-{ckpt_name}.pt")
733
- tmp_path = final_path + ".tmp"
734
- torch.save(
735
- {
736
- "step": step,
737
- "epoch": epoch,
738
- "model_state_dict": model_state,
739
- "optimizer_state_dict": optim_state,
740
- "cfg": cfg,
741
- "vocab_size": vocab_size,
742
- },
743
- tmp_path,
744
- )
745
- os.replace(tmp_path, final_path)
746
- size_mb = os.path.getsize(final_path) / (1024 * 1024)
747
- return final_path, size_mb
748
-
749
-
750
- def load_checkpoint(model, optimizer, path, ddp):
751
- ckpt = torch.load(path, map_location="cpu", weights_only=False)
752
- if ddp:
753
- with FSDP.state_dict_type(
754
- model,
755
- StateDictType.FULL_STATE_DICT,
756
- FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
757
- ):
758
- model.load_state_dict(ckpt["model_state_dict"])
759
- optim_state = FSDP.optim_state_dict_to_load(
760
- model=model,
761
- optim=optimizer,
762
- optim_state_dict=ckpt["optimizer_state_dict"],
763
- )
764
- optimizer.load_state_dict(optim_state)
765
- else:
766
- model.load_state_dict(ckpt["model_state_dict"])
767
- optimizer.load_state_dict(ckpt["optimizer_state_dict"])
768
- return int(ckpt["step"]), int(ckpt.get("epoch", 0))
769
-
770
-
771
- # ---------------------------------------------------------------------------
772
- # Main
773
- # ---------------------------------------------------------------------------
774
-
775
- def main():
776
- # ------------------------------------------------------------------
777
- # Distributed init
778
- # ------------------------------------------------------------------
779
- ddp = int(os.environ.get("RANK", -1)) != -1
780
- if ddp:
781
- dist.init_process_group("nccl")
782
- rank = int(os.environ["RANK"])
783
- local_rank = int(os.environ["LOCAL_RANK"])
784
- world_size = int(os.environ["WORLD_SIZE"])
785
- device = f"cuda:{local_rank}"
786
- torch.cuda.set_device(device)
787
- else:
788
- rank = local_rank = 0
789
- world_size = 1
790
- device = "cuda" if torch.cuda.is_available() else "cpu"
791
- master = rank == 0
792
- if master:
793
- logger.info(
794
- f"GPUs: {torch.cuda.device_count()} | World size: {world_size} | Device: {device}"
795
- )
796
-
797
- # ------------------------------------------------------------------
798
- # Tokenizer
799
- # ------------------------------------------------------------------
800
- tokenizer = AutoTokenizer.from_pretrained("gpt2")
801
- tokenizer.pad_token = tokenizer.eos_token
802
- vocab_size = tokenizer.vocab_size
803
- if master:
804
- logger.info(f"Tokenizer: gpt2 | Vocab size: {vocab_size:,}")
805
-
806
- # ------------------------------------------------------------------
807
- # Hyperparameters
808
- # ------------------------------------------------------------------
809
- seq_len = 2048
810
- micro_batch = 16
811
- target_tokens = 20_000_000_000
812
- grad_accum = 2
813
- global_batch_tok = world_size * micro_batch * grad_accum * seq_len
814
- total_steps = target_tokens // global_batch_tok
815
- warmup_steps = 200
816
- lr = 3e-4
817
- wd = 0.1
818
- log_every = 10
819
- ckpt_every = 500
820
- ckpt_dir = "checkpoints-dense"
821
- dataset_subset = "sample-10BT"
822
-
823
- if master:
824
- logger.info(
825
- f"[DENSE MLA+Engram] hidden=2048 | layers=6 | seq_len={seq_len} | micro_batch={micro_batch} | grad_accum={grad_accum} | "
826
- f"global_batch_tokens={global_batch_tok:,} | total_steps={total_steps:,}"
827
- )
828
- logger.info(
829
- f"Attention: MLA (kv_lora_rank=128, sliding_window=4096) | "
830
- f"Engram: layers [1,4] | Context: 32k | "
831
- f"Gradient checkpointing: enabled"
832
- )
833
-
834
- # ------------------------------------------------------------------
835
- # Model
836
- # ------------------------------------------------------------------
837
- cfg = SpiderPortalConfig(
838
- hidden_size=2048, num_hidden_layers=6, num_attention_heads=16,
839
- num_key_value_heads=4, intermediate_size=8192,
840
- num_experts=32, num_experts_per_tok=2, num_shared_experts=1,
841
- router_aux_loss_coef=0.05, max_loop_iters=4,
842
- prelude_layers=2, coda_layers=2, lora_rank=128,
843
- rope_theta=10000000.0,
844
- rope_scaling=None,
845
- max_position_embeddings=32768, sliding_window=4096,
846
- tie_word_embeddings=True,
847
- kv_lora_rank=128, q_lora_rank=256,
848
- qk_rope_head_dim=64, qk_nope_head_dim=64, v_head_dim=64,
849
- engram_layers=[1, 4],
850
- engram_ngram_orders=(2, 3),
851
- engram_hash_heads=4,
852
- engram_table_size=65537,
853
- engram_dim=128,
854
- )
855
- cfg.vocab_size = vocab_size
856
-
857
- bf16_ok = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
858
- amp_dtype = torch.bfloat16 if bf16_ok else torch.float16
859
-
860
- model = SpiderPortalForConditionalGeneration(cfg)
861
-
862
- if ddp:
863
- mp_policy = MixedPrecision(
864
- param_dtype=amp_dtype,
865
- reduce_dtype=amp_dtype,
866
- buffer_dtype=amp_dtype,
867
- )
868
- wrap_policy = ModuleWrapPolicy({SpiderPortalDenseLayer, SpiderPortalRecurrentDenseLayer})
869
- model = FSDP(
870
- model,
871
- sharding_strategy=ShardingStrategy.FULL_SHARD,
872
- mixed_precision=mp_policy,
873
- auto_wrap_policy=wrap_policy,
874
- device_id=local_rank,
875
- )
876
- else:
877
- model = model.to(device)
878
- amp_ctx = (
879
- torch.amp.autocast(device_type="cuda", dtype=amp_dtype)
880
- if "cuda" in device
881
- else nullcontext()
882
- )
883
-
884
- amp_ctx = nullcontext() if ddp else amp_ctx
885
-
886
- if master:
887
- n_params = sum(p.numel() for p in model.parameters())
888
- engram_params = sum(p.numel() for n, p in model.named_parameters() if 'engram' in n)
889
- mla_params = sum(p.numel() for n, p in model.named_parameters() if 'self_attn' in n)
890
- embed_params = sum(p.numel() for n, p in model.named_parameters() if 'embed_tokens' in n or 'lm_head' in n)
891
- ffn_params = sum(p.numel() for n, p in model.named_parameters() if 'ffn' in n or 'gate_proj' in n or 'up_proj' in n or 'down_proj' in n)
892
- other_params = n_params - engram_params - mla_params - embed_params - ffn_params
893
- logger.info(
894
- f"Parameters: {n_params:,} (all active) | "
895
- f"Embeddings: {embed_params:,} | MLA: {mla_params:,} | "
896
- f"FFN: {ffn_params:,} | Engram: {engram_params:,} | "
897
- f"Other: {other_params:,} | AMP dtype: {amp_dtype}"
898
- )
899
-
900
- # ------------------------------------------------------------------
901
- # Optimizer — dual optimizer for Engram embeddings
902
- # ------------------------------------------------------------------
903
- engram_params_list = [p for n, p in model.named_parameters() if 'engram' in n and 'embed_tables' in n]
904
- backbone_params = [p for n, p in model.named_parameters() if 'engram' not in n or 'embed_tables' not in n]
905
-
906
- optimizer = torch.optim.AdamW(
907
- backbone_params, lr=lr, weight_decay=wd, betas=(0.9, 0.95), fused=True
908
- )
909
- if engram_params_list:
910
- engram_optimizer = torch.optim.Adam(
911
- engram_params_list, lr=lr * 5, betas=(0.9, 0.95), eps=1e-8
912
- )
913
- else:
914
- engram_optimizer = None
915
-
916
- # ------------------------------------------------------------------
917
- # Resume from latest checkpoint
918
- # ------------------------------------------------------------------
919
- start_step = 0
920
- start_epoch = 1
921
- best_loss = float("inf")
922
- existing_ckpts = [f for f in os.listdir(ckpt_dir) if f.startswith("spiderportal-v5-dense-ep") and f.endswith(".pt") and "-step" not in f] if os.path.isdir(ckpt_dir) else []
923
- if existing_ckpts:
924
- latest = os.path.join(ckpt_dir, sorted(existing_ckpts)[-1])
925
- if master:
926
- logger.info(f"Resuming from checkpoint: {latest}")
927
- start_step, start_epoch = load_checkpoint(model, optimizer, latest, ddp)
928
- if master:
929
- logger.success(f"Resumed at step {start_step}, epoch {start_epoch}")
930
-
931
- # ------------------------------------------------------------------
932
- # Dataset + DataLoader
933
- # ------------------------------------------------------------------
934
- dataset = FineWebEduDataset(tokenizer, seq_len, dataset_subset, rank, world_size)
935
- loader = DataLoader(dataset, batch_size=micro_batch, num_workers=8, pin_memory=True, prefetch_factor=2)
936
-
937
- # ------------------------------------------------------------------
938
- # Training loop
939
- # ------------------------------------------------------------------
940
- if master:
941
- os.makedirs(ckpt_dir, exist_ok=True)
942
-
943
- model.train()
944
- data_iter = iter(loader)
945
- t0 = time.perf_counter()
946
- step = start_step
947
- epoch = start_epoch
948
- step_ckpt_files = []
949
- tokens_in_epoch = 0
950
- tokens_per_epoch = target_tokens
951
-
952
- while step < total_steps:
953
- cur_lr = get_lr(step, warmup_steps, total_steps, lr, lr * 0.1)
954
- for g in optimizer.param_groups:
955
- g["lr"] = cur_lr
956
- if engram_optimizer:
957
- for g in engram_optimizer.param_groups:
958
- g["lr"] = cur_lr * 5
959
-
960
- optimizer.zero_grad()
961
- if engram_optimizer:
962
- engram_optimizer.zero_grad()
963
- loss_accum = 0.0
964
-
965
- for micro_step in range(grad_accum):
966
- try:
967
- x, y = next(data_iter)
968
- except StopIteration:
969
- data_iter = iter(loader)
970
- x, y = next(data_iter)
971
-
972
- x = x.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True)
973
- y = y.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True)
974
-
975
- sync = (
976
- nullcontext()
977
- if (not ddp or micro_step == grad_accum - 1)
978
- else model.no_sync()
979
- )
980
- with sync, amp_ctx:
981
- output = model(x)
982
- if isinstance(output, dict):
983
- logits = output["logits"]
984
- else:
985
- logits = output
986
- loss = nn.functional.cross_entropy(
987
- logits.view(-1, vocab_size), y.view(-1)
988
- )
989
- loss = loss / grad_accum
990
-
991
- loss.backward()
992
- loss_accum += loss.item()
993
-
994
- if ddp:
995
- grad_norm = model.clip_grad_norm_(1.0)
996
- else:
997
- grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
998
- optimizer.step()
999
- if engram_optimizer:
1000
- engram_optimizer.step()
1001
- step += 1
1002
- tokens_in_epoch += global_batch_tok
1003
-
1004
- if master and step % log_every == 0:
1005
- dt = time.perf_counter() - t0
1006
- tok_per_sec = global_batch_tok * log_every / dt
1007
- tokens_seen = step * global_batch_tok
1008
- logger.info(
1009
- f"Epoch {epoch} | step {step:6d}/{total_steps} | loss {loss_accum:.4f} "
1010
- f"| gnorm {float(grad_norm):.2f} | lr {cur_lr:.2e} "
1011
- f"| {tok_per_sec / 1e6:.2f}M tok/s "
1012
- f"| {tokens_seen / 1e9:.2f}B tokens seen"
1013
- )
1014
- t0 = time.perf_counter()
1015
-
1016
- if step % ckpt_every == 0 and master:
1017
- ckpt_path, size_mb = save_weights_only(model, step, epoch, ckpt_dir, ddp)
1018
- step_ckpt_files.append(ckpt_path)
1019
- logger.info(f"Saved weights-only: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)")
1020
-
1021
- if tokens_in_epoch >= tokens_per_epoch:
1022
- epoch_loss = loss_accum
1023
- if master:
1024
- epoch_time = (time.perf_counter() - t0) / 60
1025
- logger.info(f"Epoch {epoch} complete | loss={epoch_loss:.4f} | Time: {epoch_time:.1f}min")
1026
-
1027
- for f in step_ckpt_files:
1028
- if os.path.exists(f):
1029
- os.remove(f)
1030
- logger.info(f" Deleted step checkpoint: {os.path.basename(f)}")
1031
- step_ckpt_files.clear()
1032
-
1033
- ckpt_path, size_mb = save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, f"ep{epoch}")
1034
- if ckpt_path:
1035
- logger.info(f"Saved epoch checkpoint: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)")
1036
-
1037
- if epoch_loss < best_loss:
1038
- best_loss = epoch_loss
1039
- ckpt_path, size_mb = save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, "best")
1040
- if ckpt_path:
1041
- logger.info(f"Saved best checkpoint: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)")
1042
-
1043
- epoch += 1
1044
- tokens_in_epoch = 0
1045
-
1046
- if step > start_step and master:
1047
- ckpt_path, size_mb = save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, f"final-ep{epoch}")
1048
- if ckpt_path:
1049
- logger.info(f"Saved final checkpoint: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)")
1050
-
1051
- if ddp:
1052
- dist.barrier()
1053
- dist.destroy_process_group()
1054
-
1055
- if master:
1056
- logger.success("Training complete.")
1057
-
1058
-
1059
- if __name__ == "__main__":
1060
- main()