""" Shared building blocks for Circuit Transformer architectures. Components: - RMSNorm: Root Mean Square Layer Normalization - RotaryEmbedding: Rotary Position Embedding (RoPE) - CausalAttention: Multi-head causal attention with RoPE + KV cache - SwiGLU: Gated feed-forward network """ import torch import torch.nn as nn import torch.nn.functional as F import math from functools import lru_cache class RMSNorm(nn.Module): """Root Mean Square Layer Normalization.""" def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt() return (x.float() * norm).type_as(x) * self.weight def build_word_start_table(tokenizer, vocab_size: int) -> torch.BoolTensor: """Build a boolean table marking which token IDs start a new word. Detects word boundaries from tokenizer's token representations: - Ġ prefix (GPT-2/BPE style) - ▁ prefix (SentencePiece style) - Special tokens (starting with <) """ table = torch.zeros(vocab_size, dtype=torch.bool) # Get all token strings — handle both HF and SentencePiece tokenizers if hasattr(tokenizer, 'convert_ids_to_tokens'): tokens = tokenizer.convert_ids_to_tokens(list(range(vocab_size))) elif hasattr(tokenizer, 'sp'): tokens = [tokenizer.sp.IdToPiece(i) for i in range(vocab_size)] else: tokens = [tokenizer.decode([i]) for i in range(vocab_size)] for idx, tok in enumerate(tokens): if tok is None: continue if tok.startswith('Ġ') or tok.startswith('▁') or tok.startswith('<'): table[idx] = True # Punctuation and newlines that start new "words" elif len(tok) > 0 and tok[0] in '\n\r\t': table[idx] = True # Token 0 is always a word starter (BOS/padding) table[0] = True return table def compute_word_positions(input_ids: torch.Tensor, word_start_table: torch.Tensor) -> torch.Tensor: """Compute position-within-word for each token. Vectorized, no loops. Args: input_ids: [B, L] token IDs word_start_table: [vocab_size] bool tensor from build_word_start_table Returns: [B, L] float tensor: 0, 1, 2, 0, 1, 0, ... (resets at each word boundary) """ is_word_start = word_start_table[input_ids] # [B, L] is_word_start[:, 0] = True # First token always starts a word B, L = input_ids.shape positions = torch.arange(L, device=input_ids.device, dtype=torch.float32).unsqueeze(0).expand(B, -1) # Fill non-word-start positions with -1, word-start positions with their index fill = torch.where(is_word_start, positions, torch.tensor(-1.0, device=input_ids.device)) # cummax propagates the most recent word-start position forward running_start, _ = fill.cummax(dim=1) # Position within word = distance from the most recent word start word_pos = positions - running_start # [B, L] float: 0, 1, 2, 0, 1, 0, ... return word_pos class WordPositionRoPE(nn.Module): """RoPE encoding for position-within-word. Dedicates a small subspace of head dimensions to word-internal position, using separate (lower) frequency bases. Overrides the last `word_dims` of the standard RoPE cos/sin tensors. """ def __init__(self, word_dims: int, word_base: float = 10.0): super().__init__() self.word_dims = word_dims word_inv_freq = 1.0 / (word_base ** (torch.arange(0, word_dims, 2).float() / word_dims)) self.register_buffer("word_inv_freq", word_inv_freq) def forward( self, cos: torch.Tensor, sin: torch.Tensor, word_positions: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: """Override last word_dims of cos/sin with word-position-derived values. Args: cos, sin: [L, head_dim] from standard RotaryEmbedding word_positions: [B, L] float tensor (position within word) Returns: cos, sin: [B, L, head_dim] with word dims overridden """ B, L = word_positions.shape # Compute word angles: [B, L, word_dims/2] angles = word_positions.unsqueeze(-1) * self.word_inv_freq # Duplicate for rotate_half pattern: [B, L, word_dims] word_emb = torch.cat([angles, angles], dim=-1) word_cos = word_emb.cos() word_sin = word_emb.sin() # Expand standard cos/sin to batch dimension: [L, D] -> [B, L, D] cos = cos.unsqueeze(0).expand(B, -1, -1).clone() sin = sin.unsqueeze(0).expand(B, -1, -1).clone() # Override last word_dims with word-position values cos[:, :, -self.word_dims:] = word_cos sin[:, :, -self.word_dims:] = word_sin return cos, sin class RotaryEmbedding(nn.Module): """Rotary Position Embedding (RoPE).""" def __init__(self, dim: int, max_seq_len: int = 2048, base: float = 10000.0): super().__init__() self.dim = dim self.max_seq_len = max_seq_len inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) self._build_cache(max_seq_len) def _build_cache(self, seq_len: int): t = torch.arange(seq_len, device=self.inv_freq.device) freqs = torch.outer(t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos(), persistent=False) self.register_buffer("sin_cached", emb.sin(), persistent=False) def forward(self, x: torch.Tensor, seq_len: int) -> tuple[torch.Tensor, torch.Tensor]: if seq_len > self.cos_cached.size(0): self._build_cache(seq_len) return self.cos_cached[:seq_len], self.sin_cached[:seq_len] def rotate_half(x: torch.Tensor) -> torch.Tensor: """Rotate half the hidden dims.""" x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: """Apply rotary position embedding to queries and keys. Handles both standard [L, D] and batched [B, L, D] cos/sin. Q, K shape: [B, H, L, D]. For batched cos/sin, unsqueeze dim 1 for head broadcast. """ if cos.dim() == 3: # [B, L, D] from WordPositionRoPE cos = cos.unsqueeze(1) # [B, 1, L, D] — broadcast over heads sin = sin.unsqueeze(1) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class CausalAttention(nn.Module): """Multi-head attention with causal mask, RoPE, and optional GQA. Supports Grouped Query Attention (GQA) where num_kv_heads < num_heads. Each KV head serves (num_heads // num_kv_heads) query heads. KV cache stored at kv_heads granularity for memory efficiency. """ def __init__( self, hidden_size: int, num_heads: int, num_kv_heads: int | None = None, max_seq_len: int = 2048, dropout: float = 0.0, window_size: int | None = None, word_rope_dims: int = 0, word_rope_base: float = 10.0, ): super().__init__() self.hidden_size = hidden_size self.num_heads = num_heads self.num_kv_heads = num_kv_heads or num_heads self.head_dim = hidden_size // num_heads self.num_kv_groups = self.num_heads // self.num_kv_heads self.dropout = dropout self.window_size = window_size assert self.num_heads % self.num_kv_heads == 0, \ f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads})" if word_rope_dims > 0: assert word_rope_dims <= self.head_dim, \ f"word_rope_dims ({word_rope_dims}) must be <= head_dim ({self.head_dim})" assert word_rope_dims % 2 == 0, \ f"word_rope_dims ({word_rope_dims}) must be even" self.q_proj = nn.Linear(hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(hidden_size, self.num_kv_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(hidden_size, self.num_kv_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False) self.rotary = RotaryEmbedding(self.head_dim, max_seq_len) # Word-position RoPE (optional) self.word_rope = WordPositionRoPE(word_rope_dims, word_rope_base) if word_rope_dims > 0 else None # Build causal mask (optionally windowed) mask = torch.tril(torch.ones(max_seq_len, max_seq_len)) if window_size is not None: # Band mask: position i attends to [max(0, i-window+1), i] band = torch.triu(torch.ones(max_seq_len, max_seq_len), diagonal=-(window_size - 1)) mask = mask * band self.register_buffer( "causal_mask", mask.view(1, 1, max_seq_len, max_seq_len), persistent=False, ) def _expand_kv(self, kv: torch.Tensor) -> torch.Tensor: """Expand KV heads to match Q heads for GQA. No-op if num_kv_heads == num_heads.""" if self.num_kv_groups == 1: return kv B, H_kv, L, D = kv.shape return kv.unsqueeze(2).expand(B, H_kv, self.num_kv_groups, L, D).reshape(B, self.num_heads, L, D) def forward( self, x: torch.Tensor, use_cache: bool = False, past_kv: tuple | None = None, word_positions: torch.Tensor | None = None, ) -> tuple[torch.Tensor, tuple | None]: B, L, _ = x.shape q = self.q_proj(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2) # RoPE: use correct position offset for KV-cached generation offset = past_kv[0].size(2) if past_kv is not None else 0 cos, sin = self.rotary(x, offset + L) cos = cos[offset:offset + L] sin = sin[offset:offset + L] # Override word-position dims if enabled if self.word_rope is not None and word_positions is not None: cos, sin = self.word_rope(cos, sin, word_positions) q, k = apply_rotary_pos_emb(q, k, cos, sin) # KV cache at kv_heads granularity (memory efficient for GQA) if past_kv is not None: past_k, past_v = past_kv k = torch.cat([past_k, k], dim=2) v = torch.cat([past_v, v], dim=2) new_kv = (k, v) if use_cache else None dropout_p = self.dropout if self.training else 0.0 use_gqa = self.num_kv_groups > 1 if self.window_size is not None: # Windowed attention: manual path (SDPA FlashAttention doesn't support arbitrary masks) k_expanded = self._expand_kv(k) v_expanded = self._expand_kv(v) seq_len = k.size(2) attn = torch.matmul(q, k_expanded.transpose(-2, -1)) / math.sqrt(self.head_dim) if seq_len <= self.causal_mask.size(-1): mask = self.causal_mask[:, :, offset:offset + L, :seq_len] attn = attn.masked_fill(mask == 0, float("-inf")) attn = F.softmax(attn, dim=-1) if dropout_p > 0: attn = F.dropout(attn, p=dropout_p) out = torch.matmul(attn, v_expanded) else: # SDPA: auto-dispatches to FlashAttention2 / memory-efficient / math backend # Native GQA support avoids expanding KV heads (saves memory + enables FlashAttention GQA kernel) is_causal = past_kv is None and L > 1 out = F.scaled_dot_product_attention( q, k, v, dropout_p=dropout_p, is_causal=is_causal, enable_gqa=use_gqa, ) out = out.transpose(1, 2).contiguous().view(B, L, self.hidden_size) return self.o_proj(out), new_kv class SwiGLU(nn.Module): """SwiGLU feed-forward network.""" def __init__(self, hidden_size: int, intermediate_size: int | None = None): super().__init__() intermediate_size = intermediate_size or int(hidden_size * 8 / 3) intermediate_size = ((intermediate_size + 63) // 64) * 64 self.w1 = nn.Linear(hidden_size, intermediate_size, bias=False) self.w2 = nn.Linear(intermediate_size, hidden_size, bias=False) self.w3 = nn.Linear(hidden_size, intermediate_size, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w2(F.silu(self.w1(x)) * self.w3(x))