"""Standard Transformer language model implementation.""" import math from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from taoTrain.core import BaseModel from taoTrain.config import ModelConfig from .registry import register_architecture # ============================================================================ # Components # ============================================================================ class PositionalEmbedding(nn.Module): """Sinusoidal positional embeddings.""" def __init__(self, dim: int, max_seq_length: int = 2048): """Initialize positional embeddings.""" super().__init__() self.dim = dim self.max_seq_length = max_seq_length # Precompute positional embeddings pe = torch.zeros(max_seq_length, dim) pos = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim)) pe[:, 0::2] = torch.sin(pos * div_term) if dim % 2 == 1: pe[:, 1::2] = torch.cos(pos * div_term[:-1]) else: pe[:, 1::2] = torch.cos(pos * div_term) self.register_buffer("pe", pe, persistent=False) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Add positional embeddings to input. Args: x: Input tensor (batch, seq_len, hidden_dim) Returns: Input + positional embeddings """ seq_len = x.shape[1] return x + self.pe[:seq_len] class Attention(nn.Module): """Multi-head self-attention using scaled dot-product attention.""" def __init__(self, config: ModelConfig): """Initialize attention.""" super().__init__() self.hidden_dim = config.hidden_dim self.num_heads = config.num_heads self.head_dim = config.head_dim assert self.hidden_dim % self.num_heads == 0 # Linear projections self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim) self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim) self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim) self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim) self.dropout_p = config.dropout def forward( self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Forward pass using scaled_dot_product_attention. Args: x: Shape (batch, seq_len, hidden_dim) attention_mask: Shape (batch, seq_len) Returns: Output: Shape (batch, seq_len, hidden_dim) """ batch_size, seq_len, _ = x.shape # Project to Q, K, V q = self.q_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim) k = self.k_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim) v = self.v_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim) # Transpose for attention: (batch, num_heads, seq_len, head_dim) q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) # NOTE: PyTorch's scaled_dot_product_attention does NOT support both # explicit attn_mask AND is_causal=True together. # When is_causal=True, PyTorch handles causal masking automatically. # Padding positions are handled separately via loss computation (labels=-100). # See: https://github.com/pytorch/pytorch/issues/96099 # Compute attention using scaled_dot_product_attention # is_causal=True automatically applies causal masking # We do NOT pass attn_mask when is_causal=True out = F.scaled_dot_product_attention( q, k, v, attn_mask=None, # Must be None when is_causal=True dropout_p=self.dropout_p if self.training else 0.0, is_causal=True, scale=None # Uses default scale of 1/sqrt(head_dim) ) # (batch, num_heads, seq_len, head_dim) # Transpose back and reshape out = out.transpose(1, 2).contiguous() # (batch, seq_len, num_heads, head_dim) out = out.reshape(batch_size, seq_len, self.hidden_dim) # Output projection out = self.out_proj(out) return out class SwiGLU(nn.Module): """Swish Gated Linear Unit activation.""" def __init__(self, in_dim: int, out_dim: int, dropout: float = 0.0): """ Initialize SwiGLU. Args: in_dim: Input dimension out_dim: Intermediate/hidden dimension dropout: Dropout rate """ super().__init__() # Project to 2x the intermediate dimension (for value and gate) self.fc1 = nn.Linear(in_dim, 2 * out_dim) self.fc2 = nn.Linear(out_dim, in_dim) # Project back to input dimension self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass with SwiGLU activation. Args: x: Input tensor Returns: Gated activation output (same dimension as input) """ # Project to 2x intermediate dimension x = self.fc1(x) # Split into value and gate x, gate = x.chunk(2, dim=-1) # SwiGLU: value * swish(gate) = value * gate * sigmoid(gate) x = x * F.silu(gate) # SiLU is Swish: x * sigmoid(x) x = self.dropout(x) x = self.fc2(x) # Project back to input dimension return x class FeedForward(nn.Module): """Feed-forward network with SwiGLU activation.""" def __init__(self, config: ModelConfig): """Initialize FFN with SwiGLU.""" super().__init__() self.swiglu = SwiGLU( in_dim=config.hidden_dim, out_dim=config.intermediate_dim, dropout=config.dropout ) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass with SwiGLU activation.""" return self.swiglu(x) class TransformerBlock(nn.Module): """Single transformer block with attention and FFN.""" def __init__(self, config: ModelConfig): """Initialize transformer block.""" super().__init__() self.norm1 = nn.LayerNorm(config.hidden_dim) self.attn = Attention(config) self.norm2 = nn.LayerNorm(config.hidden_dim) self.ffn = FeedForward(config) def forward( self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with pre-norm residual connections.""" # Attention with residual x = x + self.attn(self.norm1(x), attention_mask=attention_mask) # FFN with residual x = x + self.ffn(self.norm2(x)) return x # ============================================================================ # Transformer LM # ============================================================================ @register_architecture("transformer") class TransformerLM(BaseModel): """Standard Transformer language model.""" def __init__(self, config: ModelConfig): """Initialize Transformer LM.""" super().__init__(config) # Embeddings self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_dim) self.pos_embed = PositionalEmbedding(config.hidden_dim, max_seq_length=config.max_seq_length) self.dropout = nn.Dropout(config.dropout) # Transformer blocks self.blocks = nn.ModuleList([ TransformerBlock(config) for _ in range(config.num_layers) ]) # Final layer norm self.final_norm = nn.LayerNorm(config.hidden_dim) # Output projection (shared with input embeddings for efficiency) self.lm_head = nn.Linear(config.hidden_dim, config.vocab_size, bias=False) # Weight tying (optional) self.lm_head.weight = self.embed_tokens.weight # Initialize weights self._init_weights() def _init_weights(self): """Initialize model weights.""" for module in self.modules(): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, std=self.config.init_std) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, std=self.config.init_std) def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, ) -> dict[str, torch.Tensor]: """ Forward pass. Args: input_ids: (batch_size, seq_len) attention_mask: (batch_size, seq_len) labels: (batch_size, seq_len) for loss computation Returns: Dict with 'logits' and optionally 'loss' """ batch_size, seq_len = input_ids.shape # Embedding x = self.embed_tokens(input_ids) # Add positional embeddings x = self.pos_embed(x) x = self.dropout(x) # Transformer blocks for block in self.blocks: x = block(x, attention_mask=attention_mask) # Final normalization x = self.final_norm(x) # LM head logits = self.lm_head(x) # (batch, seq_len, vocab_size) # Loss computation loss = None if labels is not None: # Flatten for loss computation logits_flat = logits.view(-1, logits.size(-1)) # (batch * seq_len, vocab_size) labels_flat = labels.view(-1) # Only compute loss on valid targets (ignore -100 tokens) loss = F.cross_entropy( logits_flat, labels_flat, reduction='mean', ignore_index=-100 ) return { 'logits': logits, 'loss': loss, }