at / agent-mythos-edit.py
CLIWorks's picture
Upload agent-mythos-edit.py
c08afec verified
#!/usr/bin/env python3
"""
SpiderPortal v5-Dense: English pretraining on FineWeb-Edu with AdamW.
Architecture: RDT (2 prelude + 6 recurrent + 2 coda) with:
- MLA (Multi-Latent Attention): 10.7x KV cache compression + sliding window
- Engram conditional memory at recurrent layers 1 and 4
- Dense FFN (all params active, MoE conversion in Phase 2)
- LTI Injection + ACT Halting + LoRA Adapter
- 32k context (extendable to 256k at inference via YaRN)
Config: hidden_size=2048, 6 recurrent layers, 32 experts (Phase 2), top-2 routing
Single GPU:
python mythos-fineweb-dense.py
Multi-GPU:
torchrun --nproc_per_node=$(python -c "import torch; print(torch.cuda.device_count())") mythos-fineweb-dense.py
"""
import os
import math
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import sys
# Simple print-based logging — no file rotation, no hanging
def log(msg, level="INFO"):
ts = time.strftime("%Y-%m-%d %H:%M:%S")
print(f"{ts} | {level} | {msg}", flush=True)
# Speed up CUDA memory allocation
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512,expandable_segments:True"
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
ShardingStrategy,
MixedPrecision,
FullStateDictConfig,
StateDictType,
)
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.utils.data import IterableDataset, DataLoader, get_worker_info
from contextlib import nullcontext
from dataclasses import dataclass, field
from typing import Optional, Tuple, Dict, List
from torch.nn import CrossEntropyLoss
from datasets import load_dataset
from transformers import AutoTokenizer
# ---------------------------------------------------------------------------
# SpiderPortal Model Architecture (Dense + MLA + Engram)
# ---------------------------------------------------------------------------
@dataclass
class SpiderPortalConfig:
vocab_size: int = 50257
hidden_size: int = 2048
num_hidden_layers: int = 6
num_attention_heads: int = 16
num_key_value_heads: int = 4
intermediate_size: int = 8192
hidden_act: str = "silu"
num_experts: int = 32
num_experts_per_tok: int = 2
num_shared_experts: int = 1
router_aux_loss_coef: float = 0.05
max_loop_iters: int = 4
act_threshold: float = 0.5
max_position_embeddings: int = 32768
rope_theta: float = 10000000.0
rope_scaling: dict = None
sliding_window: int = 4096
attention_dropout: float = 0.0
rms_norm_eps: float = 1e-6
initializer_range: float = 0.02
use_cache: bool = True
tie_word_embeddings: bool = True
prelude_layers: int = 2
coda_layers: int = 2
lora_rank: int = 128
loop_embed_dim: int = 128
vision_hidden_size: int = 2048
audio_hidden_size: int = 512
vision_num_frames: int = 60
vision_tokens_per_frame: int = 256
vision_temporal_tokens: int = 64
vision_temporal_layers: int = 2
model_type: str = "spiderportal"
torch_dtype: str = "bfloat16"
# MLA parameters (DeepSeek-V2 style, scaled for hidden_size=2048)
kv_lora_rank: int = 128
q_lora_rank: int = 256
qk_rope_head_dim: int = 64
qk_nope_head_dim: int = 64
v_head_dim: int = 64
# Engram parameters (DeepSeek conditional memory)
engram_layers: List[int] = field(default_factory=lambda: [1, 4])
engram_ngram_orders: Tuple[int, ...] = (2, 3)
engram_hash_heads: int = 4
engram_table_size: int = 65537 # prime number for hash table
engram_conv_kernel: int = 4
engram_conv_dilation: int = 3
engram_dim: int = 128 # per-head embedding dimension
def loop_index_embedding(h, loop_t, loop_dim, theta=10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim))
angles = loop_t * freqs
emb = torch.cat([angles.sin(), angles.cos()], dim=-1)[:loop_dim]
emb_full = torch.zeros(h.shape[-1], device=h.device, dtype=h.dtype)
emb_full[:loop_dim] = emb
return h + emb_full.unsqueeze(0).unsqueeze(0)
class SpiderPortalRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight.to(input_dtype) * hidden_states.to(input_dtype)
def compute_yarn_inv_freq(head_dim, rope_theta, factor, orig_max, beta_fast=32.0, beta_slow=1.0):
dim = head_dim
orig_inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2).float() / dim))
pos_freqs = torch.arange(0, dim, 2).float() / dim
beta = (pos_freqs * math.log(rope_theta) / math.log(orig_max))
scale = torch.where(beta < beta_slow, torch.ones_like(beta), torch.where(beta > beta_fast, torch.ones_like(beta) / factor, 1.0 - (beta - beta_slow) / (beta_fast - beta_slow) * (1.0 - 1.0 / factor)))
return orig_inv_freq * scale
# ---------------------------------------------------------------------------
# MLA: Multi-Latent Attention (DeepSeek-V2 style) + Sliding Window
# ---------------------------------------------------------------------------
class SpiderPortalMLA(nn.Module):
"""Multi-Latent Attention with compressed KV cache and sliding window.
For hidden_size=2048, num_heads=16:
- qk_nope_head_dim=64, qk_rope_head_dim=64 → total head_dim=128
- kv_lora_rank=128 → 10.7x compression vs full 2048-dim KV
- v_head_dim=64 → value projection
- sliding_window=4096 → local attention range
"""
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.kv_lora_rank = config.kv_lora_rank
self.q_lora_rank = config.q_lora_rank
self.qk_rope_head_dim = config.qk_rope_head_dim
self.qk_nope_head_dim = config.qk_nope_head_dim
self.v_head_dim = config.v_head_dim
self.head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
self.sliding_window = getattr(config, 'sliding_window', None)
# Q projection: optional low-rank → full Q
if self.q_lora_rank > 0:
self.q_a_proj = nn.Linear(config.hidden_size, self.q_lora_rank, bias=False)
self.q_a_layernorm = SpiderPortalRMSNorm(self.q_lora_rank)
self.q_b_proj = nn.Linear(self.q_lora_rank, self.num_heads * self.head_dim, bias=False)
else:
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
# KV compression: hidden → kv_lora_rank (shared latent)
self.kv_a_proj_with_mqa = nn.Linear(config.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=False)
self.kv_a_layernorm = SpiderPortalRMSNorm(self.kv_lora_rank)
# Decompress: kv_lora_rank → nope heads + v heads
self.kv_b_proj = nn.Linear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
)
# Output projection
self.o_proj = nn.Linear(self.num_heads * self.v_head_dim, config.hidden_size, bias=False)
# RoPE frequencies
rope_scaling = getattr(config, 'rope_scaling', None)
if rope_scaling and rope_scaling.get("type") == "yarn":
factor = rope_scaling.get("factor", 1.0)
orig_max_pos = rope_scaling.get("original_max_position_embeddings", config.max_position_embeddings)
inv_freq = compute_yarn_inv_freq(self.qk_rope_head_dim, config.rope_theta, factor, orig_max_pos)
else:
inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, self.qk_rope_head_dim, 2).float() / self.qk_rope_head_dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def _rotate_half(self, x):
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def _apply_rotary(self, x, cos, sin):
return (x * cos) + (self._rotate_half(x) * sin)
def _make_sliding_window_mask(self, q_len, kv_len, device, dtype):
"""Create a sliding window causal mask."""
if self.sliding_window is None or self.sliding_window <= 0:
return None
mask = torch.full((q_len, kv_len), torch.finfo(dtype).min, device=device, dtype=dtype)
for i in range(q_len):
start = max(0, i - self.sliding_window + 1)
mask[i, start:i + 1] = 0.0
return mask
def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
bsz, q_len, _ = hidden_states.size()
# Q projection
if self.q_lora_rank > 0:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
else:
q = self.q_proj(hidden_states)
q = q.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# KV: compress to latent, then decompress
kv_hidden = self.kv_a_proj_with_mqa(hidden_states)
kv_latent, k_rope = torch.split(kv_hidden, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_latent_norm = self.kv_a_layernorm(kv_latent)
kv_b_out = self.kv_b_proj(kv_latent_norm)
k_nope, v = torch.split(kv_b_out, [self.num_heads * self.qk_nope_head_dim, self.num_heads * self.v_head_dim], dim=-1)
k_nope = k_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2)
v = v.view(bsz, q_len, self.num_heads, self.v_head_dim).transpose(1, 2)
k_rope = k_rope.unsqueeze(1)
# RoPE on Q and K rope parts
if position_ids is None:
position_ids = torch.arange(q_len, device=hidden_states.device).unsqueeze(0).expand(bsz, -1)
max_pos = position_ids.max().item() + 1
seq_len = max(max_pos, q_len)
t = torch.arange(seq_len, device=hidden_states.device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
cos, sin = emb.cos(), emb.sin()
cos = cos[position_ids].unsqueeze(1)
sin = sin[position_ids].unsqueeze(1)
q_rope = self._apply_rotary(q_rope, cos, sin)
k_rope = self._apply_rotary(k_rope, cos, sin)
# Assemble full K
k_rope_expanded = k_rope.expand(-1, self.num_heads, -1, -1)
k_full = torch.cat([k_nope, k_rope_expanded], dim=-1)
q_full = torch.cat([q_nope, q_rope], dim=-1)
# KV cache
if past_key_value is not None:
k_full = torch.cat([past_key_value[0], k_full], dim=2)
v = torch.cat([past_key_value[1], v], dim=2)
past_kv = (k_full, v) if use_cache else None
# Build attention mask: user mask + sliding window
final_mask = attention_mask
if self.sliding_window is not None and self.sliding_window > 0:
kv_len = k_full.size(2)
sw_mask = self._make_sliding_window_mask(q_len, kv_len, hidden_states.device, hidden_states.dtype)
if final_mask is not None:
final_mask = final_mask + sw_mask
else:
final_mask = sw_mask
# Attention with SDPA
attn_output = F.scaled_dot_product_attention(
q_full, k_full, v,
attn_mask=final_mask,
dropout_p=self.config.attention_dropout if self.training else 0.0,
is_causal=(final_mask is None),
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
return self.o_proj(attn_output), past_kv
# ---------------------------------------------------------------------------
# Engram: Conditional Memory via Scalable Lookup (DeepSeek style)
# ---------------------------------------------------------------------------
def _tokenizer_compress(token_ids, vocab_size=50257):
"""Simulate NFKC + lowercase canonical ID projection."""
return token_ids % (vocab_size * 77 // 100)
class SpiderPortalEngram(nn.Module):
"""Conditional memory module via NN-gram lookup.
Applied only at specific recurrent layers (config.engram_layers).
"""
def __init__(self, config):
super().__init__()
self.config = config
self.ngram_orders = config.engram_ngram_orders
self.num_heads = config.engram_hash_heads
self.table_size = config.engram_table_size
self.d_mem = config.engram_dim
self.total_mem_dim = len(self.ngram_orders) * self.num_heads * self.d_mem
self.embed_tables = nn.ParameterDict()
for n in self.ngram_orders:
for h in range(self.num_heads):
key = f"e_{n}_{h}"
self.embed_tables[key] = nn.Parameter(
torch.randn(self.table_size, self.d_mem) * 0.02
)
self.register_buffer("hash_seeds", torch.tensor([
(h + 1) * 2654435761
for _ in self.ngram_orders
for h in range(self.num_heads)
], dtype=torch.int64))
self.W_k = nn.Linear(self.total_mem_dim, config.hidden_size, bias=False)
self.W_v = nn.Linear(self.total_mem_dim, config.hidden_size, bias=False)
self.conv = nn.Conv1d(
config.hidden_size, config.hidden_size,
kernel_size=config.engram_conv_kernel,
padding=config.engram_conv_kernel - 1,
groups=config.hidden_size,
)
self.conv_dilation = config.engram_conv_dilation
with torch.no_grad():
self.conv.weight.zero_()
if self.conv.bias is not None:
self.conv.bias.zero_()
self.q_norm = SpiderPortalRMSNorm(config.hidden_size)
self.k_norm = SpiderPortalRMSNorm(config.hidden_size)
def _compute_indices(self, compressed_ids, n, head_idx):
"""Vectorized NN-gram hash indices for a single (order, head)."""
bsz, seq_len = compressed_ids.shape
pad = torch.zeros(bsz, n - 1, dtype=compressed_ids.dtype, device=compressed_ids.device)
padded = torch.cat([pad, compressed_ids], dim=1)
indices_list = []
for i in range(n):
indices_list.append(padded[:, i:i + seq_len])
ngrams = torch.stack(indices_list, dim=-1)
seed = int(self.hash_seeds[head_idx].item())
h_val = torch.zeros(bsz, seq_len, dtype=torch.int64, device=compressed_ids.device)
for i in range(n):
h_val = h_val * 31 + ngrams[:, :, i]
h_val = h_val % self.table_size
h_val = (h_val * seed) % self.table_size
return h_val
def _retrieve(self, token_ids):
"""Retrieve memory vectors for a batch of token sequences."""
bsz, seq_len = token_ids.shape
compressed = _tokenizer_compress(token_ids)
all_parts = []
head_counter = 0
for n in self.ngram_orders:
for h in range(self.num_heads):
key = f"e_{n}_{h}"
table = self.embed_tables[key]
indices = self._compute_indices(compressed, n, head_counter)
emb = table[indices.view(-1)]
all_parts.append(emb.view(bsz, seq_len, self.d_mem))
head_counter += 1
memory = torch.cat(all_parts, dim=-1)
return memory
def forward(self, hidden_states, token_ids):
mem = self._retrieve(token_ids)
q = hidden_states
k = self.W_k(mem)
v = self.W_v(mem)
q_norm = self.q_norm(q)
k_norm = self.k_norm(k)
alpha = torch.sigmoid(
(q_norm * k_norm).sum(dim=-1, keepdim=True) / math.sqrt(q.shape[-1])
)
v_gated = alpha * v
v_gated_t = v_gated.transpose(1, 2)
conv_out = self.conv(v_gated_t)
conv_out = conv_out[:, :, :v_gated_t.shape[-1]]
conv_out = conv_out.transpose(1, 2)
y = F.silu(conv_out) + v_gated
return y
# ---------------------------------------------------------------------------
# FFN Expert (dense)
# ---------------------------------------------------------------------------
class SpiderPortalExpert(nn.Module):
def __init__(self, config, intermediate_size=None):
super().__init__()
inter_size = intermediate_size or config.intermediate_size
self.gate_proj = nn.Linear(config.hidden_size, inter_size, bias=False)
self.up_proj = nn.Linear(config.hidden_size, inter_size, bias=False)
self.down_proj = nn.Linear(inter_size, config.hidden_size, bias=False)
self.act_fn = nn.SiLU()
def forward(self, hidden_states):
return self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
# ---------------------------------------------------------------------------
# Prelude/Coda Dense Layer (uses MLA)
# ---------------------------------------------------------------------------
class SpiderPortalDenseLayer(nn.Module):
"""Prelude/coda dense layer with MLA attention."""
def __init__(self, config):
super().__init__()
self.self_attn = SpiderPortalMLA(config)
dense_intermediate = config.hidden_size * 4 // 3
self.ffn = SpiderPortalExpert(config, intermediate_size=dense_intermediate)
self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
attn_input = self.input_layernorm(hidden_states)
attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
hidden_states = hidden_states + attn_output
ffn_input = self.post_attention_layernorm(hidden_states)
ffn_output = self.ffn(ffn_input)
hidden_states = hidden_states + ffn_output
return hidden_states, past_kv
# ---------------------------------------------------------------------------
# Recurrent Dense Layer (uses MLA + optional Engram)
# ---------------------------------------------------------------------------
class SpiderPortalRecurrentDenseLayer(nn.Module):
"""Recurrent layer with MLA attention and optional Engram memory."""
def __init__(self, config, layer_idx, has_engram=False):
super().__init__()
self.layer_idx = layer_idx
self.has_engram = has_engram
self.self_attn = SpiderPortalMLA(config)
if has_engram:
self.engram = SpiderPortalEngram(config)
self.ffn = SpiderPortalExpert(config, intermediate_size=config.intermediate_size)
self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_engram_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) if has_engram else None
def forward(self, hidden_states, token_ids=None, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
attn_input = self.input_layernorm(hidden_states)
attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
hidden_states = hidden_states + attn_output
if self.has_engram and token_ids is not None:
engram_out = self.engram(hidden_states, token_ids)
hidden_states = hidden_states + engram_out
if self.post_engram_layernorm is not None:
hidden_states = self.post_engram_layernorm(hidden_states)
ffn_input = self.post_attention_layernorm(hidden_states)
ffn_output = self.ffn(ffn_input)
hidden_states = hidden_states + ffn_output
return hidden_states, 0.0, past_kv
# ---------------------------------------------------------------------------
# LTI Injection, ACT Halting, LoRA Adapter
# ---------------------------------------------------------------------------
class LTIInjection(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.log_A = nn.Parameter(torch.full((config.hidden_size,), -2.0))
self.delta_t = nn.Parameter(torch.tensor(1.0))
self.B = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
with torch.no_grad():
self.B.weight.data.normal_(mean=0.0, std=0.01)
def get_A(self):
return -torch.exp(self.log_A)
def forward(self, h_t, e):
A = self.get_A()
return A * h_t + self.B(e)
class ACTHalting(nn.Module):
def __init__(self, config):
super().__init__()
self.halt_predictor = nn.Linear(config.hidden_size, 1)
self.threshold = config.act_threshold
def forward(self, hidden_states):
return torch.sigmoid(self.halt_predictor(hidden_states))
class LoRAAdapter(nn.Module):
def __init__(self, config):
super().__init__()
rank = config.lora_rank
self.down = nn.Linear(config.hidden_size, rank, bias=False)
self.B = nn.Parameter(torch.randn(rank, config.hidden_size) * 0.02)
self.scale = nn.Embedding(config.max_loop_iters, rank)
with torch.no_grad():
self.scale.weight.data.zero_()
self.down.weight.data.normal_(mean=0.0, std=0.001)
def forward(self, x, loop_t):
max_t = self.scale.num_embeddings - 1
t_idx = min(loop_t, max_t)
s = self.scale(torch.tensor(t_idx, device=x.device))
down = self.down(x) * s
return down @ self.B
def checkpoint(func, *args, **kwargs):
"""Gradient checkpointing wrapper — saves VRAM at ~20% compute cost."""
if torch.is_grad_enabled():
return torch.utils.checkpoint.checkpoint(func, *args, use_reentrant=False, **kwargs)
return func(*args, **kwargs)
# ---------------------------------------------------------------------------
# Full Model
# ---------------------------------------------------------------------------
class SpiderPortalDenseModel(nn.Module):
"""Full RDT model with MLA attention + Engram memory at layers 1,4.
Architecture:
2x Prelude (MLA + dense FFN)
6x Recurrent (MLA + Engram@L1,L4 + dense FFN) — with gradient checkpointing
2x Coda (MLA + dense FFN)
LTI Injection + ACT Halting + LoRA Adapter
"""
def __init__(self, config):
super().__init__()
self.config = config
self.prelude_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.prelude_layers)])
self.recurrent_layers = nn.ModuleList([
SpiderPortalRecurrentDenseLayer(config, i, has_engram=(i in config.engram_layers))
for i in range(config.num_hidden_layers)
])
self.coda_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.coda_layers)])
self.norm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.injection = LTIInjection(config)
self.act_halting = ACTHalting(config)
self.lora_adapter = LoRAAdapter(config)
self.loop_embed_dim = config.loop_embed_dim
def forward(self, hidden_states, input_embedding=None, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, n_loops=None, token_ids=None):
n_loops = n_loops or self.config.max_loop_iters
input_embedding = input_embedding if input_embedding is not None else hidden_states
for layer in self.prelude_layers:
hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
e = hidden_states.clone()
B, T_seq, D = hidden_states.shape
halted = torch.zeros(B, T_seq, device=hidden_states.device, dtype=torch.bool)
cumulative_p = torch.zeros(B, T_seq, device=hidden_states.device, dtype=hidden_states.dtype)
h_out = torch.zeros_like(hidden_states)
past_key_values = past_key_values if past_key_values is not None else [None] * len(self.recurrent_layers)
for t in range(n_loops):
h_loop = loop_index_embedding(hidden_states, t, self.loop_embed_dim)
if t > 0:
injection = self.injection(hidden_states, input_embedding)
hidden_states = hidden_states + injection
new_past_key_values = []
for i, layer in enumerate(self.recurrent_layers):
hidden_states, aux_loss, past_kv = checkpoint(
layer, hidden_states,
token_ids=token_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values[i] if t == 0 else None,
use_cache=use_cache
)
new_past_key_values.append(past_kv)
lora_delta = self.lora_adapter(hidden_states, t)
hidden_states = hidden_states + lora_delta
halt_prob = self.act_halting(hidden_states).squeeze(-1)
still_running = ~halted
remainder = (1.0 - cumulative_p).clamp(min=0)
weight = torch.where(cumulative_p + halt_prob >= self.config.act_threshold, remainder, halt_prob)
weight = weight * still_running.to(hidden_states.dtype)
h_out = h_out + weight.unsqueeze(-1) * hidden_states
cumulative_p = cumulative_p + halt_prob * still_running.to(hidden_states.dtype)
halted = halted | (cumulative_p >= self.config.act_threshold)
if halted.all() and not self.training:
break
never_halted = (~halted).to(hidden_states.dtype).unsqueeze(-1)
hidden_states = h_out + never_halted * hidden_states
for layer in self.coda_layers:
hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
hidden_states = self.norm(hidden_states)
return hidden_states, 0.0, new_past_key_values
class SpiderPortalForConditionalGeneration(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.model = SpiderPortalDenseModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
if config.tie_word_embeddings:
self.lm_head.weight = self.embed_tokens.weight
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
if hasattr(self, 'model') and module is self.model.injection.B:
return
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
def forward(self, input_ids, attention_mask=None, position_ids=None, labels=None, n_loops=None, use_cache=False):
hidden_states = self.embed_tokens(input_ids)
model_dtype = next(self.model.parameters()).dtype
hidden_states = hidden_states.to(model_dtype)
input_embedding = hidden_states.clone()
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
causal_mask = torch.full((attention_mask.size(0), 1, attention_mask.size(1), attention_mask.size(1)), 0.0, dtype=hidden_states.dtype, device=hidden_states.device)
causal_mask = causal_mask.masked_fill(~attention_mask.unsqueeze(1).unsqueeze(2), torch.finfo(hidden_states.dtype).min)
causal_mask = causal_mask.triu(1)
hidden_states, aux_loss, past_kv = self.model(
hidden_states, input_embedding=input_embedding,
attention_mask=causal_mask, position_ids=position_ids,
use_cache=use_cache, n_loops=n_loops, token_ids=input_ids
)
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
return {"loss": loss, "logits": logits, "aux_loss": aux_loss, "past_key_values": past_kv}
def get_num_params(self):
total = sum(p.numel() for p in self.parameters())
return {"total": total, "trainable": total}
# ---------------------------------------------------------------------------
# Dataset
# ---------------------------------------------------------------------------
class FineWebEduDataset(IterableDataset):
def __init__(self, tokenizer, seq_len: int, subset: str, rank: int, world_size: int, local_token_file=None):
self.tokenizer = tokenizer
self.seq_len = seq_len
self.subset = subset
self.rank = rank
self.world_size = world_size
# Local tokenized data - USE mmapped binary for speed
if local_token_file and os.path.exists(local_token_file):
import numpy as np
self.use_local = True
self.local_file = local_token_file
self.mm = np.memmap(local_token_file, dtype='<u4', mode='r')
self.num_tokens = len(self.mm)
self.num_samples = self.num_tokens // seq_len
log(f"Using pre-tokenized binary: {local_token_file} ({self.num_tokens:,} tokens)")
else:
self.use_local = False
log("WARNING: No pre-tokenized binary found. Using streaming tokenizer (SLOW).")
log("Run pretokenize_fineweb.py first for 50-100x speedup.")
def __iter__(self):
if self.use_local:
# Fast: use memory-mapped array
worker = get_worker_info()
num_workers = worker.num_workers if worker else 1
worker_id = worker.id if worker else 0
samples_per_worker = self.num_samples // (self.world_size * num_workers)
start_sample = (self.rank * num_workers + worker_id) * samples_per_worker
end_sample = start_sample + samples_per_worker
# Batch read tokens - convert to numpy array slice then tensor
import numpy as np
for i in range(start_sample, end_sample):
start_idx = i * self.seq_len
# Direct slice from memory-mapped array
tokens = self.mm[start_idx:start_idx + self.seq_len + 1].copy()
yield (
torch.from_numpy(tokens[:-1].astype('int64')),
torch.from_numpy(tokens[1:].astype('int64')),
)
else:
# Fallback to HuggingFace
worker = get_worker_info()
num_workers = worker.num_workers if worker else 1
worker_id = worker.id if worker else 0
total_shards = self.world_size * num_workers
shard_index = self.rank * num_workers + worker_id
ds = load_dataset(
"HuggingFaceFW/fineweb-edu",
name=self.subset,
split="train",
streaming=True,
).shard(num_shards=total_shards, index=shard_index)
buf = []
for sample in ds:
buf.extend(self.tokenizer.encode(sample["text"]))
while len(buf) >= self.seq_len + 1:
chunk = buf[: self.seq_len + 1]
buf = buf[self.seq_len + 1 :]
yield (
torch.tensor(chunk[:-1], dtype=torch.long),
torch.tensor(chunk[1:], dtype=torch.long),
)
# ---------------------------------------------------------------------------
# LR schedule
# ---------------------------------------------------------------------------
def get_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) -> float:
if step < warmup:
return max_lr * step / warmup
if step >= total:
return min_lr
decay = (step - warmup) / (total - warmup)
return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * decay))
# ---------------------------------------------------------------------------
# Checkpointing
# ---------------------------------------------------------------------------
def save_weights_only(model, step, epoch, ckpt_dir, ddp):
if ddp:
with FSDP.state_dict_type(
model,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
):
model_state = model.state_dict()
else:
model_state = model.state_dict()
ckpt_path = os.path.join(ckpt_dir, f"spiderportal-v5-dense-ep{epoch}-step{step}.pt")
tmp_path = ckpt_path + ".tmp"
torch.save(model_state, tmp_path)
os.replace(tmp_path, ckpt_path)
size_mb = os.path.getsize(ckpt_path) / (1024 * 1024)
return ckpt_path, size_mb
def save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, ckpt_name="full"):
if ddp:
with FSDP.state_dict_type(
model,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
):
model_state = model.state_dict()
optim_state = FSDP.optim_state_dict(model, optimizer)
else:
model_state = model.state_dict()
optim_state = optimizer.state_dict()
if not master:
return None, 0
os.makedirs(ckpt_dir, exist_ok=True)
final_path = os.path.join(ckpt_dir, f"spiderportal-v5-dense-{ckpt_name}.pt")
tmp_path = final_path + ".tmp"
torch.save(
{
"step": step,
"epoch": epoch,
"model_state_dict": model_state,
"optimizer_state_dict": optim_state,
"cfg": cfg,
"vocab_size": vocab_size,
},
tmp_path,
)
os.replace(tmp_path, final_path)
size_mb = os.path.getsize(final_path) / (1024 * 1024)
return final_path, size_mb
def load_checkpoint(model, optimizer, path, ddp):
ckpt = torch.load(path, map_location="cpu", weights_only=False)
if ddp:
with FSDP.state_dict_type(
model,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
):
model.load_state_dict(ckpt["model_state_dict"])
optim_state = FSDP.optim_state_dict_to_load(
model=model,
optim=optimizer,
optim_state_dict=ckpt["optimizer_state_dict"],
)
optimizer.load_state_dict(optim_state)
else:
model.load_state_dict(ckpt["model_state_dict"])
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
return int(ckpt["step"]), int(ckpt.get("epoch", 0))
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
# ------------------------------------------------------------------
# Distributed init
# ------------------------------------------------------------------
ddp = int(os.environ.get("RANK", -1)) != -1
if ddp:
dist.init_process_group("nccl")
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
device = f"cuda:{local_rank}"
torch.cuda.set_device(device)
else:
rank = local_rank = 0
world_size = 1
device = "cuda" if torch.cuda.is_available() else "cpu"
master = rank == 0
if master:
log(
f"GPUs: {torch.cuda.device_count()} | World size: {world_size} | Device: {device}"
)
# ------------------------------------------------------------------
# Tokenizer
# ------------------------------------------------------------------
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
vocab_size = tokenizer.vocab_size
if master:
log(f"Tokenizer: gpt2 | Vocab size: {vocab_size:,}")
# ------------------------------------------------------------------
# Hyperparameters
# ------------------------------------------------------------------
seq_len = 2048
micro_batch = 32 # Increased — 96GB VRAM can handle this
target_tokens = 20_000_000_000
grad_accum = 2
global_batch_tok = world_size * micro_batch * grad_accum * seq_len
total_steps = target_tokens // global_batch_tok
warmup_steps = 200
lr = 3e-4
wd = 0.1
log_every = 10
ckpt_every = 500
ckpt_dir = "checkpoints-dense"
dataset_subset = "sample-10BT"
if master:
log(
f"[DENSE MLA+Engram] hidden=2048 | layers=6 | seq_len={seq_len} | micro_batch={micro_batch} | grad_accum={grad_accum} | "
f"global_batch_tokens={global_batch_tok:,} | total_steps={total_steps:,}"
)
log(
f"Attention: MLA (kv_lora_rank=128, sliding_window=4096) | "
f"Engram: layers [1,4] | Context: 32k | "
f"Gradient checkpointing: enabled"
)
# ------------------------------------------------------------------
# Model
# ------------------------------------------------------------------
cfg = SpiderPortalConfig(
hidden_size=2048, num_hidden_layers=6, num_attention_heads=16,
num_key_value_heads=4, intermediate_size=8192,
num_experts=32, num_experts_per_tok=2, num_shared_experts=1,
router_aux_loss_coef=0.05, max_loop_iters=4,
prelude_layers=2, coda_layers=2, lora_rank=128,
rope_theta=10000000.0,
rope_scaling=None,
max_position_embeddings=32768, sliding_window=4096,
tie_word_embeddings=True,
kv_lora_rank=128, q_lora_rank=256,
qk_rope_head_dim=64, qk_nope_head_dim=64, v_head_dim=64,
engram_layers=[1, 4],
engram_ngram_orders=(2, 3),
engram_hash_heads=4,
engram_table_size=65537,
engram_dim=128,
)
cfg.vocab_size = vocab_size
bf16_ok = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
amp_dtype = torch.bfloat16 if bf16_ok else torch.float16
model = SpiderPortalForConditionalGeneration(cfg)
if ddp:
mp_policy = MixedPrecision(
param_dtype=amp_dtype,
reduce_dtype=amp_dtype,
buffer_dtype=amp_dtype,
)
wrap_policy = ModuleWrapPolicy({SpiderPortalDenseLayer, SpiderPortalRecurrentDenseLayer})
model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mp_policy,
auto_wrap_policy=wrap_policy,
device_id=local_rank,
)
amp_ctx = nullcontext()
else:
model = model.to(device)
amp_ctx = torch.amp.autocast(device_type="cuda", dtype=amp_dtype) if torch.cuda.is_available() else nullcontext()
# Enable torch.compile for 20-30% speedup
try:
model = torch.compile(model, mode="reduce-overhead")
if master:
log("torch.compile: enabled (reduce-overhead)")
except Exception as e:
if master:
log(f"torch.compile failed ({e}), using eager mode")
if master:
n_params = sum(p.numel() for p in model.parameters())
engram_params = sum(p.numel() for n, p in model.named_parameters() if 'engram' in n)
mla_params = sum(p.numel() for n, p in model.named_parameters() if 'self_attn' in n)
embed_params = sum(p.numel() for n, p in model.named_parameters() if 'embed_tokens' in n or 'lm_head' in n)
ffn_params = sum(p.numel() for n, p in model.named_parameters() if 'ffn' in n or 'gate_proj' in n or 'up_proj' in n or 'down_proj' in n)
other_params = n_params - engram_params - mla_params - embed_params - ffn_params
log(
f"Parameters: {n_params:,} (all active) | "
f"Embeddings: {embed_params:,} | MLA: {mla_params:,} | "
f"FFN: {ffn_params:,} | Engram: {engram_params:,} | "
f"Other: {other_params:,} | AMP dtype: {amp_dtype}"
)
# ------------------------------------------------------------------
# Optimizer — dual optimizer for Engram embeddings
# ------------------------------------------------------------------
engram_params_list = [p for n, p in model.named_parameters() if 'engram' in n and 'embed_tables' in n]
backbone_params = [p for n, p in model.named_parameters() if 'engram' not in n or 'embed_tables' not in n]
optimizer = torch.optim.AdamW(
backbone_params, lr=lr, weight_decay=wd, betas=(0.9, 0.95), fused=True
)
if engram_params_list:
engram_optimizer = torch.optim.Adam(
engram_params_list, lr=lr * 5, betas=(0.9, 0.95), eps=1e-8
)
else:
engram_optimizer = None
# ------------------------------------------------------------------
# Resume from latest checkpoint
# ------------------------------------------------------------------
start_step = 0
start_epoch = 1
best_loss = float("inf")
existing_ckpts = [f for f in os.listdir(ckpt_dir) if f.startswith("spiderportal-v5-dense-ep") and f.endswith(".pt") and "-step" not in f] if os.path.isdir(ckpt_dir) else []
if existing_ckpts:
latest = os.path.join(ckpt_dir, sorted(existing_ckpts)[-1])
if master:
log(f"Resuming from checkpoint: {latest}")
start_step, start_epoch = load_checkpoint(model, optimizer, latest, ddp)
if master:
log(f"Resumed at step {start_step}, epoch {start_epoch}")
# ------------------------------------------------------------------
# Dataset + DataLoader
# ------------------------------------------------------------------
# Check for pre-tokenized binary file
local_token_file = os.environ.get("TOKEN_FILE", "data/fineweb-edu-sample-10BT.bin")
dataset = FineWebEduDataset(tokenizer, seq_len, dataset_subset, rank, world_size, local_token_file=local_token_file)
num_workers = 16 if dataset.use_local else 4
prefetch = 8 if dataset.use_local else 2
loader = DataLoader(dataset, batch_size=micro_batch, num_workers=num_workers, pin_memory=True, prefetch_factor=prefetch)
if master:
log(f"DataLoader: num_workers={num_workers}, prefetch={prefetch}, use_local={dataset.use_local}")
# ------------------------------------------------------------------
# Training loop
# ------------------------------------------------------------------
if master:
os.makedirs(ckpt_dir, exist_ok=True)
model.train()
data_iter = iter(loader)
t0 = time.perf_counter()
step = start_step
epoch = start_epoch
step_ckpt_files = []
tokens_in_epoch = 0
tokens_per_epoch = target_tokens
while step < total_steps:
cur_lr = get_lr(step, warmup_steps, total_steps, lr, lr * 0.1)
for g in optimizer.param_groups:
g["lr"] = cur_lr
if engram_optimizer:
for g in engram_optimizer.param_groups:
g["lr"] = cur_lr * 5
optimizer.zero_grad()
if engram_optimizer:
engram_optimizer.zero_grad()
loss_accum = 0.0
for micro_step in range(grad_accum):
try:
x, y = next(data_iter)
except StopIteration:
# Dataset exhausted — reshuffle and restart
if master:
log(f"Dataset exhausted at step {step}, restarting DataLoader")
dataset = FineWebEduDataset(tokenizer, seq_len, dataset_subset, rank, world_size, local_token_file=local_token_file)
loader = DataLoader(dataset, batch_size=micro_batch, num_workers=num_workers, pin_memory=True, prefetch_factor=prefetch)
data_iter = iter(loader)
x, y = next(data_iter)
x = x.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True)
y = y.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True)
sync = (
nullcontext()
if (not ddp or micro_step == grad_accum - 1)
else model.no_sync()
)
with sync, amp_ctx:
output = model(x)
if isinstance(output, dict):
logits = output["logits"]
else:
logits = output
loss = nn.functional.cross_entropy(
logits.view(-1, vocab_size), y.view(-1)
)
loss = loss / grad_accum
loss.backward()
loss_accum += loss.item()
if ddp:
grad_norm = model.clip_grad_norm_(1.0)
else:
grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
if engram_optimizer:
engram_optimizer.step()
step += 1
tokens_in_epoch += global_batch_tok
if master and step % log_every == 0:
dt = time.perf_counter() - t0
tok_per_sec = global_batch_tok * log_every / dt
tokens_seen = step * global_batch_tok
log(
f"Epoch {epoch} | step {step:6d}/{total_steps} | loss {loss_accum:.4f} "
f"| gnorm {float(grad_norm):.2f} | lr {cur_lr:.2e} "
f"| {tok_per_sec / 1e6:.2f}M tok/s "
f"| {tokens_seen / 1e9:.2f}B tokens seen"
)
t0 = time.perf_counter()
if step % ckpt_every == 0 and master:
ckpt_path, size_mb = save_weights_only(model, step, epoch, ckpt_dir, ddp)
step_ckpt_files.append(ckpt_path)
log(f"Saved weights-only: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)")
if tokens_in_epoch >= tokens_per_epoch:
epoch_loss = loss_accum
if master:
epoch_time = (time.perf_counter() - t0) / 60
log(f"Epoch {epoch} complete | loss={epoch_loss:.4f} | Time: {epoch_time:.1f}min")
for f in step_ckpt_files:
if os.path.exists(f):
os.remove(f)
log(f" Deleted step checkpoint: {os.path.basename(f)}")
step_ckpt_files.clear()
ckpt_path, size_mb = save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, f"ep{epoch}")
if ckpt_path:
log(f"Saved epoch checkpoint: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)")
if epoch_loss < best_loss:
best_loss = epoch_loss
ckpt_path, size_mb = save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, "best")
if ckpt_path:
log(f"Saved best checkpoint: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)")
epoch += 1
tokens_in_epoch = 0
if step > start_step and master:
ckpt_path, size_mb = save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, f"final-ep{epoch}")
if ckpt_path:
log(f"Saved final checkpoint: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)")
if ddp:
dist.barrier()
dist.destroy_process_group()
if master:
log("Training complete.")
if __name__ == "__main__":
main()