File size: 13,187 Bytes
56e82ec | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 | """
Shared building blocks for Circuit Transformer architectures.
Components:
- RMSNorm: Root Mean Square Layer Normalization
- RotaryEmbedding: Rotary Position Embedding (RoPE)
- CausalAttention: Multi-head causal attention with RoPE + KV cache
- SwiGLU: Gated feed-forward network
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from functools import lru_cache
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization."""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
return (x.float() * norm).type_as(x) * self.weight
def build_word_start_table(tokenizer, vocab_size: int) -> torch.BoolTensor:
"""Build a boolean table marking which token IDs start a new word.
Detects word boundaries from tokenizer's token representations:
- Ġ prefix (GPT-2/BPE style)
- ▁ prefix (SentencePiece style)
- Special tokens (starting with <)
"""
table = torch.zeros(vocab_size, dtype=torch.bool)
# Get all token strings — handle both HF and SentencePiece tokenizers
if hasattr(tokenizer, 'convert_ids_to_tokens'):
tokens = tokenizer.convert_ids_to_tokens(list(range(vocab_size)))
elif hasattr(tokenizer, 'sp'):
tokens = [tokenizer.sp.IdToPiece(i) for i in range(vocab_size)]
else:
tokens = [tokenizer.decode([i]) for i in range(vocab_size)]
for idx, tok in enumerate(tokens):
if tok is None:
continue
if tok.startswith('Ġ') or tok.startswith('▁') or tok.startswith('<'):
table[idx] = True
# Punctuation and newlines that start new "words"
elif len(tok) > 0 and tok[0] in '\n\r\t':
table[idx] = True
# Token 0 is always a word starter (BOS/padding)
table[0] = True
return table
def compute_word_positions(input_ids: torch.Tensor, word_start_table: torch.Tensor) -> torch.Tensor:
"""Compute position-within-word for each token. Vectorized, no loops.
Args:
input_ids: [B, L] token IDs
word_start_table: [vocab_size] bool tensor from build_word_start_table
Returns:
[B, L] float tensor: 0, 1, 2, 0, 1, 0, ... (resets at each word boundary)
"""
is_word_start = word_start_table[input_ids] # [B, L]
is_word_start[:, 0] = True # First token always starts a word
B, L = input_ids.shape
positions = torch.arange(L, device=input_ids.device, dtype=torch.float32).unsqueeze(0).expand(B, -1)
# Fill non-word-start positions with -1, word-start positions with their index
fill = torch.where(is_word_start, positions, torch.tensor(-1.0, device=input_ids.device))
# cummax propagates the most recent word-start position forward
running_start, _ = fill.cummax(dim=1)
# Position within word = distance from the most recent word start
word_pos = positions - running_start # [B, L] float: 0, 1, 2, 0, 1, 0, ...
return word_pos
class WordPositionRoPE(nn.Module):
"""RoPE encoding for position-within-word.
Dedicates a small subspace of head dimensions to word-internal position,
using separate (lower) frequency bases. Overrides the last `word_dims`
of the standard RoPE cos/sin tensors.
"""
def __init__(self, word_dims: int, word_base: float = 10.0):
super().__init__()
self.word_dims = word_dims
word_inv_freq = 1.0 / (word_base ** (torch.arange(0, word_dims, 2).float() / word_dims))
self.register_buffer("word_inv_freq", word_inv_freq)
def forward(
self, cos: torch.Tensor, sin: torch.Tensor, word_positions: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Override last word_dims of cos/sin with word-position-derived values.
Args:
cos, sin: [L, head_dim] from standard RotaryEmbedding
word_positions: [B, L] float tensor (position within word)
Returns:
cos, sin: [B, L, head_dim] with word dims overridden
"""
B, L = word_positions.shape
# Compute word angles: [B, L, word_dims/2]
angles = word_positions.unsqueeze(-1) * self.word_inv_freq
# Duplicate for rotate_half pattern: [B, L, word_dims]
word_emb = torch.cat([angles, angles], dim=-1)
word_cos = word_emb.cos()
word_sin = word_emb.sin()
# Expand standard cos/sin to batch dimension: [L, D] -> [B, L, D]
cos = cos.unsqueeze(0).expand(B, -1, -1).clone()
sin = sin.unsqueeze(0).expand(B, -1, -1).clone()
# Override last word_dims with word-position values
cos[:, :, -self.word_dims:] = word_cos
sin[:, :, -self.word_dims:] = word_sin
return cos, sin
class RotaryEmbedding(nn.Module):
"""Rotary Position Embedding (RoPE)."""
def __init__(self, dim: int, max_seq_len: int = 2048, base: float = 10000.0):
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self._build_cache(max_seq_len)
def _build_cache(self, seq_len: int):
t = torch.arange(seq_len, device=self.inv_freq.device)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos(), persistent=False)
self.register_buffer("sin_cached", emb.sin(), persistent=False)
def forward(self, x: torch.Tensor, seq_len: int) -> tuple[torch.Tensor, torch.Tensor]:
if seq_len > self.cos_cached.size(0):
self._build_cache(seq_len)
return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
def rotate_half(x: torch.Tensor) -> torch.Tensor:
"""Rotate half the hidden dims."""
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Apply rotary position embedding to queries and keys.
Handles both standard [L, D] and batched [B, L, D] cos/sin.
Q, K shape: [B, H, L, D]. For batched cos/sin, unsqueeze dim 1 for head broadcast.
"""
if cos.dim() == 3: # [B, L, D] from WordPositionRoPE
cos = cos.unsqueeze(1) # [B, 1, L, D] — broadcast over heads
sin = sin.unsqueeze(1)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class CausalAttention(nn.Module):
"""Multi-head attention with causal mask, RoPE, and optional GQA.
Supports Grouped Query Attention (GQA) where num_kv_heads < num_heads.
Each KV head serves (num_heads // num_kv_heads) query heads.
KV cache stored at kv_heads granularity for memory efficiency.
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int | None = None,
max_seq_len: int = 2048,
dropout: float = 0.0,
window_size: int | None = None,
word_rope_dims: int = 0,
word_rope_base: float = 10.0,
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads or num_heads
self.head_dim = hidden_size // num_heads
self.num_kv_groups = self.num_heads // self.num_kv_heads
self.dropout = dropout
self.window_size = window_size
assert self.num_heads % self.num_kv_heads == 0, \
f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads})"
if word_rope_dims > 0:
assert word_rope_dims <= self.head_dim, \
f"word_rope_dims ({word_rope_dims}) must be <= head_dim ({self.head_dim})"
assert word_rope_dims % 2 == 0, \
f"word_rope_dims ({word_rope_dims}) must be even"
self.q_proj = nn.Linear(hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False)
self.rotary = RotaryEmbedding(self.head_dim, max_seq_len)
# Word-position RoPE (optional)
self.word_rope = WordPositionRoPE(word_rope_dims, word_rope_base) if word_rope_dims > 0 else None
# Build causal mask (optionally windowed)
mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
if window_size is not None:
# Band mask: position i attends to [max(0, i-window+1), i]
band = torch.triu(torch.ones(max_seq_len, max_seq_len), diagonal=-(window_size - 1))
mask = mask * band
self.register_buffer(
"causal_mask",
mask.view(1, 1, max_seq_len, max_seq_len),
persistent=False,
)
def _expand_kv(self, kv: torch.Tensor) -> torch.Tensor:
"""Expand KV heads to match Q heads for GQA. No-op if num_kv_heads == num_heads."""
if self.num_kv_groups == 1:
return kv
B, H_kv, L, D = kv.shape
return kv.unsqueeze(2).expand(B, H_kv, self.num_kv_groups, L, D).reshape(B, self.num_heads, L, D)
def forward(
self, x: torch.Tensor, use_cache: bool = False, past_kv: tuple | None = None,
word_positions: torch.Tensor | None = None,
) -> tuple[torch.Tensor, tuple | None]:
B, L, _ = x.shape
q = self.q_proj(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
# RoPE: use correct position offset for KV-cached generation
offset = past_kv[0].size(2) if past_kv is not None else 0
cos, sin = self.rotary(x, offset + L)
cos = cos[offset:offset + L]
sin = sin[offset:offset + L]
# Override word-position dims if enabled
if self.word_rope is not None and word_positions is not None:
cos, sin = self.word_rope(cos, sin, word_positions)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
# KV cache at kv_heads granularity (memory efficient for GQA)
if past_kv is not None:
past_k, past_v = past_kv
k = torch.cat([past_k, k], dim=2)
v = torch.cat([past_v, v], dim=2)
new_kv = (k, v) if use_cache else None
dropout_p = self.dropout if self.training else 0.0
use_gqa = self.num_kv_groups > 1
if self.window_size is not None:
# Windowed attention: manual path (SDPA FlashAttention doesn't support arbitrary masks)
k_expanded = self._expand_kv(k)
v_expanded = self._expand_kv(v)
seq_len = k.size(2)
attn = torch.matmul(q, k_expanded.transpose(-2, -1)) / math.sqrt(self.head_dim)
if seq_len <= self.causal_mask.size(-1):
mask = self.causal_mask[:, :, offset:offset + L, :seq_len]
attn = attn.masked_fill(mask == 0, float("-inf"))
attn = F.softmax(attn, dim=-1)
if dropout_p > 0:
attn = F.dropout(attn, p=dropout_p)
out = torch.matmul(attn, v_expanded)
else:
# SDPA: auto-dispatches to FlashAttention2 / memory-efficient / math backend
# Native GQA support avoids expanding KV heads (saves memory + enables FlashAttention GQA kernel)
is_causal = past_kv is None and L > 1
out = F.scaled_dot_product_attention(
q, k, v,
dropout_p=dropout_p,
is_causal=is_causal,
enable_gqa=use_gqa,
)
out = out.transpose(1, 2).contiguous().view(B, L, self.hidden_size)
return self.o_proj(out), new_kv
class SwiGLU(nn.Module):
"""SwiGLU feed-forward network."""
def __init__(self, hidden_size: int, intermediate_size: int | None = None):
super().__init__()
intermediate_size = intermediate_size or int(hidden_size * 8 / 3)
intermediate_size = ((intermediate_size + 63) // 64) * 64
self.w1 = nn.Linear(hidden_size, intermediate_size, bias=False)
self.w2 = nn.Linear(intermediate_size, hidden_size, bias=False)
self.w3 = nn.Linear(hidden_size, intermediate_size, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|