gpuburnout-models / models /s2_model.py
GPUburnout's picture
feat: add app code, configs, and tokenizers
36bc78f
"""
Llama-Style Transformer Model
=============================
Modern transformer architecture with all Tier 1 and Tier 2 optimizations:
Architecture (Tier 1):
- RMSNorm (faster than LayerNorm, no mean calculation)
- RoPE (Rotary Position Embedding, better length generalization)
- SwiGLU activation (gated FFN, consistently outperforms GELU)
- Pre-norm (apply norm before attention/FFN, more stable training)
Optimizations (Tier 2):
- GQA (Grouped Query Attention, fewer KV heads = faster + less memory)
- Weight tying (share embedding and output projection)
- Flash Attention via F.scaled_dot_product_attention
- Gradient checkpointing support (trade compute for memory)
Compatible with:
- liger-kernel (fused RMSNorm, SwiGLU, RoPE, cross-entropy)
- bf16/fp16 mixed precision training
- torch.compile for additional speedups
Model Sizes:
- tiny: ~15M params (for testing)
- small: ~125M params
- medium: ~350M params
- large: ~760M params
- 1B: ~1.1B params (Llama 3.2 1B style)
"""
import math
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
# ============================================================================
# Model Configuration
# ============================================================================
@dataclass
class ModelConfig:
"""Configuration for Llama-style transformer model."""
# Model architecture
vocab_size: int = 32000
d_model: int = 2048 # Hidden dimension
n_layers: int = 16 # Number of transformer blocks
n_heads: int = 32 # Number of attention heads
n_kv_heads: int = 8 # Number of KV heads (for GQA)
d_ff: int = None # FFN intermediate dim (default: 8/3 * d_model)
# Sequence
max_seq_len: int = 2048 # Maximum sequence length
# RoPE
rope_theta: float = 500000.0 # RoPE base frequency
# Regularization
dropout: float = 0.0 # Dropout (0 for pretraining)
# Options
tie_weights: bool = True # Tie embedding and output weights
use_flash_attn: bool = True # Use Flash Attention (SDPA)
def __post_init__(self):
# SwiGLU uses 8/3 * d_model for FFN, rounded to multiple of 256
if self.d_ff is None:
self.d_ff = int(8 / 3 * self.d_model)
self.d_ff = ((self.d_ff + 255) // 256) * 256
# Validate GQA configuration
assert self.n_heads % self.n_kv_heads == 0, \
f"n_heads ({self.n_heads}) must be divisible by n_kv_heads ({self.n_kv_heads})"
self.n_kv_groups = self.n_heads // self.n_kv_heads
self.head_dim = self.d_model // self.n_heads
# Predefined model configurations
MODEL_CONFIGS = {
"tiny": ModelConfig(
d_model=256,
n_layers=6,
n_heads=8,
n_kv_heads=4,
max_seq_len=1024,
),
"small": ModelConfig(
d_model=768,
n_layers=12,
n_heads=12,
n_kv_heads=4,
max_seq_len=2048,
),
"medium": ModelConfig(
d_model=1024,
n_layers=16,
n_heads=16,
n_kv_heads=4,
max_seq_len=2048,
),
"large": ModelConfig(
d_model=1536,
n_layers=20,
n_heads=24,
n_kv_heads=8,
max_seq_len=2048,
),
"1B": ModelConfig(
d_model=2048,
n_layers=16,
n_heads=32,
n_kv_heads=8,
d_ff=8192, # Llama 3.2 1B uses 4x hidden, not 8/3x
max_seq_len=2048,
),
}
def get_model_config(size: str, **overrides) -> ModelConfig:
"""Get a predefined model configuration with optional overrides."""
if size not in MODEL_CONFIGS:
raise ValueError(f"Unknown model size: {size}. Choose from: {list(MODEL_CONFIGS.keys())}")
config = MODEL_CONFIGS[size]
# Apply overrides
for key, value in overrides.items():
if hasattr(config, key):
setattr(config, key, value)
else:
raise ValueError(f"Unknown config parameter: {key}")
# Recompute derived values
config.__post_init__()
return config
# ============================================================================
# RMSNorm (Tier 1)
# ============================================================================
class RMSNorm(nn.Module):
"""
Root Mean Square Layer Normalization.
Simpler and faster than LayerNorm - skips the mean calculation.
Used in Llama, Mistral, and other modern LLMs.
Can be replaced with liger_kernel.transformers.LigerRMSNorm for
additional speedup via kernel fusion.
"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x: torch.Tensor) -> torch.Tensor:
output = self._norm(x.float()).type_as(x)
return output * self.weight
# ============================================================================
# Rotary Position Embedding (RoPE) (Tier 1)
# ============================================================================
def precompute_rope_freqs(
dim: int,
max_seq_len: int,
theta: float = 10000.0,
device: torch.device = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Precompute the cos and sin frequencies for RoPE.
Args:
dim: Head dimension (d_model // n_heads)
max_seq_len: Maximum sequence length
theta: Base frequency (Llama 3 uses 500000)
device: Target device
Returns:
cos, sin tensors of shape (max_seq_len, dim)
"""
# Compute inverse frequencies
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim))
# Create position indices
t = torch.arange(max_seq_len, device=device)
# Outer product: (seq_len,) x (dim/2,) -> (seq_len, dim/2)
freqs = torch.outer(t, freqs)
# Compute cos and sin, then interleave to get (seq_len, dim)
cos = torch.cos(freqs).repeat_interleave(2, dim=-1)
sin = torch.sin(freqs).repeat_interleave(2, dim=-1)
return cos, sin
def apply_rotary_emb(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
"""
Apply rotary position embedding to input tensor.
Args:
x: Input tensor of shape (batch, n_heads, seq_len, head_dim)
cos: Cosine frequencies of shape (seq_len, head_dim)
sin: Sine frequencies of shape (seq_len, head_dim)
Returns:
Tensor with rotary embedding applied
"""
# Get sequence length from input
seq_len = x.size(2)
cos = cos[:seq_len]
sin = sin[:seq_len]
# Reshape for broadcasting: (1, 1, seq_len, head_dim)
cos = cos.unsqueeze(0).unsqueeze(0)
sin = sin.unsqueeze(0).unsqueeze(0)
# Rotate pairs: [x0, x1, x2, x3, ...] -> [-x1, x0, -x3, x2, ...]
x_rot = torch.stack([-x[..., 1::2], x[..., ::2]], dim=-1)
x_rot = x_rot.reshape(x.shape)
# Apply rotation
return x * cos + x_rot * sin
# ============================================================================
# Grouped Query Attention (GQA) with Flash Attention (Tier 1 + Tier 2)
# ============================================================================
class Attention(nn.Module):
"""
Multi-head attention with Grouped Query Attention (GQA) and Flash Attention.
GQA uses fewer key-value heads than query heads, reducing memory and
compute while maintaining quality. For example, with 32 query heads and
8 KV heads, each KV head is shared by 4 query heads.
Flash Attention is used via PyTorch's scaled_dot_product_attention,
which provides O(N) memory complexity instead of O(N^2).
"""
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
self.n_heads = config.n_heads
self.n_kv_heads = config.n_kv_heads
self.n_kv_groups = config.n_kv_groups
self.head_dim = config.head_dim
# Query projection: full heads
self.wq = nn.Linear(config.d_model, config.n_heads * config.head_dim, bias=False)
# Key and Value projections: fewer heads for GQA
self.wk = nn.Linear(config.d_model, config.n_kv_heads * config.head_dim, bias=False)
self.wv = nn.Linear(config.d_model, config.n_kv_heads * config.head_dim, bias=False)
# Output projection
self.wo = nn.Linear(config.n_heads * config.head_dim, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout)
self.use_flash_attn = config.use_flash_attn
def forward(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, seq_len, _ = x.shape
# Project to Q, K, V
q = self.wq(x) # (B, T, n_heads * head_dim)
k = self.wk(x) # (B, T, n_kv_heads * head_dim)
v = self.wv(x) # (B, T, n_kv_heads * head_dim)
# Reshape to (B, n_heads, T, head_dim)
q = q.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2)
# Apply RoPE to Q and K
q = apply_rotary_emb(q, cos, sin)
k = apply_rotary_emb(k, cos, sin)
# Expand KV heads for GQA: (B, n_kv_heads, T, head_dim) -> (B, n_heads, T, head_dim)
if self.n_kv_groups > 1:
k = k.repeat_interleave(self.n_kv_groups, dim=1)
v = v.repeat_interleave(self.n_kv_groups, dim=1)
# Attention
if self.use_flash_attn:
# Use PyTorch's optimized SDPA (Flash Attention when available)
attn_out = F.scaled_dot_product_attention(
q, k, v,
attn_mask=mask,
dropout_p=self.dropout.p if self.training else 0.0,
is_causal=mask is None, # Use causal mask if no explicit mask
)
else:
# Manual attention (for debugging or when SDPA unavailable)
scale = 1.0 / math.sqrt(self.head_dim)
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale
if mask is not None:
attn_weights = attn_weights + mask
else:
# Causal mask
causal_mask = torch.triu(
torch.full((seq_len, seq_len), float('-inf'), device=x.device),
diagonal=1
)
attn_weights = attn_weights + causal_mask
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = self.dropout(attn_weights)
attn_out = torch.matmul(attn_weights, v)
# Reshape back: (B, n_heads, T, head_dim) -> (B, T, d_model)
attn_out = attn_out.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
return self.wo(attn_out)
# ============================================================================
# SwiGLU Feed-Forward Network (Tier 1)
# ============================================================================
class FeedForward(nn.Module):
"""
SwiGLU Feed-Forward Network.
Replaces the standard GELU FFN with a gated linear unit using SiLU activation.
Uses 3 weight matrices (gate, up, down) instead of 2.
SwiGLU(x) = (x * W_gate * SiLU) * (x * W_up) * W_down
Consistently outperforms GELU at the same compute budget.
Can be replaced with liger_kernel.transformers.LigerSwiGLUMLP for fusion.
"""
def __init__(self, config: ModelConfig):
super().__init__()
hidden_dim = config.d_ff
# Gate and up projections (can be fused)
self.w_gate = nn.Linear(config.d_model, hidden_dim, bias=False)
self.w_up = nn.Linear(config.d_model, hidden_dim, bias=False)
# Down projection
self.w_down = nn.Linear(hidden_dim, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# SwiGLU: SiLU(gate) * up, then project down
return self.dropout(self.w_down(F.silu(self.w_gate(x)) * self.w_up(x)))
# ============================================================================
# Transformer Block (Pre-norm)
# ============================================================================
class TransformerBlock(nn.Module):
"""
Single transformer block with pre-norm architecture.
Pre-norm applies normalization BEFORE attention/FFN (not after),
which provides more stable gradients at scale.
Structure:
x = x + Attention(RMSNorm(x))
x = x + FFN(RMSNorm(x))
"""
def __init__(self, config: ModelConfig, layer_idx: int):
super().__init__()
self.layer_idx = layer_idx
# Pre-norm layers
self.attn_norm = RMSNorm(config.d_model)
self.ffn_norm = RMSNorm(config.d_model)
# Attention and FFN
self.attn = Attention(config)
self.ffn = FeedForward(config)
def forward(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Pre-norm attention with residual
x = x + self.attn(self.attn_norm(x), cos, sin, mask)
# Pre-norm FFN with residual
x = x + self.ffn(self.ffn_norm(x))
return x
# ============================================================================
# Complete Llama Model
# ============================================================================
class LlamaModel(nn.Module):
"""
Complete Llama-style transformer model for language modeling.
Features:
- RMSNorm, RoPE, SwiGLU, GQA (Tier 1)
- Weight tying, Flash Attention (Tier 2)
- Gradient checkpointing support
- Compatible with liger-kernel fused ops
Usage:
config = get_model_config("1B", vocab_size=32000)
model = LlamaModel(config)
# Enable gradient checkpointing for memory savings
model.gradient_checkpointing_enable()
# Forward pass
logits = model(input_ids)
loss = model(input_ids, targets=targets)
"""
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
# Token embedding
self.tok_emb = nn.Embedding(config.vocab_size, config.d_model)
# Transformer blocks
self.layers = nn.ModuleList([
TransformerBlock(config, layer_idx=i)
for i in range(config.n_layers)
])
# Final normalization
self.norm = RMSNorm(config.d_model)
# Output projection (language model head)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
# Weight tying: share embedding and output weights
if config.tie_weights:
self.lm_head.weight = self.tok_emb.weight
# Precompute RoPE frequencies
self.register_buffer(
"rope_cos",
torch.zeros(config.max_seq_len, config.head_dim),
persistent=False
)
self.register_buffer(
"rope_sin",
torch.zeros(config.max_seq_len, config.head_dim),
persistent=False
)
# Gradient checkpointing flag
self._gradient_checkpointing = False
# Initialize weights
self.apply(self._init_weights)
# Apply special initialization for output projection
self._init_output_weights()
def _init_weights(self, module: nn.Module):
"""Initialize weights using Llama-style initialization."""
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def _init_output_weights(self):
"""Apply scaled initialization to output projections for stability."""
# Scale down residual projections by 1/sqrt(2*n_layers)
scale = (2 * self.config.n_layers) ** -0.5
for layer in self.layers:
torch.nn.init.normal_(layer.attn.wo.weight, mean=0.0, std=0.02 * scale)
torch.nn.init.normal_(layer.ffn.w_down.weight, mean=0.0, std=0.02 * scale)
def _init_rope(self, device: torch.device):
"""Initialize RoPE frequencies on the correct device."""
cos, sin = precompute_rope_freqs(
dim=self.config.head_dim,
max_seq_len=self.config.max_seq_len,
theta=self.config.rope_theta,
device=device,
)
self.rope_cos = cos
self.rope_sin = sin
def gradient_checkpointing_enable(self):
"""Enable gradient checkpointing for memory-efficient training."""
self._gradient_checkpointing = True
def gradient_checkpointing_disable(self):
"""Disable gradient checkpointing."""
self._gradient_checkpointing = False
def forward(
self,
input_ids: torch.Tensor,
targets: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Forward pass.
Args:
input_ids: Token IDs of shape (batch_size, seq_len)
targets: Optional target IDs for loss computation
mask: Optional attention mask
Returns:
If targets provided: scalar loss
Otherwise: logits of shape (batch_size, seq_len, vocab_size)
"""
batch_size, seq_len = input_ids.shape
device = input_ids.device
# Initialize RoPE on first forward pass (ensures correct device)
if self.rope_cos.device != device or self.rope_cos.sum() == 0:
self._init_rope(device)
# Token embeddings
x = self.tok_emb(input_ids)
# Get RoPE frequencies for this sequence length
cos = self.rope_cos[:seq_len]
sin = self.rope_sin[:seq_len]
# Transformer blocks
for layer in self.layers:
if self._gradient_checkpointing and self.training:
x = torch.utils.checkpoint.checkpoint(
layer, x, cos, sin, mask,
use_reentrant=False
)
else:
x = layer(x, cos, sin, mask)
# Final norm
x = self.norm(x)
# Compute logits
logits = self.lm_head(x)
# Compute loss if targets provided
if targets is not None:
# NOTE: No shift here — the DataLoader already provides
# pre-shifted targets (x = tokens[:-1], y = tokens[1:]),
# so logits[k] should predict targets[k] directly.
loss = F.cross_entropy(
logits.view(-1, self.config.vocab_size),
targets.view(-1),
ignore_index=-100, # Ignore padding
)
return loss
return logits
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
max_new_tokens: int = 100,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
) -> torch.Tensor:
"""
Generate tokens autoregressively.
Args:
input_ids: Starting token IDs (batch_size, seq_len)
max_new_tokens: Maximum number of tokens to generate
temperature: Sampling temperature (1.0 = neutral)
top_k: If set, only sample from top k tokens
top_p: If set, use nucleus sampling with this probability mass
Returns:
Generated token IDs (batch_size, seq_len + max_new_tokens)
"""
self.eval()
for _ in range(max_new_tokens):
# Crop to max_seq_len if needed
idx_cond = input_ids if input_ids.size(1) <= self.config.max_seq_len else \
input_ids[:, -self.config.max_seq_len:]
# Forward pass
logits = self(idx_cond)
# Get logits for last position
logits = logits[:, -1, :] / temperature
# Apply top-k filtering
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = float('-inf')
# Apply top-p (nucleus) filtering
if top_p is not None:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above threshold
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
logits[indices_to_remove] = float('-inf')
# Sample
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
# Append
input_ids = torch.cat([input_ids, next_token], dim=1)
return input_ids
def count_parameters(self, trainable_only: bool = True) -> int:
"""Count model parameters."""
if trainable_only:
return sum(p.numel() for p in self.parameters() if p.requires_grad)
return sum(p.numel() for p in self.parameters())
def estimate_flops(self, seq_len: int, batch_size: int = 1) -> int:
"""
Estimate FLOPs for a forward pass.
Uses the approximation: FLOPs ≈ 2 * params * tokens
(multiply-add counts as 2 ops)
"""
params = self.count_parameters(trainable_only=False)
tokens = batch_size * seq_len
return 2 * params * tokens
# ============================================================================
# Utility Functions
# ============================================================================
def create_model(
size: str = "1B",
vocab_size: int = 32000,
max_seq_len: int = 2048,
**kwargs
) -> LlamaModel:
"""
Create a Llama model with the specified configuration.
Args:
size: Model size ("tiny", "small", "medium", "large", "1B")
vocab_size: Vocabulary size
max_seq_len: Maximum sequence length
**kwargs: Additional config overrides
Returns:
Initialized LlamaModel
"""
config = get_model_config(
size,
vocab_size=vocab_size,
max_seq_len=max_seq_len,
**kwargs
)
return LlamaModel(config)
def print_model_summary(model: LlamaModel):
"""Print a summary of the model architecture."""
config = model.config
params = model.count_parameters()
print("\n" + "=" * 60)
print("LLAMA MODEL SUMMARY")
print("=" * 60)
print(f"\nArchitecture:")
print(f" Hidden dim: {config.d_model}")
print(f" Layers: {config.n_layers}")
print(f" Attention heads: {config.n_heads}")
print(f" KV heads (GQA): {config.n_kv_heads}")
print(f" Head dim: {config.head_dim}")
print(f" FFN dim: {config.d_ff}")
print(f" Vocab size: {config.vocab_size}")
print(f" Max seq len: {config.max_seq_len}")
print(f"\nOptimizations:")
print(f" RMSNorm: Yes")
print(f" RoPE: Yes (theta={config.rope_theta})")
print(f" SwiGLU: Yes")
print(f" GQA: Yes ({config.n_heads}/{config.n_kv_heads} = {config.n_kv_groups}x)")
print(f" Weight tying: {config.tie_weights}")
print(f" Flash Attention: {config.use_flash_attn}")
print(f"\nParameters:")
print(f" Total: {params:,}")
print(f" Size: ~{params / 1e9:.2f}B" if params > 1e9 else f" Size: ~{params / 1e6:.0f}M")
# Estimate memory
param_bytes = params * 4 # fp32
print(f" FP32 memory: ~{param_bytes / 1e9:.2f} GB")
print(f" BF16 memory: ~{param_bytes / 2 / 1e9:.2f} GB")
print("=" * 60 + "\n")
# ============================================================================
# Main (for testing)
# ============================================================================
if __name__ == "__main__":
# Test model creation
print("Testing Llama model creation...\n")
for size in ["tiny", "small", "medium", "large", "1B"]:
model = create_model(size)
params = model.count_parameters()
print(f"{size:8s}: {params:>12,} parameters ({params/1e6:>7.1f}M)")
print("\n" + "-" * 60)
# Detailed summary for 1B
model = create_model("1B")
print_model_summary(model)
# Test forward pass
print("Testing forward pass...")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
batch_size = 2
seq_len = 128
input_ids = torch.randint(0, 32000, (batch_size, seq_len), device=device)
# Forward without targets (returns logits)
logits = model(input_ids)
print(f"Logits shape: {logits.shape}")
# Forward with targets (returns loss)
targets = torch.randint(0, 32000, (batch_size, seq_len), device=device)
loss = model(input_ids, targets=targets)
print(f"Loss: {loss.item():.4f}")
# Test gradient checkpointing
print("\nTesting gradient checkpointing...")
model.gradient_checkpointing_enable()
loss = model(input_ids, targets=targets)
loss.backward()
print(f"Gradient checkpointing loss: {loss.item():.4f}")
print("\nAll tests passed!")