| | |
| | """ |
| | model.py — Role SLM Transformer (~1B params) with RoPE + Gradient Checkpointing |
| | ================================================================================ |
| | Supports context lengths up to 5M tokens via: |
| | * RoPE (no fixed position embedding table) |
| | * RMSNorm (more efficient than LayerNorm) |
| | * SwiGLU activation (better training dynamics) |
| | * Flash Attention via PyTorch scaled_dot_product_attention |
| | * Gradient checkpointing for memory-efficient training on 24GB |
| | """ |
| |
|
| | import math |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch.utils.checkpoint import checkpoint as grad_checkpoint |
| | from typing import Optional, Tuple |
| | from config import cfg |
| |
|
| |
|
| | class RMSNorm(nn.Module): |
| | def __init__(self, dim: int, eps: float = 1e-6): |
| | super().__init__() |
| | self.eps = eps |
| | self.weight = nn.Parameter(torch.ones(dim)) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | norm = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps) |
| | return (x.float() * norm).type_as(x) * self.weight |
| |
|
| |
|
| | def precompute_rope_freqs(dim, max_seq_len, theta=10000.0, device=None): |
| | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim)) |
| | t = torch.arange(max_seq_len, device=device).float() |
| | freqs = torch.outer(t, freqs) |
| | return freqs.cos(), freqs.sin() |
| |
|
| |
|
| | def apply_rope(x, cos, sin): |
| | seq_len = x.shape[2] |
| | head_dim = x.shape[3] |
| | cos = cos[:seq_len].unsqueeze(0).unsqueeze(0) |
| | sin = sin[:seq_len].unsqueeze(0).unsqueeze(0) |
| | x1 = x[..., :head_dim // 2] |
| | x2 = x[..., head_dim // 2:] |
| | return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) |
| |
|
| |
|
| | class CausalSelfAttention(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | assert cfg.n_embd % cfg.n_head == 0 |
| | self.n_head = cfg.n_head |
| | self.head_dim = cfg.n_embd // cfg.n_head |
| | self.q_proj = nn.Linear(cfg.n_embd, cfg.n_embd, bias=False) |
| | self.k_proj = nn.Linear(cfg.n_embd, cfg.n_embd, bias=False) |
| | self.v_proj = nn.Linear(cfg.n_embd, cfg.n_embd, bias=False) |
| | self.out_proj = nn.Linear(cfg.n_embd, cfg.n_embd, bias=False) |
| | self.resid_drop = nn.Dropout(cfg.dropout) |
| |
|
| | def forward(self, x, rope_cos, rope_sin): |
| | B, T, C = x.shape |
| | q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) |
| | k = self.k_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) |
| | v = self.v_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) |
| | q = apply_rope(q, rope_cos, rope_sin) |
| | k = apply_rope(k, rope_cos, rope_sin) |
| | if hasattr(F, 'scaled_dot_product_attention'): |
| | y = F.scaled_dot_product_attention(q, k, v, |
| | dropout_p=cfg.dropout if self.training else 0.0, is_causal=True) |
| | else: |
| | scale = 1.0 / math.sqrt(self.head_dim) |
| | att = (q @ k.transpose(-2, -1)) * scale |
| | mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool() |
| | att = att.masked_fill(mask.unsqueeze(0).unsqueeze(0), float('-inf')) |
| | att = F.softmax(att, dim=-1) |
| | y = att @ v |
| | y = y.transpose(1, 2).contiguous().view(B, T, C) |
| | return self.resid_drop(self.out_proj(y)) |
| |
|
| |
|
| | class SwiGLUFFN(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | hidden_dim = int(cfg.n_embd * getattr(cfg, 'ffn_multiplier', 2.667)) |
| | hidden_dim = ((hidden_dim + 63) // 64) * 64 |
| | self.gate_proj = nn.Linear(cfg.n_embd, hidden_dim, bias=False) |
| | self.up_proj = nn.Linear(cfg.n_embd, hidden_dim, bias=False) |
| | self.down_proj = nn.Linear(hidden_dim, cfg.n_embd, bias=False) |
| | self.dropout = nn.Dropout(cfg.dropout) |
| |
|
| | def forward(self, x): |
| | return self.dropout(self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))) |
| |
|
| |
|
| | class TransformerBlock(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.attn_norm = RMSNorm(cfg.n_embd) |
| | self.attn = CausalSelfAttention() |
| | self.ffn_norm = RMSNorm(cfg.n_embd) |
| | self.ffn = SwiGLUFFN() |
| |
|
| | def forward(self, x, rope_cos, rope_sin): |
| | x = x + self.attn(self.attn_norm(x), rope_cos, rope_sin) |
| | x = x + self.ffn(self.ffn_norm(x)) |
| | return x |
| |
|
| |
|
| | class RoleSLM(nn.Module): |
| | """Role-Based Small Language Model — ~1B params, LLaMA-style with gradient checkpointing.""" |
| |
|
| | def __init__(self): |
| | super().__init__() |
| | self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.n_embd) |
| | self.drop = nn.Dropout(cfg.dropout) |
| | self.blocks = nn.ModuleList([TransformerBlock() for _ in range(cfg.n_layer)]) |
| | self.norm = RMSNorm(cfg.n_embd) |
| | self.lm_head = nn.Linear(cfg.n_embd, cfg.vocab_size, bias=False) |
| | self.tok_emb.weight = self.lm_head.weight |
| |
|
| | self.use_checkpointing = getattr(cfg, 'gradient_checkpointing', True) |
| |
|
| | head_dim = cfg.n_embd // cfg.n_head |
| | max_pos = getattr(cfg, 'max_position_embeddings', 1_000_000) |
| | rope_theta = getattr(cfg, 'rope_theta', 10000.0) |
| | precompute_len = min(max_pos, cfg.block_size * 2) |
| | cos, sin = precompute_rope_freqs(head_dim, precompute_len, theta=rope_theta) |
| | self.register_buffer("rope_cos", cos, persistent=False) |
| | self.register_buffer("rope_sin", sin, persistent=False) |
| | self._rope_max_len = precompute_len |
| | self._rope_theta = rope_theta |
| | self._head_dim = head_dim |
| | self.apply(self._init_weights) |
| |
|
| | n_params = sum(p.numel() for p in self.parameters()) |
| | print(f"{cfg.domain_name}-SLM initialized: {n_params/1e6:.2f}M parameters ({n_params/1e9:.3f}B)") |
| | print(f" Architecture: {cfg.n_layer}L / {cfg.n_head}H / {cfg.n_embd}D") |
| | print(f" Gradient checkpointing: {self.use_checkpointing}") |
| | print(f" Max context: {max_pos:,} tokens (via RoPE)") |
| | print(f" Estimated model size: {n_params * 4 / 1e9:.2f} GB (fp32)") |
| |
|
| | def _init_weights(self, module): |
| | if isinstance(module, nn.Linear): |
| | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| | if module.bias is not None: |
| | torch.nn.init.zeros_(module.bias) |
| | elif isinstance(module, nn.Embedding): |
| | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| |
|
| | def _extend_rope(self, seq_len, device): |
| | if seq_len > self._rope_max_len: |
| | new_len = max(seq_len, self._rope_max_len * 2) |
| | cos, sin = precompute_rope_freqs(self._head_dim, new_len, |
| | theta=self._rope_theta, device=device) |
| | self.rope_cos = cos |
| | self.rope_sin = sin |
| | self._rope_max_len = new_len |
| |
|
| | def _block_forward(self, block, x, rope_cos, rope_sin): |
| | """Wrapper for gradient checkpointing.""" |
| | return block(x, rope_cos, rope_sin) |
| |
|
| | def forward(self, idx, targets=None): |
| | B, T = idx.shape |
| | device = idx.device |
| | self._extend_rope(T, device) |
| | x = self.drop(self.tok_emb(idx)) |
| | rope_cos = self.rope_cos[:T].to(device) |
| | rope_sin = self.rope_sin[:T].to(device) |
| | for block in self.blocks: |
| | if self.use_checkpointing and self.training: |
| | x = grad_checkpoint(self._block_forward, block, x, rope_cos, rope_sin, |
| | use_reentrant=False) |
| | else: |
| | x = block(x, rope_cos, rope_sin) |
| | x = self.norm(x) |
| | logits = self.lm_head(x) |
| | loss = None |
| | if targets is not None: |
| | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) |
| | return logits, loss |
| |
|
| | @torch.no_grad() |
| | def generate(self, idx, max_new_tokens, temperature=0.8, top_k=50, top_p=0.9): |
| | self.use_checkpointing = False |
| | for _ in range(max_new_tokens): |
| | idx_cond = idx if idx.size(1) <= cfg.block_size else idx[:, -cfg.block_size:] |
| | logits, _ = self(idx_cond) |
| | logits = logits[:, -1, :] |
| | if temperature == 0: |
| | idx_next = logits.argmax(dim=-1, keepdim=True) |
| | else: |
| | logits = logits / temperature |
| | if top_k > 0: |
| | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| | logits[logits < v[:, [-1]]] = float('-inf') |
| | if top_p < 1.0: |
| | sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| | sorted_indices_to_remove = cumulative_probs > top_p |
| | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| | sorted_indices_to_remove[..., 0] = 0 |
| | indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
| | logits[indices_to_remove] = float('-inf') |
| | probs = F.softmax(logits, dim=-1) |
| | idx_next = torch.multinomial(probs, num_samples=1) |
| | idx = torch.cat([idx, idx_next], dim=1) |
| | if idx_next.item() == 3: |
| | break |
| | self.use_checkpointing = getattr(cfg, 'gradient_checkpointing', True) |
| | return idx |
| |
|
| | def count_parameters(self): |
| | return sum(p.numel() for p in self.parameters()) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | model = RoleSLM() |
| | x = torch.randint(0, cfg.vocab_size, (1, 32)) |
| | logits, loss = model(x, x) |
| | print(f"Test forward: logits={logits.shape}, loss={loss.item():.4f}") |
| |
|