#!/usr/bin/env python3 """ model.py — Role SLM Transformer (~1B params) with RoPE + Gradient Checkpointing ================================================================================ Supports context lengths up to 1M tokens via: * RoPE (no fixed position embedding table) * RMSNorm (more efficient than LayerNorm) * SwiGLU activation (better training dynamics) * Flash Attention via PyTorch scaled_dot_product_attention * Gradient checkpointing for memory-efficient training on 24GB """ import math import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint as grad_checkpoint from typing import Optional, Tuple from config import cfg class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: norm = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps) return (x.float() * norm).type_as(x) * self.weight def precompute_rope_freqs(dim, max_seq_len, theta=10000.0, device=None): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim)) t = torch.arange(max_seq_len, device=device).float() freqs = torch.outer(t, freqs) return freqs.cos(), freqs.sin() def apply_rope(x, cos, sin): seq_len = x.shape[2] head_dim = x.shape[3] cos = cos[:seq_len].unsqueeze(0).unsqueeze(0) sin = sin[:seq_len].unsqueeze(0).unsqueeze(0) x1 = x[..., :head_dim // 2] x2 = x[..., head_dim // 2:] return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) class CausalSelfAttention(nn.Module): def __init__(self): super().__init__() assert cfg.n_embd % cfg.n_head == 0 self.n_head = cfg.n_head self.head_dim = cfg.n_embd // cfg.n_head self.q_proj = nn.Linear(cfg.n_embd, cfg.n_embd, bias=False) self.k_proj = nn.Linear(cfg.n_embd, cfg.n_embd, bias=False) self.v_proj = nn.Linear(cfg.n_embd, cfg.n_embd, bias=False) self.out_proj = nn.Linear(cfg.n_embd, cfg.n_embd, bias=False) self.resid_drop = nn.Dropout(cfg.dropout) def forward(self, x, rope_cos, rope_sin): B, T, C = x.shape q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) q = apply_rope(q, rope_cos, rope_sin) k = apply_rope(k, rope_cos, rope_sin) if hasattr(F, 'scaled_dot_product_attention'): y = F.scaled_dot_product_attention(q, k, v, dropout_p=cfg.dropout if self.training else 0.0, is_causal=True) else: scale = 1.0 / math.sqrt(self.head_dim) att = (q @ k.transpose(-2, -1)) * scale mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool() att = att.masked_fill(mask.unsqueeze(0).unsqueeze(0), float('-inf')) att = F.softmax(att, dim=-1) y = att @ v y = y.transpose(1, 2).contiguous().view(B, T, C) return self.resid_drop(self.out_proj(y)) class SwiGLUFFN(nn.Module): def __init__(self): super().__init__() hidden_dim = int(cfg.n_embd * getattr(cfg, 'ffn_multiplier', 2.667)) hidden_dim = ((hidden_dim + 63) // 64) * 64 self.gate_proj = nn.Linear(cfg.n_embd, hidden_dim, bias=False) self.up_proj = nn.Linear(cfg.n_embd, hidden_dim, bias=False) self.down_proj = nn.Linear(hidden_dim, cfg.n_embd, bias=False) self.dropout = nn.Dropout(cfg.dropout) def forward(self, x): return self.dropout(self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))) class TransformerBlock(nn.Module): def __init__(self): super().__init__() self.attn_norm = RMSNorm(cfg.n_embd) self.attn = CausalSelfAttention() self.ffn_norm = RMSNorm(cfg.n_embd) self.ffn = SwiGLUFFN() def forward(self, x, rope_cos, rope_sin): x = x + self.attn(self.attn_norm(x), rope_cos, rope_sin) x = x + self.ffn(self.ffn_norm(x)) return x class RoleSLM(nn.Module): """Role-Based Small Language Model — ~1B params, LLaMA-style with gradient checkpointing.""" def __init__(self): super().__init__() self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.n_embd) self.drop = nn.Dropout(cfg.dropout) self.blocks = nn.ModuleList([TransformerBlock() for _ in range(cfg.n_layer)]) self.norm = RMSNorm(cfg.n_embd) self.lm_head = nn.Linear(cfg.n_embd, cfg.vocab_size, bias=False) self.tok_emb.weight = self.lm_head.weight # Weight tying self.use_checkpointing = getattr(cfg, 'gradient_checkpointing', True) head_dim = cfg.n_embd // cfg.n_head max_pos = getattr(cfg, 'max_position_embeddings', 1_000_000) rope_theta = getattr(cfg, 'rope_theta', 10000.0) precompute_len = min(max_pos, cfg.block_size * 2) cos, sin = precompute_rope_freqs(head_dim, precompute_len, theta=rope_theta) self.register_buffer("rope_cos", cos, persistent=False) self.register_buffer("rope_sin", sin, persistent=False) self._rope_max_len = precompute_len self._rope_theta = rope_theta self._head_dim = head_dim self.apply(self._init_weights) n_params = sum(p.numel() for p in self.parameters()) print(f"{cfg.domain_name}-SLM initialized: {n_params/1e6:.2f}M parameters ({n_params/1e9:.3f}B)") print(f" Architecture: {cfg.n_layer}L / {cfg.n_head}H / {cfg.n_embd}D") print(f" Gradient checkpointing: {self.use_checkpointing}") print(f" Max context: {max_pos:,} tokens (via RoPE)") print(f" Estimated model size: {n_params * 4 / 1e9:.2f} GB (fp32)") def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) def _extend_rope(self, seq_len, device): if seq_len > self._rope_max_len: new_len = max(seq_len, self._rope_max_len * 2) cos, sin = precompute_rope_freqs(self._head_dim, new_len, theta=self._rope_theta, device=device) self.rope_cos = cos self.rope_sin = sin self._rope_max_len = new_len def _block_forward(self, block, x, rope_cos, rope_sin): """Wrapper for gradient checkpointing.""" return block(x, rope_cos, rope_sin) def forward(self, idx, targets=None): B, T = idx.shape device = idx.device self._extend_rope(T, device) x = self.drop(self.tok_emb(idx)) rope_cos = self.rope_cos[:T].to(device) rope_sin = self.rope_sin[:T].to(device) for block in self.blocks: if self.use_checkpointing and self.training: x = grad_checkpoint(self._block_forward, block, x, rope_cos, rope_sin, use_reentrant=False) else: x = block(x, rope_cos, rope_sin) x = self.norm(x) logits = self.lm_head(x) loss = None if targets is not None: loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) return logits, loss @torch.no_grad() def generate(self, idx, max_new_tokens, temperature=0.8, top_k=50, top_p=0.9): self.use_checkpointing = False # No checkpointing during generation for _ in range(max_new_tokens): idx_cond = idx if idx.size(1) <= cfg.block_size else idx[:, -cfg.block_size:] logits, _ = self(idx_cond) logits = logits[:, -1, :] if temperature == 0: idx_next = logits.argmax(dim=-1, keepdim=True) else: logits = logits / temperature if top_k > 0: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = float('-inf') if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) logits[indices_to_remove] = float('-inf') probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1) idx = torch.cat([idx, idx_next], dim=1) if idx_next.item() == 3: # break self.use_checkpointing = getattr(cfg, 'gradient_checkpointing', True) return idx def count_parameters(self): return sum(p.numel() for p in self.parameters()) if __name__ == "__main__": model = RoleSLM() x = torch.randint(0, cfg.vocab_size, (1, 32)) logits, loss = model(x, x) print(f"Test forward: logits={logits.shape}, loss={loss.item():.4f}")