| |
| """ |
| SpiderPortal v5-Dense: English pretraining on FineWeb-Edu with AdamW. |
| Architecture: RDT (2 prelude + 6 recurrent + 2 coda) with: |
| - MLA (Multi-Latent Attention): 10.7x KV cache compression + sliding window |
| - Engram conditional memory at recurrent layers 1 and 4 |
| - Dense FFN (all params active, MoE conversion in Phase 2) |
| - LTI Injection + ACT Halting + LoRA Adapter |
| - 32k context (extendable to 256k at inference via YaRN) |
| Config: hidden_size=2048, 6 recurrent layers, 32 experts (Phase 2), top-2 routing |
| Single GPU: |
| python mythos-fineweb-dense.py |
| Multi-GPU: |
| torchrun --nproc_per_node=$(python -c "import torch; print(torch.cuda.device_count())") mythos-fineweb-dense.py |
| """ |
| import os |
| import math |
| import time |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.distributed as dist |
| import sys |
|
|
| |
| def log(msg, level="INFO"): |
| ts = time.strftime("%Y-%m-%d %H:%M:%S") |
| print(f"{ts} | {level} | {msg}", flush=True) |
|
|
| |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512,expandable_segments:True" |
| from torch.distributed.fsdp import ( |
| FullyShardedDataParallel as FSDP, |
| ShardingStrategy, |
| MixedPrecision, |
| FullStateDictConfig, |
| StateDictType, |
| ) |
| from torch.distributed.fsdp.wrap import ModuleWrapPolicy |
| from torch.utils.data import IterableDataset, DataLoader, get_worker_info |
| from contextlib import nullcontext |
| from dataclasses import dataclass, field |
| from typing import Optional, Tuple, Dict, List |
| from torch.nn import CrossEntropyLoss |
| from datasets import load_dataset |
| from transformers import AutoTokenizer |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class SpiderPortalConfig: |
| vocab_size: int = 50257 |
| hidden_size: int = 2048 |
| num_hidden_layers: int = 6 |
| num_attention_heads: int = 16 |
| num_key_value_heads: int = 4 |
| intermediate_size: int = 8192 |
| hidden_act: str = "silu" |
| num_experts: int = 32 |
| num_experts_per_tok: int = 2 |
| num_shared_experts: int = 1 |
| router_aux_loss_coef: float = 0.05 |
| max_loop_iters: int = 4 |
| act_threshold: float = 0.5 |
| max_position_embeddings: int = 32768 |
| rope_theta: float = 10000000.0 |
| rope_scaling: dict = None |
| sliding_window: int = 4096 |
| attention_dropout: float = 0.0 |
| rms_norm_eps: float = 1e-6 |
| initializer_range: float = 0.02 |
| use_cache: bool = True |
| tie_word_embeddings: bool = True |
| prelude_layers: int = 2 |
| coda_layers: int = 2 |
| lora_rank: int = 128 |
| loop_embed_dim: int = 128 |
| vision_hidden_size: int = 2048 |
| audio_hidden_size: int = 512 |
| vision_num_frames: int = 60 |
| vision_tokens_per_frame: int = 256 |
| vision_temporal_tokens: int = 64 |
| vision_temporal_layers: int = 2 |
| model_type: str = "spiderportal" |
| torch_dtype: str = "bfloat16" |
|
|
| |
| kv_lora_rank: int = 128 |
| q_lora_rank: int = 256 |
| qk_rope_head_dim: int = 64 |
| qk_nope_head_dim: int = 64 |
| v_head_dim: int = 64 |
|
|
| |
| engram_layers: List[int] = field(default_factory=lambda: [1, 4]) |
| engram_ngram_orders: Tuple[int, ...] = (2, 3) |
| engram_hash_heads: int = 4 |
| engram_table_size: int = 65537 |
| engram_conv_kernel: int = 4 |
| engram_conv_dilation: int = 3 |
| engram_dim: int = 128 |
|
|
|
|
| def loop_index_embedding(h, loop_t, loop_dim, theta=10000.0): |
| freqs = 1.0 / (theta ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim)) |
| angles = loop_t * freqs |
| emb = torch.cat([angles.sin(), angles.cos()], dim=-1)[:loop_dim] |
| emb_full = torch.zeros(h.shape[-1], device=h.device, dtype=h.dtype) |
| emb_full[:loop_dim] = emb |
| return h + emb_full.unsqueeze(0).unsqueeze(0) |
|
|
|
|
| class SpiderPortalRMSNorm(nn.Module): |
| def __init__(self, hidden_size, eps=1e-6): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.variance_epsilon = eps |
| def forward(self, hidden_states): |
| input_dtype = hidden_states.dtype |
| hidden_states = hidden_states.to(torch.float32) |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
| return self.weight.to(input_dtype) * hidden_states.to(input_dtype) |
|
|
|
|
| def compute_yarn_inv_freq(head_dim, rope_theta, factor, orig_max, beta_fast=32.0, beta_slow=1.0): |
| dim = head_dim |
| orig_inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2).float() / dim)) |
| pos_freqs = torch.arange(0, dim, 2).float() / dim |
| beta = (pos_freqs * math.log(rope_theta) / math.log(orig_max)) |
| scale = torch.where(beta < beta_slow, torch.ones_like(beta), torch.where(beta > beta_fast, torch.ones_like(beta) / factor, 1.0 - (beta - beta_slow) / (beta_fast - beta_slow) * (1.0 - 1.0 / factor))) |
| return orig_inv_freq * scale |
|
|
|
|
| |
| |
| |
|
|
| class SpiderPortalMLA(nn.Module): |
| """Multi-Latent Attention with compressed KV cache and sliding window. |
| For hidden_size=2048, num_heads=16: |
| - qk_nope_head_dim=64, qk_rope_head_dim=64 → total head_dim=128 |
| - kv_lora_rank=128 → 10.7x compression vs full 2048-dim KV |
| - v_head_dim=64 → value projection |
| - sliding_window=4096 → local attention range |
| """ |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.hidden_size = config.hidden_size |
| self.num_heads = config.num_attention_heads |
| self.kv_lora_rank = config.kv_lora_rank |
| self.q_lora_rank = config.q_lora_rank |
| self.qk_rope_head_dim = config.qk_rope_head_dim |
| self.qk_nope_head_dim = config.qk_nope_head_dim |
| self.v_head_dim = config.v_head_dim |
| self.head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim |
| self.sliding_window = getattr(config, 'sliding_window', None) |
|
|
| |
| if self.q_lora_rank > 0: |
| self.q_a_proj = nn.Linear(config.hidden_size, self.q_lora_rank, bias=False) |
| self.q_a_layernorm = SpiderPortalRMSNorm(self.q_lora_rank) |
| self.q_b_proj = nn.Linear(self.q_lora_rank, self.num_heads * self.head_dim, bias=False) |
| else: |
| self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) |
|
|
| |
| self.kv_a_proj_with_mqa = nn.Linear(config.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=False) |
| self.kv_a_layernorm = SpiderPortalRMSNorm(self.kv_lora_rank) |
| |
| self.kv_b_proj = nn.Linear( |
| self.kv_lora_rank, |
| self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), |
| bias=False, |
| ) |
| |
| self.o_proj = nn.Linear(self.num_heads * self.v_head_dim, config.hidden_size, bias=False) |
|
|
| |
| rope_scaling = getattr(config, 'rope_scaling', None) |
| if rope_scaling and rope_scaling.get("type") == "yarn": |
| factor = rope_scaling.get("factor", 1.0) |
| orig_max_pos = rope_scaling.get("original_max_position_embeddings", config.max_position_embeddings) |
| inv_freq = compute_yarn_inv_freq(self.qk_rope_head_dim, config.rope_theta, factor, orig_max_pos) |
| else: |
| inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, self.qk_rope_head_dim, 2).float() / self.qk_rope_head_dim)) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
| def _rotate_half(self, x): |
| x1 = x[..., :x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2:] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
| def _apply_rotary(self, x, cos, sin): |
| return (x * cos) + (self._rotate_half(x) * sin) |
|
|
| def _make_sliding_window_mask(self, q_len, kv_len, device, dtype): |
| """Create a sliding window causal mask.""" |
| if self.sliding_window is None or self.sliding_window <= 0: |
| return None |
| mask = torch.full((q_len, kv_len), torch.finfo(dtype).min, device=device, dtype=dtype) |
| for i in range(q_len): |
| start = max(0, i - self.sliding_window + 1) |
| mask[i, start:i + 1] = 0.0 |
| return mask |
|
|
| def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False): |
| bsz, q_len, _ = hidden_states.size() |
|
|
| |
| if self.q_lora_rank > 0: |
| q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) |
| else: |
| q = self.q_proj(hidden_states) |
| q = q.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) |
|
|
| |
| kv_hidden = self.kv_a_proj_with_mqa(hidden_states) |
| kv_latent, k_rope = torch.split(kv_hidden, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) |
| kv_latent_norm = self.kv_a_layernorm(kv_latent) |
| kv_b_out = self.kv_b_proj(kv_latent_norm) |
| k_nope, v = torch.split(kv_b_out, [self.num_heads * self.qk_nope_head_dim, self.num_heads * self.v_head_dim], dim=-1) |
|
|
| k_nope = k_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) |
| v = v.view(bsz, q_len, self.num_heads, self.v_head_dim).transpose(1, 2) |
| k_rope = k_rope.unsqueeze(1) |
|
|
| |
| if position_ids is None: |
| position_ids = torch.arange(q_len, device=hidden_states.device).unsqueeze(0).expand(bsz, -1) |
| max_pos = position_ids.max().item() + 1 |
| seq_len = max(max_pos, q_len) |
| t = torch.arange(seq_len, device=hidden_states.device, dtype=self.inv_freq.dtype) |
| freqs = torch.outer(t, self.inv_freq) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| cos, sin = emb.cos(), emb.sin() |
| cos = cos[position_ids].unsqueeze(1) |
| sin = sin[position_ids].unsqueeze(1) |
|
|
| q_rope = self._apply_rotary(q_rope, cos, sin) |
| k_rope = self._apply_rotary(k_rope, cos, sin) |
|
|
| |
| k_rope_expanded = k_rope.expand(-1, self.num_heads, -1, -1) |
| k_full = torch.cat([k_nope, k_rope_expanded], dim=-1) |
| q_full = torch.cat([q_nope, q_rope], dim=-1) |
|
|
| |
| if past_key_value is not None: |
| k_full = torch.cat([past_key_value[0], k_full], dim=2) |
| v = torch.cat([past_key_value[1], v], dim=2) |
| past_kv = (k_full, v) if use_cache else None |
|
|
| |
| final_mask = attention_mask |
| if self.sliding_window is not None and self.sliding_window > 0: |
| kv_len = k_full.size(2) |
| sw_mask = self._make_sliding_window_mask(q_len, kv_len, hidden_states.device, hidden_states.dtype) |
| if final_mask is not None: |
| final_mask = final_mask + sw_mask |
| else: |
| final_mask = sw_mask |
|
|
| |
| attn_output = F.scaled_dot_product_attention( |
| q_full, k_full, v, |
| attn_mask=final_mask, |
| dropout_p=self.config.attention_dropout if self.training else 0.0, |
| is_causal=(final_mask is None), |
| ) |
| attn_output = attn_output.transpose(1, 2).contiguous() |
| attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) |
| return self.o_proj(attn_output), past_kv |
|
|
|
|
| |
| |
| |
|
|
| def _tokenizer_compress(token_ids, vocab_size=50257): |
| """Simulate NFKC + lowercase canonical ID projection.""" |
| return token_ids % (vocab_size * 77 // 100) |
|
|
|
|
| class SpiderPortalEngram(nn.Module): |
| """Conditional memory module via NN-gram lookup. |
| Applied only at specific recurrent layers (config.engram_layers). |
| """ |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.ngram_orders = config.engram_ngram_orders |
| self.num_heads = config.engram_hash_heads |
| self.table_size = config.engram_table_size |
| self.d_mem = config.engram_dim |
|
|
| self.total_mem_dim = len(self.ngram_orders) * self.num_heads * self.d_mem |
|
|
| self.embed_tables = nn.ParameterDict() |
| for n in self.ngram_orders: |
| for h in range(self.num_heads): |
| key = f"e_{n}_{h}" |
| self.embed_tables[key] = nn.Parameter( |
| torch.randn(self.table_size, self.d_mem) * 0.02 |
| ) |
|
|
| self.register_buffer("hash_seeds", torch.tensor([ |
| (h + 1) * 2654435761 |
| for _ in self.ngram_orders |
| for h in range(self.num_heads) |
| ], dtype=torch.int64)) |
|
|
| self.W_k = nn.Linear(self.total_mem_dim, config.hidden_size, bias=False) |
| self.W_v = nn.Linear(self.total_mem_dim, config.hidden_size, bias=False) |
|
|
| self.conv = nn.Conv1d( |
| config.hidden_size, config.hidden_size, |
| kernel_size=config.engram_conv_kernel, |
| padding=config.engram_conv_kernel - 1, |
| groups=config.hidden_size, |
| ) |
| self.conv_dilation = config.engram_conv_dilation |
|
|
| with torch.no_grad(): |
| self.conv.weight.zero_() |
| if self.conv.bias is not None: |
| self.conv.bias.zero_() |
|
|
| self.q_norm = SpiderPortalRMSNorm(config.hidden_size) |
| self.k_norm = SpiderPortalRMSNorm(config.hidden_size) |
|
|
| def _compute_indices(self, compressed_ids, n, head_idx): |
| """Vectorized NN-gram hash indices for a single (order, head).""" |
| bsz, seq_len = compressed_ids.shape |
| pad = torch.zeros(bsz, n - 1, dtype=compressed_ids.dtype, device=compressed_ids.device) |
| padded = torch.cat([pad, compressed_ids], dim=1) |
|
|
| indices_list = [] |
| for i in range(n): |
| indices_list.append(padded[:, i:i + seq_len]) |
| ngrams = torch.stack(indices_list, dim=-1) |
|
|
| seed = int(self.hash_seeds[head_idx].item()) |
| h_val = torch.zeros(bsz, seq_len, dtype=torch.int64, device=compressed_ids.device) |
| for i in range(n): |
| h_val = h_val * 31 + ngrams[:, :, i] |
| h_val = h_val % self.table_size |
| h_val = (h_val * seed) % self.table_size |
| return h_val |
|
|
| def _retrieve(self, token_ids): |
| """Retrieve memory vectors for a batch of token sequences.""" |
| bsz, seq_len = token_ids.shape |
| compressed = _tokenizer_compress(token_ids) |
|
|
| all_parts = [] |
| head_counter = 0 |
| for n in self.ngram_orders: |
| for h in range(self.num_heads): |
| key = f"e_{n}_{h}" |
| table = self.embed_tables[key] |
| indices = self._compute_indices(compressed, n, head_counter) |
| emb = table[indices.view(-1)] |
| all_parts.append(emb.view(bsz, seq_len, self.d_mem)) |
| head_counter += 1 |
|
|
| memory = torch.cat(all_parts, dim=-1) |
| return memory |
|
|
| def forward(self, hidden_states, token_ids): |
| mem = self._retrieve(token_ids) |
|
|
| q = hidden_states |
| k = self.W_k(mem) |
| v = self.W_v(mem) |
|
|
| q_norm = self.q_norm(q) |
| k_norm = self.k_norm(k) |
|
|
| alpha = torch.sigmoid( |
| (q_norm * k_norm).sum(dim=-1, keepdim=True) / math.sqrt(q.shape[-1]) |
| ) |
|
|
| v_gated = alpha * v |
|
|
| v_gated_t = v_gated.transpose(1, 2) |
| conv_out = self.conv(v_gated_t) |
| conv_out = conv_out[:, :, :v_gated_t.shape[-1]] |
| conv_out = conv_out.transpose(1, 2) |
|
|
| y = F.silu(conv_out) + v_gated |
|
|
| return y |
|
|
|
|
| |
| |
| |
|
|
| class SpiderPortalExpert(nn.Module): |
| def __init__(self, config, intermediate_size=None): |
| super().__init__() |
| inter_size = intermediate_size or config.intermediate_size |
| self.gate_proj = nn.Linear(config.hidden_size, inter_size, bias=False) |
| self.up_proj = nn.Linear(config.hidden_size, inter_size, bias=False) |
| self.down_proj = nn.Linear(inter_size, config.hidden_size, bias=False) |
| self.act_fn = nn.SiLU() |
| def forward(self, hidden_states): |
| return self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)) |
|
|
|
|
| |
| |
| |
|
|
| class SpiderPortalDenseLayer(nn.Module): |
| """Prelude/coda dense layer with MLA attention.""" |
| def __init__(self, config): |
| super().__init__() |
| self.self_attn = SpiderPortalMLA(config) |
| dense_intermediate = config.hidden_size * 4 // 3 |
| self.ffn = SpiderPortalExpert(config, intermediate_size=dense_intermediate) |
| self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False): |
| attn_input = self.input_layernorm(hidden_states) |
| attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache) |
| hidden_states = hidden_states + attn_output |
| ffn_input = self.post_attention_layernorm(hidden_states) |
| ffn_output = self.ffn(ffn_input) |
| hidden_states = hidden_states + ffn_output |
| return hidden_states, past_kv |
|
|
|
|
| |
| |
| |
|
|
| class SpiderPortalRecurrentDenseLayer(nn.Module): |
| """Recurrent layer with MLA attention and optional Engram memory.""" |
| def __init__(self, config, layer_idx, has_engram=False): |
| super().__init__() |
| self.layer_idx = layer_idx |
| self.has_engram = has_engram |
| self.self_attn = SpiderPortalMLA(config) |
| if has_engram: |
| self.engram = SpiderPortalEngram(config) |
| self.ffn = SpiderPortalExpert(config, intermediate_size=config.intermediate_size) |
| self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.post_engram_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) if has_engram else None |
| def forward(self, hidden_states, token_ids=None, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False): |
| attn_input = self.input_layernorm(hidden_states) |
| attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache) |
| hidden_states = hidden_states + attn_output |
|
|
| if self.has_engram and token_ids is not None: |
| engram_out = self.engram(hidden_states, token_ids) |
| hidden_states = hidden_states + engram_out |
| if self.post_engram_layernorm is not None: |
| hidden_states = self.post_engram_layernorm(hidden_states) |
|
|
| ffn_input = self.post_attention_layernorm(hidden_states) |
| ffn_output = self.ffn(ffn_input) |
| hidden_states = hidden_states + ffn_output |
| return hidden_states, 0.0, past_kv |
|
|
|
|
| |
| |
| |
|
|
| class LTIInjection(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.hidden_size = config.hidden_size |
| self.log_A = nn.Parameter(torch.full((config.hidden_size,), -2.0)) |
| self.delta_t = nn.Parameter(torch.tensor(1.0)) |
| self.B = nn.Linear(config.hidden_size, config.hidden_size, bias=False) |
| with torch.no_grad(): |
| self.B.weight.data.normal_(mean=0.0, std=0.01) |
| def get_A(self): |
| return -torch.exp(self.log_A) |
| def forward(self, h_t, e): |
| A = self.get_A() |
| return A * h_t + self.B(e) |
|
|
|
|
| class ACTHalting(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.halt_predictor = nn.Linear(config.hidden_size, 1) |
| self.threshold = config.act_threshold |
| def forward(self, hidden_states): |
| return torch.sigmoid(self.halt_predictor(hidden_states)) |
|
|
|
|
| class LoRAAdapter(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| rank = config.lora_rank |
| self.down = nn.Linear(config.hidden_size, rank, bias=False) |
| self.B = nn.Parameter(torch.randn(rank, config.hidden_size) * 0.02) |
| self.scale = nn.Embedding(config.max_loop_iters, rank) |
| with torch.no_grad(): |
| self.scale.weight.data.zero_() |
| self.down.weight.data.normal_(mean=0.0, std=0.001) |
| def forward(self, x, loop_t): |
| max_t = self.scale.num_embeddings - 1 |
| t_idx = min(loop_t, max_t) |
| s = self.scale(torch.tensor(t_idx, device=x.device)) |
| down = self.down(x) * s |
| return down @ self.B |
|
|
|
|
| def checkpoint(func, *args, **kwargs): |
| """Gradient checkpointing wrapper — saves VRAM at ~20% compute cost.""" |
| if torch.is_grad_enabled(): |
| return torch.utils.checkpoint.checkpoint(func, *args, use_reentrant=False, **kwargs) |
| return func(*args, **kwargs) |
|
|
|
|
| |
| |
| |
|
|
| class SpiderPortalDenseModel(nn.Module): |
| """Full RDT model with MLA attention + Engram memory at layers 1,4. |
| Architecture: |
| 2x Prelude (MLA + dense FFN) |
| 6x Recurrent (MLA + Engram@L1,L4 + dense FFN) — with gradient checkpointing |
| 2x Coda (MLA + dense FFN) |
| LTI Injection + ACT Halting + LoRA Adapter |
| """ |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.prelude_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.prelude_layers)]) |
| self.recurrent_layers = nn.ModuleList([ |
| SpiderPortalRecurrentDenseLayer(config, i, has_engram=(i in config.engram_layers)) |
| for i in range(config.num_hidden_layers) |
| ]) |
| self.coda_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.coda_layers)]) |
| self.norm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.injection = LTIInjection(config) |
| self.act_halting = ACTHalting(config) |
| self.lora_adapter = LoRAAdapter(config) |
| self.loop_embed_dim = config.loop_embed_dim |
| def forward(self, hidden_states, input_embedding=None, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, n_loops=None, token_ids=None): |
| n_loops = n_loops or self.config.max_loop_iters |
| input_embedding = input_embedding if input_embedding is not None else hidden_states |
| for layer in self.prelude_layers: |
| hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids) |
| e = hidden_states.clone() |
| B, T_seq, D = hidden_states.shape |
| halted = torch.zeros(B, T_seq, device=hidden_states.device, dtype=torch.bool) |
| cumulative_p = torch.zeros(B, T_seq, device=hidden_states.device, dtype=hidden_states.dtype) |
| h_out = torch.zeros_like(hidden_states) |
| past_key_values = past_key_values if past_key_values is not None else [None] * len(self.recurrent_layers) |
| for t in range(n_loops): |
| h_loop = loop_index_embedding(hidden_states, t, self.loop_embed_dim) |
| if t > 0: |
| injection = self.injection(hidden_states, input_embedding) |
| hidden_states = hidden_states + injection |
| new_past_key_values = [] |
| for i, layer in enumerate(self.recurrent_layers): |
| hidden_states, aux_loss, past_kv = checkpoint( |
| layer, hidden_states, |
| token_ids=token_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_values[i] if t == 0 else None, |
| use_cache=use_cache |
| ) |
| new_past_key_values.append(past_kv) |
| lora_delta = self.lora_adapter(hidden_states, t) |
| hidden_states = hidden_states + lora_delta |
| halt_prob = self.act_halting(hidden_states).squeeze(-1) |
| still_running = ~halted |
| remainder = (1.0 - cumulative_p).clamp(min=0) |
| weight = torch.where(cumulative_p + halt_prob >= self.config.act_threshold, remainder, halt_prob) |
| weight = weight * still_running.to(hidden_states.dtype) |
| h_out = h_out + weight.unsqueeze(-1) * hidden_states |
| cumulative_p = cumulative_p + halt_prob * still_running.to(hidden_states.dtype) |
| halted = halted | (cumulative_p >= self.config.act_threshold) |
| if halted.all() and not self.training: |
| break |
| never_halted = (~halted).to(hidden_states.dtype).unsqueeze(-1) |
| hidden_states = h_out + never_halted * hidden_states |
| for layer in self.coda_layers: |
| hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids) |
| hidden_states = self.norm(hidden_states) |
| return hidden_states, 0.0, new_past_key_values |
|
|
|
|
| class SpiderPortalForConditionalGeneration(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) |
| self.model = SpiderPortalDenseModel(config) |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| if config.tie_word_embeddings: |
| self.lm_head.weight = self.embed_tokens.weight |
| self.apply(self._init_weights) |
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| if hasattr(self, 'model') and module is self.model.injection.B: |
| return |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| def forward(self, input_ids, attention_mask=None, position_ids=None, labels=None, n_loops=None, use_cache=False): |
| hidden_states = self.embed_tokens(input_ids) |
| model_dtype = next(self.model.parameters()).dtype |
| hidden_states = hidden_states.to(model_dtype) |
| input_embedding = hidden_states.clone() |
| if attention_mask is None: |
| attention_mask = torch.ones_like(input_ids, dtype=torch.bool) |
| causal_mask = torch.full((attention_mask.size(0), 1, attention_mask.size(1), attention_mask.size(1)), 0.0, dtype=hidden_states.dtype, device=hidden_states.device) |
| causal_mask = causal_mask.masked_fill(~attention_mask.unsqueeze(1).unsqueeze(2), torch.finfo(hidden_states.dtype).min) |
| causal_mask = causal_mask.triu(1) |
| hidden_states, aux_loss, past_kv = self.model( |
| hidden_states, input_embedding=input_embedding, |
| attention_mask=causal_mask, position_ids=position_ids, |
| use_cache=use_cache, n_loops=n_loops, token_ids=input_ids |
| ) |
| logits = self.lm_head(hidden_states) |
| loss = None |
| if labels is not None: |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| loss_fct = CrossEntropyLoss() |
| loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) |
| return {"loss": loss, "logits": logits, "aux_loss": aux_loss, "past_key_values": past_kv} |
| def get_num_params(self): |
| total = sum(p.numel() for p in self.parameters()) |
| return {"total": total, "trainable": total} |
|
|
|
|
| |
| |
| |
|
|
| class FineWebEduDataset(IterableDataset): |
| def __init__(self, tokenizer, seq_len: int, subset: str, rank: int, world_size: int, local_token_file=None): |
| self.tokenizer = tokenizer |
| self.seq_len = seq_len |
| self.subset = subset |
| self.rank = rank |
| self.world_size = world_size |
| |
| |
| if local_token_file and os.path.exists(local_token_file): |
| import numpy as np |
| self.use_local = True |
| self.local_file = local_token_file |
| self.mm = np.memmap(local_token_file, dtype='<u4', mode='r') |
| self.num_tokens = len(self.mm) |
| self.num_samples = self.num_tokens // seq_len |
| log(f"Using pre-tokenized binary: {local_token_file} ({self.num_tokens:,} tokens)") |
| else: |
| self.use_local = False |
| log("WARNING: No pre-tokenized binary found. Using streaming tokenizer (SLOW).") |
| log("Run pretokenize_fineweb.py first for 50-100x speedup.") |
| |
| def __iter__(self): |
| if self.use_local: |
| |
| worker = get_worker_info() |
| num_workers = worker.num_workers if worker else 1 |
| worker_id = worker.id if worker else 0 |
| |
| samples_per_worker = self.num_samples // (self.world_size * num_workers) |
| start_sample = (self.rank * num_workers + worker_id) * samples_per_worker |
| end_sample = start_sample + samples_per_worker |
| |
| |
| import numpy as np |
| for i in range(start_sample, end_sample): |
| start_idx = i * self.seq_len |
| |
| tokens = self.mm[start_idx:start_idx + self.seq_len + 1].copy() |
| |
| yield ( |
| torch.from_numpy(tokens[:-1].astype('int64')), |
| torch.from_numpy(tokens[1:].astype('int64')), |
| ) |
| else: |
| |
| worker = get_worker_info() |
| num_workers = worker.num_workers if worker else 1 |
| worker_id = worker.id if worker else 0 |
| total_shards = self.world_size * num_workers |
| shard_index = self.rank * num_workers + worker_id |
| ds = load_dataset( |
| "HuggingFaceFW/fineweb-edu", |
| name=self.subset, |
| split="train", |
| streaming=True, |
| ).shard(num_shards=total_shards, index=shard_index) |
| buf = [] |
| for sample in ds: |
| buf.extend(self.tokenizer.encode(sample["text"])) |
| while len(buf) >= self.seq_len + 1: |
| chunk = buf[: self.seq_len + 1] |
| buf = buf[self.seq_len + 1 :] |
| yield ( |
| torch.tensor(chunk[:-1], dtype=torch.long), |
| torch.tensor(chunk[1:], dtype=torch.long), |
| ) |
|
|
|
|
| |
| |
| |
|
|
| def get_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) -> float: |
| if step < warmup: |
| return max_lr * step / warmup |
| if step >= total: |
| return min_lr |
| decay = (step - warmup) / (total - warmup) |
| return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * decay)) |
|
|
|
|
| |
| |
| |
|
|
| def save_weights_only(model, step, epoch, ckpt_dir, ddp): |
| if ddp: |
| with FSDP.state_dict_type( |
| model, |
| StateDictType.FULL_STATE_DICT, |
| FullStateDictConfig(offload_to_cpu=True, rank0_only=True), |
| ): |
| model_state = model.state_dict() |
| else: |
| model_state = model.state_dict() |
| ckpt_path = os.path.join(ckpt_dir, f"spiderportal-v5-dense-ep{epoch}-step{step}.pt") |
| tmp_path = ckpt_path + ".tmp" |
| torch.save(model_state, tmp_path) |
| os.replace(tmp_path, ckpt_path) |
| size_mb = os.path.getsize(ckpt_path) / (1024 * 1024) |
| return ckpt_path, size_mb |
|
|
|
|
| def save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, ckpt_name="full"): |
| if ddp: |
| with FSDP.state_dict_type( |
| model, |
| StateDictType.FULL_STATE_DICT, |
| FullStateDictConfig(offload_to_cpu=True, rank0_only=True), |
| ): |
| model_state = model.state_dict() |
| optim_state = FSDP.optim_state_dict(model, optimizer) |
| else: |
| model_state = model.state_dict() |
| optim_state = optimizer.state_dict() |
| if not master: |
| return None, 0 |
| os.makedirs(ckpt_dir, exist_ok=True) |
| final_path = os.path.join(ckpt_dir, f"spiderportal-v5-dense-{ckpt_name}.pt") |
| tmp_path = final_path + ".tmp" |
| torch.save( |
| { |
| "step": step, |
| "epoch": epoch, |
| "model_state_dict": model_state, |
| "optimizer_state_dict": optim_state, |
| "cfg": cfg, |
| "vocab_size": vocab_size, |
| }, |
| tmp_path, |
| ) |
| os.replace(tmp_path, final_path) |
| size_mb = os.path.getsize(final_path) / (1024 * 1024) |
| return final_path, size_mb |
|
|
|
|
| def load_checkpoint(model, optimizer, path, ddp): |
| ckpt = torch.load(path, map_location="cpu", weights_only=False) |
| if ddp: |
| with FSDP.state_dict_type( |
| model, |
| StateDictType.FULL_STATE_DICT, |
| FullStateDictConfig(offload_to_cpu=True, rank0_only=False), |
| ): |
| model.load_state_dict(ckpt["model_state_dict"]) |
| optim_state = FSDP.optim_state_dict_to_load( |
| model=model, |
| optim=optimizer, |
| optim_state_dict=ckpt["optimizer_state_dict"], |
| ) |
| optimizer.load_state_dict(optim_state) |
| else: |
| model.load_state_dict(ckpt["model_state_dict"]) |
| optimizer.load_state_dict(ckpt["optimizer_state_dict"]) |
| return int(ckpt["step"]), int(ckpt.get("epoch", 0)) |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| |
| |
| |
| ddp = int(os.environ.get("RANK", -1)) != -1 |
| if ddp: |
| dist.init_process_group("nccl") |
| rank = int(os.environ["RANK"]) |
| local_rank = int(os.environ["LOCAL_RANK"]) |
| world_size = int(os.environ["WORLD_SIZE"]) |
| device = f"cuda:{local_rank}" |
| torch.cuda.set_device(device) |
| else: |
| rank = local_rank = 0 |
| world_size = 1 |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| master = rank == 0 |
| if master: |
| log( |
| f"GPUs: {torch.cuda.device_count()} | World size: {world_size} | Device: {device}" |
| ) |
|
|
| |
| |
| |
| tokenizer = AutoTokenizer.from_pretrained("gpt2") |
| tokenizer.pad_token = tokenizer.eos_token |
| vocab_size = tokenizer.vocab_size |
| if master: |
| log(f"Tokenizer: gpt2 | Vocab size: {vocab_size:,}") |
|
|
| |
| |
| |
| seq_len = 2048 |
| micro_batch = 32 |
| target_tokens = 20_000_000_000 |
| grad_accum = 2 |
| global_batch_tok = world_size * micro_batch * grad_accum * seq_len |
| total_steps = target_tokens // global_batch_tok |
| warmup_steps = 200 |
| lr = 3e-4 |
| wd = 0.1 |
| log_every = 10 |
| ckpt_every = 500 |
| ckpt_dir = "checkpoints-dense" |
| dataset_subset = "sample-10BT" |
|
|
| if master: |
| log( |
| f"[DENSE MLA+Engram] hidden=2048 | layers=6 | seq_len={seq_len} | micro_batch={micro_batch} | grad_accum={grad_accum} | " |
| f"global_batch_tokens={global_batch_tok:,} | total_steps={total_steps:,}" |
| ) |
| log( |
| f"Attention: MLA (kv_lora_rank=128, sliding_window=4096) | " |
| f"Engram: layers [1,4] | Context: 32k | " |
| f"Gradient checkpointing: enabled" |
| ) |
|
|
| |
| |
| |
| cfg = SpiderPortalConfig( |
| hidden_size=2048, num_hidden_layers=6, num_attention_heads=16, |
| num_key_value_heads=4, intermediate_size=8192, |
| num_experts=32, num_experts_per_tok=2, num_shared_experts=1, |
| router_aux_loss_coef=0.05, max_loop_iters=4, |
| prelude_layers=2, coda_layers=2, lora_rank=128, |
| rope_theta=10000000.0, |
| rope_scaling=None, |
| max_position_embeddings=32768, sliding_window=4096, |
| tie_word_embeddings=True, |
| kv_lora_rank=128, q_lora_rank=256, |
| qk_rope_head_dim=64, qk_nope_head_dim=64, v_head_dim=64, |
| engram_layers=[1, 4], |
| engram_ngram_orders=(2, 3), |
| engram_hash_heads=4, |
| engram_table_size=65537, |
| engram_dim=128, |
| ) |
| cfg.vocab_size = vocab_size |
|
|
| bf16_ok = torch.cuda.is_available() and torch.cuda.is_bf16_supported() |
| amp_dtype = torch.bfloat16 if bf16_ok else torch.float16 |
|
|
| model = SpiderPortalForConditionalGeneration(cfg) |
|
|
| if ddp: |
| mp_policy = MixedPrecision( |
| param_dtype=amp_dtype, |
| reduce_dtype=amp_dtype, |
| buffer_dtype=amp_dtype, |
| ) |
| wrap_policy = ModuleWrapPolicy({SpiderPortalDenseLayer, SpiderPortalRecurrentDenseLayer}) |
| model = FSDP( |
| model, |
| sharding_strategy=ShardingStrategy.FULL_SHARD, |
| mixed_precision=mp_policy, |
| auto_wrap_policy=wrap_policy, |
| device_id=local_rank, |
| ) |
| amp_ctx = nullcontext() |
| else: |
| model = model.to(device) |
| amp_ctx = torch.amp.autocast(device_type="cuda", dtype=amp_dtype) if torch.cuda.is_available() else nullcontext() |
| |
| try: |
| model = torch.compile(model, mode="reduce-overhead") |
| if master: |
| log("torch.compile: enabled (reduce-overhead)") |
| except Exception as e: |
| if master: |
| log(f"torch.compile failed ({e}), using eager mode") |
|
|
| if master: |
| n_params = sum(p.numel() for p in model.parameters()) |
| engram_params = sum(p.numel() for n, p in model.named_parameters() if 'engram' in n) |
| mla_params = sum(p.numel() for n, p in model.named_parameters() if 'self_attn' in n) |
| embed_params = sum(p.numel() for n, p in model.named_parameters() if 'embed_tokens' in n or 'lm_head' in n) |
| ffn_params = sum(p.numel() for n, p in model.named_parameters() if 'ffn' in n or 'gate_proj' in n or 'up_proj' in n or 'down_proj' in n) |
| other_params = n_params - engram_params - mla_params - embed_params - ffn_params |
| log( |
| f"Parameters: {n_params:,} (all active) | " |
| f"Embeddings: {embed_params:,} | MLA: {mla_params:,} | " |
| f"FFN: {ffn_params:,} | Engram: {engram_params:,} | " |
| f"Other: {other_params:,} | AMP dtype: {amp_dtype}" |
| ) |
|
|
| |
| |
| |
| engram_params_list = [p for n, p in model.named_parameters() if 'engram' in n and 'embed_tables' in n] |
| backbone_params = [p for n, p in model.named_parameters() if 'engram' not in n or 'embed_tables' not in n] |
|
|
| optimizer = torch.optim.AdamW( |
| backbone_params, lr=lr, weight_decay=wd, betas=(0.9, 0.95), fused=True |
| ) |
| if engram_params_list: |
| engram_optimizer = torch.optim.Adam( |
| engram_params_list, lr=lr * 5, betas=(0.9, 0.95), eps=1e-8 |
| ) |
| else: |
| engram_optimizer = None |
|
|
| |
| |
| |
| start_step = 0 |
| start_epoch = 1 |
| best_loss = float("inf") |
| existing_ckpts = [f for f in os.listdir(ckpt_dir) if f.startswith("spiderportal-v5-dense-ep") and f.endswith(".pt") and "-step" not in f] if os.path.isdir(ckpt_dir) else [] |
| if existing_ckpts: |
| latest = os.path.join(ckpt_dir, sorted(existing_ckpts)[-1]) |
| if master: |
| log(f"Resuming from checkpoint: {latest}") |
| start_step, start_epoch = load_checkpoint(model, optimizer, latest, ddp) |
| if master: |
| log(f"Resumed at step {start_step}, epoch {start_epoch}") |
|
|
| |
| |
| |
| |
| local_token_file = os.environ.get("TOKEN_FILE", "data/fineweb-edu-sample-10BT.bin") |
| dataset = FineWebEduDataset(tokenizer, seq_len, dataset_subset, rank, world_size, local_token_file=local_token_file) |
| num_workers = 16 if dataset.use_local else 4 |
| prefetch = 8 if dataset.use_local else 2 |
| loader = DataLoader(dataset, batch_size=micro_batch, num_workers=num_workers, pin_memory=True, prefetch_factor=prefetch) |
| if master: |
| log(f"DataLoader: num_workers={num_workers}, prefetch={prefetch}, use_local={dataset.use_local}") |
|
|
| |
| |
| |
| if master: |
| os.makedirs(ckpt_dir, exist_ok=True) |
|
|
| model.train() |
| data_iter = iter(loader) |
| t0 = time.perf_counter() |
| step = start_step |
| epoch = start_epoch |
| step_ckpt_files = [] |
| tokens_in_epoch = 0 |
| tokens_per_epoch = target_tokens |
|
|
| while step < total_steps: |
| cur_lr = get_lr(step, warmup_steps, total_steps, lr, lr * 0.1) |
| for g in optimizer.param_groups: |
| g["lr"] = cur_lr |
| if engram_optimizer: |
| for g in engram_optimizer.param_groups: |
| g["lr"] = cur_lr * 5 |
|
|
| optimizer.zero_grad() |
| if engram_optimizer: |
| engram_optimizer.zero_grad() |
| loss_accum = 0.0 |
|
|
| for micro_step in range(grad_accum): |
| try: |
| x, y = next(data_iter) |
| except StopIteration: |
| |
| if master: |
| log(f"Dataset exhausted at step {step}, restarting DataLoader") |
| dataset = FineWebEduDataset(tokenizer, seq_len, dataset_subset, rank, world_size, local_token_file=local_token_file) |
| loader = DataLoader(dataset, batch_size=micro_batch, num_workers=num_workers, pin_memory=True, prefetch_factor=prefetch) |
| data_iter = iter(loader) |
| x, y = next(data_iter) |
|
|
| x = x.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True) |
| y = y.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True) |
|
|
| sync = ( |
| nullcontext() |
| if (not ddp or micro_step == grad_accum - 1) |
| else model.no_sync() |
| ) |
| with sync, amp_ctx: |
| output = model(x) |
| if isinstance(output, dict): |
| logits = output["logits"] |
| else: |
| logits = output |
| loss = nn.functional.cross_entropy( |
| logits.view(-1, vocab_size), y.view(-1) |
| ) |
| loss = loss / grad_accum |
|
|
| loss.backward() |
| loss_accum += loss.item() |
|
|
| if ddp: |
| grad_norm = model.clip_grad_norm_(1.0) |
| else: |
| grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) |
| optimizer.step() |
| if engram_optimizer: |
| engram_optimizer.step() |
| step += 1 |
| tokens_in_epoch += global_batch_tok |
|
|
| if master and step % log_every == 0: |
| dt = time.perf_counter() - t0 |
| tok_per_sec = global_batch_tok * log_every / dt |
| tokens_seen = step * global_batch_tok |
| log( |
| f"Epoch {epoch} | step {step:6d}/{total_steps} | loss {loss_accum:.4f} " |
| f"| gnorm {float(grad_norm):.2f} | lr {cur_lr:.2e} " |
| f"| {tok_per_sec / 1e6:.2f}M tok/s " |
| f"| {tokens_seen / 1e9:.2f}B tokens seen" |
| ) |
| t0 = time.perf_counter() |
|
|
| if step % ckpt_every == 0 and master: |
| ckpt_path, size_mb = save_weights_only(model, step, epoch, ckpt_dir, ddp) |
| step_ckpt_files.append(ckpt_path) |
| log(f"Saved weights-only: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)") |
|
|
| if tokens_in_epoch >= tokens_per_epoch: |
| epoch_loss = loss_accum |
| if master: |
| epoch_time = (time.perf_counter() - t0) / 60 |
| log(f"Epoch {epoch} complete | loss={epoch_loss:.4f} | Time: {epoch_time:.1f}min") |
|
|
| for f in step_ckpt_files: |
| if os.path.exists(f): |
| os.remove(f) |
| log(f" Deleted step checkpoint: {os.path.basename(f)}") |
| step_ckpt_files.clear() |
|
|
| ckpt_path, size_mb = save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, f"ep{epoch}") |
| if ckpt_path: |
| log(f"Saved epoch checkpoint: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)") |
|
|
| if epoch_loss < best_loss: |
| best_loss = epoch_loss |
| ckpt_path, size_mb = save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, "best") |
| if ckpt_path: |
| log(f"Saved best checkpoint: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)") |
|
|
| epoch += 1 |
| tokens_in_epoch = 0 |
|
|
| if step > start_step and master: |
| ckpt_path, size_mb = save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, f"final-ep{epoch}") |
| if ckpt_path: |
| log(f"Saved final checkpoint: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)") |
|
|
| if ddp: |
| dist.barrier() |
| dist.destroy_process_group() |
|
|
| if master: |
| log("Training complete.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|