import math from dataclasses import dataclass from typing import Optional, Tuple, List import torch import torch.nn as nn import torch.nn.functional as F from .config import ModelConfig class RotaryEmbedding(nn.Module): """Rotary Position Embedding (RoPE) - used in LLaMA, GPT-NeoX""" def __init__(self, dim: int, max_seq_len: int = 8192, base: float = 10000.0): super().__init__() self.dim = dim self.max_seq_len = max_seq_len self.base = base # Precompute frequencies inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) # Build cache for efficiency self._build_cache(max_seq_len) def _build_cache(self, seq_len: int): """Precompute cos/sin for given sequence length""" t = torch.arange(seq_len, device=self.inv_freq.device).type_as(self.inv_freq) freqs = torch.outer(t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos(), persistent=False) self.register_buffer("sin_cached", emb.sin(), persistent=False) self.cached_seq_len = seq_len def forward(self, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]: """Return cos and sin for position embeddings""" if seq_len > self.cached_seq_len: self._build_cache(seq_len) return self.cos_cached[:seq_len], self.sin_cached[:seq_len] def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply rotary position embedding to queries and keys. Args: q: (B, n_heads, T, d_head) k: (B, n_heads, T, d_head) cos: (T, d_head) sin: (T, d_head) """ # Reshape for broadcasting cos = cos.unsqueeze(0).unsqueeze(0) # (1, 1, T, d_head) sin = sin.unsqueeze(0).unsqueeze(0) # Split into first and second half q_half1, q_half2 = q.chunk(2, dim=-1) k_half1, k_half2 = k.chunk(2, dim=-1) # Apply rotation q_rot = torch.cat([ q_half1 * cos - q_half2 * sin, q_half2 * cos + q_half1 * sin ], dim=-1) k_rot = torch.cat([ k_half1 * cos - k_half2 * sin, k_half2 * cos + k_half1 * sin ], dim=-1) return q_rot, k_rot class MultiHeadSelfAttention(nn.Module): def __init__( self, d_model: int, n_heads: int, dropout: float, max_seq_len: int = 8192, use_rope: bool = True, use_flash: bool = True ): super().__init__() assert d_model % n_heads == 0, "d_model must be divisible by n_heads" self.d_model = d_model self.n_heads = n_heads self.d_head = d_model // n_heads self.use_rope = use_rope self.use_flash = use_flash and hasattr(F, 'scaled_dot_product_attention') # QKV projection self.qkv = nn.Linear(d_model, 3 * d_model, bias=True) self.out_proj = nn.Linear(d_model, d_model, bias=True) # Dropout self.attn_dropout = nn.Dropout(dropout) self.resid_dropout = nn.Dropout(dropout) # Rotary embeddings if use_rope: self.rotary_emb = RotaryEmbedding(self.d_head, max_seq_len) # Causal mask (fallback for non-flash attention) if not self.use_flash: self.register_buffer( "causal_mask", torch.tril(torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)), persistent=False ) def forward( self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: B, T, C = x.size() # Compute QKV qkv = self.qkv(x) # (B, T, 3*C) q, k, v = qkv.split(self.d_model, dim=-1) # Reshape to (B, n_heads, T, d_head) q = q.view(B, T, self.n_heads, self.d_head).transpose(1, 2) k = k.view(B, T, self.n_heads, self.d_head).transpose(1, 2) v = v.view(B, T, self.n_heads, self.d_head).transpose(1, 2) # Apply rotary embeddings if self.use_rope: cos, sin = self.rotary_emb(T) q, k = apply_rotary_pos_emb(q, k, cos, sin) # KV cache for inference if past_kv is not None: past_k, past_v = past_kv k = torch.cat([past_k, k], dim=2) v = torch.cat([past_v, v], dim=2) present_kv = (k, v) if use_cache else None # Compute attention if self.use_flash: # Use PyTorch's optimized Flash Attention y = F.scaled_dot_product_attention( q, k, v, attn_mask=None, dropout_p=self.attn_dropout.p if self.training else 0.0, is_causal=True ) else: # Fallback: manual attention computation att = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_head) # Apply causal mask T_q, T_k = q.size(2), k.size(2) causal = self.causal_mask[:T_q, :T_k] att = att.masked_fill(~causal, float("-inf")) # Apply additional mask if provided if attn_mask is not None: att = att + attn_mask att = F.softmax(att, dim=-1) att = self.attn_dropout(att) y = att @ v # (B, n_heads, T, d_head) # Reshape and project output y = y.transpose(1, 2).contiguous().view(B, T, C) y = self.out_proj(y) y = self.resid_dropout(y) return y, present_kv class TransformerBlock(nn.Module): def __init__( self, d_model: int, n_heads: int, mlp_ratio: int, dropout: float, max_seq_len: int = 8192, use_rope: bool = True, use_flash: bool = True ): super().__init__() self.ln1 = nn.LayerNorm(d_model) self.attn = MultiHeadSelfAttention( d_model, n_heads, dropout, max_seq_len, use_rope, use_flash ) self.ln2 = nn.LayerNorm(d_model) # MLP with GELU activation (SwiGLU would be even better) self.mlp = nn.Sequential( nn.Linear(d_model, mlp_ratio * d_model, bias=True), nn.GELU(), nn.Linear(mlp_ratio * d_model, d_model, bias=True), nn.Dropout(dropout), ) def forward( self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: # Pre-LayerNorm architecture attn_out, present_kv = self.attn(self.ln1(x), attn_mask, past_kv, use_cache) x = x + attn_out x = x + self.mlp(self.ln2(x)) return x, present_kv class SupernovaModel(nn.Module): """ Optimized Transformer Language Model with: - Flash Attention support - Rotary Position Embeddings (RoPE) - KV caching for efficient generation - Gradient checkpointing support - Mixed precision training compatibility """ def __init__(self, cfg: ModelConfig): super().__init__() self.cfg = cfg d = cfg.d_model V = cfg.vocab_size # Token embeddings self.tok_emb = nn.Embedding(V, d) # Optional learned positional embeddings (if not using RoPE) use_rope = getattr(cfg, 'use_rope', True) if not use_rope and cfg.use_positional_embedding: self.pos_emb = nn.Embedding(cfg.n_positions, d) else: self.pos_emb = None # Dropout self.drop = nn.Dropout(cfg.dropout) # Transformer blocks self.blocks = nn.ModuleList([ TransformerBlock( d, cfg.n_heads, cfg.mlp_ratio, cfg.dropout, max_seq_len=getattr(cfg, 'n_positions', 8192), use_rope=use_rope, use_flash=getattr(cfg, 'use_flash', True) ) for _ in range(cfg.n_layers) ]) # Final layer norm self.ln_f = nn.LayerNorm(d) if cfg.final_layer_norm else nn.Identity() # Gradient checkpointing flag (set during training) self.gradient_checkpointing = False # Initialize weights self.apply(self._init_weights) def _init_weights(self, module): """Initialize weights following GPT-2/3 initialization scheme""" if isinstance(module, nn.Linear): # Use normal distribution with std=0.02 nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): nn.init.ones_(module.weight) nn.init.zeros_(module.bias) def forward( self, input_ids: torch.Tensor, targets: Optional[torch.Tensor] = None, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, use_cache: bool = False ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]: """ Forward pass with optional KV caching for efficient generation. Args: input_ids: (B, T) input token indices targets: (B, T) target token indices for loss computation past_key_values: List of (k, v) tuples for each layer (for caching) use_cache: Whether to return present key values Returns: logits: (B, T, V) output logits loss: Optional loss value present_key_values: Optional list of present (k, v) for caching """ B, T = input_ids.shape device = input_ids.device # Compute embeddings tok = self.tok_emb(input_ids) # (B, T, d) # Add positional embeddings if using learned positions (not RoPE) if self.pos_emb is not None: if past_key_values is not None: # During generation with cache, only process new position pos_offset = past_key_values[0][0].size(2) pos = torch.arange(pos_offset, pos_offset + T, device=device) else: pos = torch.arange(0, T, device=device) assert pos.max() < self.cfg.n_positions, f"Position {pos.max()} exceeds n_positions {self.cfg.n_positions}" pos_emb = self.pos_emb(pos)[None, :, :] # (1, T, d) x = tok + pos_emb else: x = tok x = self.drop(x) # Pass through transformer blocks present_key_values = [] if use_cache else None for i, block in enumerate(self.blocks): past_kv = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: # Use gradient checkpointing to save memory def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs, use_cache=False) return custom_forward x, _ = torch.utils.checkpoint.checkpoint( create_custom_forward(block), x, None, # attn_mask past_kv, use_reentrant=False ) if use_cache: present_key_values.append(None) # Placeholder else: x, present_kv = block(x, attn_mask=None, past_kv=past_kv, use_cache=use_cache) if use_cache: present_key_values.append(present_kv) x = self.ln_f(x) # Compute logits via tied embeddings logits = x @ self.tok_emb.weight.T # (B, T, V) # Compute loss if targets provided loss = None if targets is not None: # Shift for next-token prediction logits_ = logits[:, :-1, :].contiguous() targets_ = targets[:, 1:].contiguous() loss = F.cross_entropy( logits_.view(-1, logits_.size(-1)), targets_.view(-1), ignore_index=-100, ) return logits, loss, present_key_values @torch.no_grad() def generate( self, idx: torch.Tensor, max_new_tokens: int, temperature: float = 1.0, top_k: Optional[int] = None, top_p: Optional[float] = None, repetition_penalty: float = 1.0, use_cache: bool = True ) -> torch.Tensor: """ Generate text autoregressively with various sampling strategies. Args: idx: (B, T) input token indices max_new_tokens: Number of tokens to generate temperature: Sampling temperature (higher = more random) top_k: Keep only top k logits (None = disabled) top_p: Nucleus sampling threshold (None = disabled) repetition_penalty: Penalty for repeated tokens (1.0 = no penalty) use_cache: Use KV caching for faster generation Returns: (B, T + max_new_tokens) generated token indices """ past_key_values = None for _ in range(max_new_tokens): # Crop context if needed (only when not using cache) if not use_cache or past_key_values is None: max_len = getattr(self.cfg, 'n_positions', 8192) idx_cond = idx if idx.size(1) <= max_len else idx[:, -max_len:] else: # With cache, only process the last token idx_cond = idx[:, -1:] # Forward pass logits, _, past_key_values = self( idx_cond, use_cache=use_cache ) logits = logits[:, -1, :] # (B, V) # Apply repetition penalty if repetition_penalty != 1.0: for i in range(idx.size(0)): for token_id in set(idx[i].tolist()): logits[i, token_id] /= repetition_penalty # Apply temperature logits = logits / temperature # Top-k filtering if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = float('-inf') # Nucleus (top-p) sampling if top_p is not None: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above threshold sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() sorted_indices_to_remove[:, 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) logits[indices_to_remove] = float('-inf') # Sample next token probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1) # (B, 1) # Append to sequence idx = torch.cat([idx, idx_next], dim=1) return idx def num_parameters(self, only_trainable: bool = True) -> int: """ Count model parameters. Args: only_trainable: If True, count only trainable parameters Returns: Total number of parameters """ if only_trainable: return sum(p.numel() for p in self.parameters() if p.requires_grad) return sum(p.numel() for p in self.parameters()) def parameter_breakdown(self) -> dict: """ Get detailed parameter count by component. Returns: Dictionary with parameter counts for each component """ breakdown = { "token_embeddings": sum(p.numel() for p in self.tok_emb.parameters()), "positional_embeddings": sum(p.numel() for p in self.pos_emb.parameters()) if self.pos_emb else 0, "attention": sum( p.numel() for block in self.blocks for p in block.attn.parameters() ), "mlp": sum( p.numel() for block in self.blocks for p in block.mlp.parameters() ), "layer_norm": sum( p.numel() for block in self.blocks for p in [block.ln1, block.ln2] ) + (sum(p.numel() for p in self.ln_f.parameters()) if self.cfg.final_layer_norm else 0), } breakdown["total"] = sum(breakdown.values()) breakdown["total_trainable"] = self.num_parameters(only_trainable=True) return breakdown def estimate_mfu(self, fwdbwd_per_iter: int, dt: float) -> float: """ Estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS. Args: fwdbwd_per_iter: Number of forward-backward passes per iteration dt: Time taken for iteration (seconds) Returns: MFU as a percentage (0-100) """ N = self.num_parameters() cfg = self.cfg L, H, Q, T = cfg.n_layers, cfg.n_heads, cfg.d_model // cfg.n_heads, cfg.n_positions # Estimate FLOPs per token (forward pass only) # Approximation: 6N + 12LHQ*T (attention dominates) flops_per_token = 6 * N + 12 * L * H * Q * T flops_per_fwdbwd = flops_per_token * T * fwdbwd_per_iter * 3 # 3x for backward pass flops_per_iter = flops_per_fwdbwd # A100 bfloat16 peak FLOPS flops_achieved = flops_per_iter / dt flops_promised = 312e12 # A100 GPU bfloat16 peak mfu = flops_achieved / flops_promised * 100 return mfu def configure_optimizers( self, weight_decay: float, learning_rate: float, betas: Tuple[float, float], device_type: str ): """ Configure optimizer with weight decay only on specific parameters. Args: weight_decay: L2 regularization coefficient learning_rate: Learning rate betas: Adam beta parameters device_type: 'cuda' or 'cpu' Returns: Configured AdamW optimizer """ # Separate parameters that should and shouldn't have weight decay decay = set() no_decay = set() whitelist_weight_modules = (nn.Linear,) blacklist_weight_modules = (nn.LayerNorm, nn.Embedding) for mn, m in self.named_modules(): for pn, p in m.named_parameters(): fpn = f'{mn}.{pn}' if mn else pn if pn.endswith('bias'): no_decay.add(fpn) elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): decay.add(fpn) elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): no_decay.add(fpn) # Validate that we've covered all parameters param_dict = {pn: p for pn, p in self.named_parameters()} inter_params = decay & no_decay union_params = decay | no_decay assert len(inter_params) == 0, f"Parameters in both decay/no_decay: {inter_params}" assert len(param_dict.keys() - union_params) == 0, f"Missing parameters: {param_dict.keys() - union_params}" # Create optimizer groups optim_groups = [ {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay}, {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, ] # Use fused AdamW if on CUDA for better performance use_fused = device_type == 'cuda' optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, fused=use_fused) return optimizer