| | """
|
| | Shared building blocks for Circuit Transformer architectures.
|
| |
|
| | Components:
|
| | - RMSNorm: Root Mean Square Layer Normalization
|
| | - RotaryEmbedding: Rotary Position Embedding (RoPE)
|
| | - CausalAttention: Multi-head causal attention with RoPE + KV cache
|
| | - SwiGLU: Gated feed-forward network
|
| | """
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| | import math
|
| | from functools import lru_cache
|
| |
|
| |
|
| | class RMSNorm(nn.Module):
|
| | """Root Mean Square Layer Normalization."""
|
| |
|
| | def __init__(self, dim: int, eps: float = 1e-6):
|
| | super().__init__()
|
| | self.eps = eps
|
| | self.weight = nn.Parameter(torch.ones(dim))
|
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| | norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
|
| | return (x.float() * norm).type_as(x) * self.weight
|
| |
|
| |
|
| | def build_word_start_table(tokenizer, vocab_size: int) -> torch.BoolTensor:
|
| | """Build a boolean table marking which token IDs start a new word.
|
| |
|
| | Detects word boundaries from tokenizer's token representations:
|
| | - Ġ prefix (GPT-2/BPE style)
|
| | - ▁ prefix (SentencePiece style)
|
| | - Special tokens (starting with <)
|
| | """
|
| | table = torch.zeros(vocab_size, dtype=torch.bool)
|
| |
|
| |
|
| | if hasattr(tokenizer, 'convert_ids_to_tokens'):
|
| | tokens = tokenizer.convert_ids_to_tokens(list(range(vocab_size)))
|
| | elif hasattr(tokenizer, 'sp'):
|
| | tokens = [tokenizer.sp.IdToPiece(i) for i in range(vocab_size)]
|
| | else:
|
| | tokens = [tokenizer.decode([i]) for i in range(vocab_size)]
|
| |
|
| | for idx, tok in enumerate(tokens):
|
| | if tok is None:
|
| | continue
|
| | if tok.startswith('Ġ') or tok.startswith('▁') or tok.startswith('<'):
|
| | table[idx] = True
|
| |
|
| | elif len(tok) > 0 and tok[0] in '\n\r\t':
|
| | table[idx] = True
|
| |
|
| |
|
| | table[0] = True
|
| |
|
| | return table
|
| |
|
| |
|
| | def compute_word_positions(input_ids: torch.Tensor, word_start_table: torch.Tensor) -> torch.Tensor:
|
| | """Compute position-within-word for each token. Vectorized, no loops.
|
| |
|
| | Args:
|
| | input_ids: [B, L] token IDs
|
| | word_start_table: [vocab_size] bool tensor from build_word_start_table
|
| |
|
| | Returns:
|
| | [B, L] float tensor: 0, 1, 2, 0, 1, 0, ... (resets at each word boundary)
|
| | """
|
| | is_word_start = word_start_table[input_ids]
|
| | is_word_start[:, 0] = True
|
| |
|
| | B, L = input_ids.shape
|
| | positions = torch.arange(L, device=input_ids.device, dtype=torch.float32).unsqueeze(0).expand(B, -1)
|
| |
|
| |
|
| | fill = torch.where(is_word_start, positions, torch.tensor(-1.0, device=input_ids.device))
|
| |
|
| |
|
| | running_start, _ = fill.cummax(dim=1)
|
| |
|
| |
|
| | word_pos = positions - running_start
|
| |
|
| | return word_pos
|
| |
|
| |
|
| | class WordPositionRoPE(nn.Module):
|
| | """RoPE encoding for position-within-word.
|
| |
|
| | Dedicates a small subspace of head dimensions to word-internal position,
|
| | using separate (lower) frequency bases. Overrides the last `word_dims`
|
| | of the standard RoPE cos/sin tensors.
|
| | """
|
| |
|
| | def __init__(self, word_dims: int, word_base: float = 10.0):
|
| | super().__init__()
|
| | self.word_dims = word_dims
|
| | word_inv_freq = 1.0 / (word_base ** (torch.arange(0, word_dims, 2).float() / word_dims))
|
| | self.register_buffer("word_inv_freq", word_inv_freq)
|
| |
|
| | def forward(
|
| | self, cos: torch.Tensor, sin: torch.Tensor, word_positions: torch.Tensor
|
| | ) -> tuple[torch.Tensor, torch.Tensor]:
|
| | """Override last word_dims of cos/sin with word-position-derived values.
|
| |
|
| | Args:
|
| | cos, sin: [L, head_dim] from standard RotaryEmbedding
|
| | word_positions: [B, L] float tensor (position within word)
|
| |
|
| | Returns:
|
| | cos, sin: [B, L, head_dim] with word dims overridden
|
| | """
|
| | B, L = word_positions.shape
|
| |
|
| |
|
| | angles = word_positions.unsqueeze(-1) * self.word_inv_freq
|
| |
|
| | word_emb = torch.cat([angles, angles], dim=-1)
|
| | word_cos = word_emb.cos()
|
| | word_sin = word_emb.sin()
|
| |
|
| |
|
| | cos = cos.unsqueeze(0).expand(B, -1, -1).clone()
|
| | sin = sin.unsqueeze(0).expand(B, -1, -1).clone()
|
| |
|
| |
|
| | cos[:, :, -self.word_dims:] = word_cos
|
| | sin[:, :, -self.word_dims:] = word_sin
|
| |
|
| | return cos, sin
|
| |
|
| |
|
| | class RotaryEmbedding(nn.Module):
|
| | """Rotary Position Embedding (RoPE)."""
|
| |
|
| | def __init__(self, dim: int, max_seq_len: int = 2048, base: float = 10000.0):
|
| | super().__init__()
|
| | self.dim = dim
|
| | self.max_seq_len = max_seq_len
|
| |
|
| | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
| | self.register_buffer("inv_freq", inv_freq)
|
| | self._build_cache(max_seq_len)
|
| |
|
| | def _build_cache(self, seq_len: int):
|
| | t = torch.arange(seq_len, device=self.inv_freq.device)
|
| | freqs = torch.outer(t, self.inv_freq)
|
| | emb = torch.cat((freqs, freqs), dim=-1)
|
| | self.register_buffer("cos_cached", emb.cos(), persistent=False)
|
| | self.register_buffer("sin_cached", emb.sin(), persistent=False)
|
| |
|
| | def forward(self, x: torch.Tensor, seq_len: int) -> tuple[torch.Tensor, torch.Tensor]:
|
| | if seq_len > self.cos_cached.size(0):
|
| | self._build_cache(seq_len)
|
| | return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
|
| |
|
| |
|
| | def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| | """Rotate half the hidden dims."""
|
| | x1, x2 = x.chunk(2, dim=-1)
|
| | return torch.cat((-x2, x1), dim=-1)
|
| |
|
| |
|
| | def apply_rotary_pos_emb(
|
| | q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
| | ) -> tuple[torch.Tensor, torch.Tensor]:
|
| | """Apply rotary position embedding to queries and keys.
|
| |
|
| | Handles both standard [L, D] and batched [B, L, D] cos/sin.
|
| | Q, K shape: [B, H, L, D]. For batched cos/sin, unsqueeze dim 1 for head broadcast.
|
| | """
|
| | if cos.dim() == 3:
|
| | cos = cos.unsqueeze(1)
|
| | sin = sin.unsqueeze(1)
|
| | q_embed = (q * cos) + (rotate_half(q) * sin)
|
| | k_embed = (k * cos) + (rotate_half(k) * sin)
|
| | return q_embed, k_embed
|
| |
|
| |
|
| | class CausalAttention(nn.Module):
|
| | """Multi-head attention with causal mask, RoPE, and optional GQA.
|
| |
|
| | Supports Grouped Query Attention (GQA) where num_kv_heads < num_heads.
|
| | Each KV head serves (num_heads // num_kv_heads) query heads.
|
| | KV cache stored at kv_heads granularity for memory efficiency.
|
| | """
|
| |
|
| | def __init__(
|
| | self,
|
| | hidden_size: int,
|
| | num_heads: int,
|
| | num_kv_heads: int | None = None,
|
| | max_seq_len: int = 2048,
|
| | dropout: float = 0.0,
|
| | window_size: int | None = None,
|
| | word_rope_dims: int = 0,
|
| | word_rope_base: float = 10.0,
|
| | ):
|
| | super().__init__()
|
| | self.hidden_size = hidden_size
|
| | self.num_heads = num_heads
|
| | self.num_kv_heads = num_kv_heads or num_heads
|
| | self.head_dim = hidden_size // num_heads
|
| | self.num_kv_groups = self.num_heads // self.num_kv_heads
|
| | self.dropout = dropout
|
| | self.window_size = window_size
|
| |
|
| | assert self.num_heads % self.num_kv_heads == 0, \
|
| | f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads})"
|
| | if word_rope_dims > 0:
|
| | assert word_rope_dims <= self.head_dim, \
|
| | f"word_rope_dims ({word_rope_dims}) must be <= head_dim ({self.head_dim})"
|
| | assert word_rope_dims % 2 == 0, \
|
| | f"word_rope_dims ({word_rope_dims}) must be even"
|
| |
|
| | self.q_proj = nn.Linear(hidden_size, self.num_heads * self.head_dim, bias=False)
|
| | self.k_proj = nn.Linear(hidden_size, self.num_kv_heads * self.head_dim, bias=False)
|
| | self.v_proj = nn.Linear(hidden_size, self.num_kv_heads * self.head_dim, bias=False)
|
| | self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False)
|
| |
|
| | self.rotary = RotaryEmbedding(self.head_dim, max_seq_len)
|
| |
|
| |
|
| | self.word_rope = WordPositionRoPE(word_rope_dims, word_rope_base) if word_rope_dims > 0 else None
|
| |
|
| |
|
| | mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
|
| | if window_size is not None:
|
| |
|
| | band = torch.triu(torch.ones(max_seq_len, max_seq_len), diagonal=-(window_size - 1))
|
| | mask = mask * band
|
| | self.register_buffer(
|
| | "causal_mask",
|
| | mask.view(1, 1, max_seq_len, max_seq_len),
|
| | persistent=False,
|
| | )
|
| |
|
| | def _expand_kv(self, kv: torch.Tensor) -> torch.Tensor:
|
| | """Expand KV heads to match Q heads for GQA. No-op if num_kv_heads == num_heads."""
|
| | if self.num_kv_groups == 1:
|
| | return kv
|
| | B, H_kv, L, D = kv.shape
|
| | return kv.unsqueeze(2).expand(B, H_kv, self.num_kv_groups, L, D).reshape(B, self.num_heads, L, D)
|
| |
|
| | def forward(
|
| | self, x: torch.Tensor, use_cache: bool = False, past_kv: tuple | None = None,
|
| | word_positions: torch.Tensor | None = None,
|
| | ) -> tuple[torch.Tensor, tuple | None]:
|
| | B, L, _ = x.shape
|
| |
|
| | q = self.q_proj(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
|
| | k = self.k_proj(x).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
| | v = self.v_proj(x).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
| |
|
| |
|
| | offset = past_kv[0].size(2) if past_kv is not None else 0
|
| | cos, sin = self.rotary(x, offset + L)
|
| | cos = cos[offset:offset + L]
|
| | sin = sin[offset:offset + L]
|
| |
|
| |
|
| | if self.word_rope is not None and word_positions is not None:
|
| | cos, sin = self.word_rope(cos, sin, word_positions)
|
| |
|
| | q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
| |
|
| |
|
| | if past_kv is not None:
|
| | past_k, past_v = past_kv
|
| | k = torch.cat([past_k, k], dim=2)
|
| | v = torch.cat([past_v, v], dim=2)
|
| |
|
| | new_kv = (k, v) if use_cache else None
|
| |
|
| | dropout_p = self.dropout if self.training else 0.0
|
| | use_gqa = self.num_kv_groups > 1
|
| |
|
| | if self.window_size is not None:
|
| |
|
| | k_expanded = self._expand_kv(k)
|
| | v_expanded = self._expand_kv(v)
|
| | seq_len = k.size(2)
|
| | attn = torch.matmul(q, k_expanded.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| | if seq_len <= self.causal_mask.size(-1):
|
| | mask = self.causal_mask[:, :, offset:offset + L, :seq_len]
|
| | attn = attn.masked_fill(mask == 0, float("-inf"))
|
| | attn = F.softmax(attn, dim=-1)
|
| | if dropout_p > 0:
|
| | attn = F.dropout(attn, p=dropout_p)
|
| | out = torch.matmul(attn, v_expanded)
|
| | else:
|
| |
|
| |
|
| | is_causal = past_kv is None and L > 1
|
| | out = F.scaled_dot_product_attention(
|
| | q, k, v,
|
| | dropout_p=dropout_p,
|
| | is_causal=is_causal,
|
| | enable_gqa=use_gqa,
|
| | )
|
| |
|
| | out = out.transpose(1, 2).contiguous().view(B, L, self.hidden_size)
|
| |
|
| | return self.o_proj(out), new_kv
|
| |
|
| |
|
| | class SwiGLU(nn.Module):
|
| | """SwiGLU feed-forward network."""
|
| |
|
| | def __init__(self, hidden_size: int, intermediate_size: int | None = None):
|
| | super().__init__()
|
| | intermediate_size = intermediate_size or int(hidden_size * 8 / 3)
|
| | intermediate_size = ((intermediate_size + 63) // 64) * 64
|
| |
|
| | self.w1 = nn.Linear(hidden_size, intermediate_size, bias=False)
|
| | self.w2 = nn.Linear(intermediate_size, hidden_size, bias=False)
|
| | self.w3 = nn.Linear(hidden_size, intermediate_size, bias=False)
|
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| | return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
| |
|