Prisma / layers.py
y3i12's picture
Initial commit
56e82ec
"""
Shared building blocks for Circuit Transformer architectures.
Components:
- RMSNorm: Root Mean Square Layer Normalization
- RotaryEmbedding: Rotary Position Embedding (RoPE)
- CausalAttention: Multi-head causal attention with RoPE + KV cache
- SwiGLU: Gated feed-forward network
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from functools import lru_cache
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization."""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
return (x.float() * norm).type_as(x) * self.weight
def build_word_start_table(tokenizer, vocab_size: int) -> torch.BoolTensor:
"""Build a boolean table marking which token IDs start a new word.
Detects word boundaries from tokenizer's token representations:
- Ġ prefix (GPT-2/BPE style)
- ▁ prefix (SentencePiece style)
- Special tokens (starting with <)
"""
table = torch.zeros(vocab_size, dtype=torch.bool)
# Get all token strings — handle both HF and SentencePiece tokenizers
if hasattr(tokenizer, 'convert_ids_to_tokens'):
tokens = tokenizer.convert_ids_to_tokens(list(range(vocab_size)))
elif hasattr(tokenizer, 'sp'):
tokens = [tokenizer.sp.IdToPiece(i) for i in range(vocab_size)]
else:
tokens = [tokenizer.decode([i]) for i in range(vocab_size)]
for idx, tok in enumerate(tokens):
if tok is None:
continue
if tok.startswith('Ġ') or tok.startswith('▁') or tok.startswith('<'):
table[idx] = True
# Punctuation and newlines that start new "words"
elif len(tok) > 0 and tok[0] in '\n\r\t':
table[idx] = True
# Token 0 is always a word starter (BOS/padding)
table[0] = True
return table
def compute_word_positions(input_ids: torch.Tensor, word_start_table: torch.Tensor) -> torch.Tensor:
"""Compute position-within-word for each token. Vectorized, no loops.
Args:
input_ids: [B, L] token IDs
word_start_table: [vocab_size] bool tensor from build_word_start_table
Returns:
[B, L] float tensor: 0, 1, 2, 0, 1, 0, ... (resets at each word boundary)
"""
is_word_start = word_start_table[input_ids] # [B, L]
is_word_start[:, 0] = True # First token always starts a word
B, L = input_ids.shape
positions = torch.arange(L, device=input_ids.device, dtype=torch.float32).unsqueeze(0).expand(B, -1)
# Fill non-word-start positions with -1, word-start positions with their index
fill = torch.where(is_word_start, positions, torch.tensor(-1.0, device=input_ids.device))
# cummax propagates the most recent word-start position forward
running_start, _ = fill.cummax(dim=1)
# Position within word = distance from the most recent word start
word_pos = positions - running_start # [B, L] float: 0, 1, 2, 0, 1, 0, ...
return word_pos
class WordPositionRoPE(nn.Module):
"""RoPE encoding for position-within-word.
Dedicates a small subspace of head dimensions to word-internal position,
using separate (lower) frequency bases. Overrides the last `word_dims`
of the standard RoPE cos/sin tensors.
"""
def __init__(self, word_dims: int, word_base: float = 10.0):
super().__init__()
self.word_dims = word_dims
word_inv_freq = 1.0 / (word_base ** (torch.arange(0, word_dims, 2).float() / word_dims))
self.register_buffer("word_inv_freq", word_inv_freq)
def forward(
self, cos: torch.Tensor, sin: torch.Tensor, word_positions: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Override last word_dims of cos/sin with word-position-derived values.
Args:
cos, sin: [L, head_dim] from standard RotaryEmbedding
word_positions: [B, L] float tensor (position within word)
Returns:
cos, sin: [B, L, head_dim] with word dims overridden
"""
B, L = word_positions.shape
# Compute word angles: [B, L, word_dims/2]
angles = word_positions.unsqueeze(-1) * self.word_inv_freq
# Duplicate for rotate_half pattern: [B, L, word_dims]
word_emb = torch.cat([angles, angles], dim=-1)
word_cos = word_emb.cos()
word_sin = word_emb.sin()
# Expand standard cos/sin to batch dimension: [L, D] -> [B, L, D]
cos = cos.unsqueeze(0).expand(B, -1, -1).clone()
sin = sin.unsqueeze(0).expand(B, -1, -1).clone()
# Override last word_dims with word-position values
cos[:, :, -self.word_dims:] = word_cos
sin[:, :, -self.word_dims:] = word_sin
return cos, sin
class RotaryEmbedding(nn.Module):
"""Rotary Position Embedding (RoPE)."""
def __init__(self, dim: int, max_seq_len: int = 2048, base: float = 10000.0):
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self._build_cache(max_seq_len)
def _build_cache(self, seq_len: int):
t = torch.arange(seq_len, device=self.inv_freq.device)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos(), persistent=False)
self.register_buffer("sin_cached", emb.sin(), persistent=False)
def forward(self, x: torch.Tensor, seq_len: int) -> tuple[torch.Tensor, torch.Tensor]:
if seq_len > self.cos_cached.size(0):
self._build_cache(seq_len)
return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
def rotate_half(x: torch.Tensor) -> torch.Tensor:
"""Rotate half the hidden dims."""
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Apply rotary position embedding to queries and keys.
Handles both standard [L, D] and batched [B, L, D] cos/sin.
Q, K shape: [B, H, L, D]. For batched cos/sin, unsqueeze dim 1 for head broadcast.
"""
if cos.dim() == 3: # [B, L, D] from WordPositionRoPE
cos = cos.unsqueeze(1) # [B, 1, L, D] — broadcast over heads
sin = sin.unsqueeze(1)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class CausalAttention(nn.Module):
"""Multi-head attention with causal mask, RoPE, and optional GQA.
Supports Grouped Query Attention (GQA) where num_kv_heads < num_heads.
Each KV head serves (num_heads // num_kv_heads) query heads.
KV cache stored at kv_heads granularity for memory efficiency.
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int | None = None,
max_seq_len: int = 2048,
dropout: float = 0.0,
window_size: int | None = None,
word_rope_dims: int = 0,
word_rope_base: float = 10.0,
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads or num_heads
self.head_dim = hidden_size // num_heads
self.num_kv_groups = self.num_heads // self.num_kv_heads
self.dropout = dropout
self.window_size = window_size
assert self.num_heads % self.num_kv_heads == 0, \
f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads})"
if word_rope_dims > 0:
assert word_rope_dims <= self.head_dim, \
f"word_rope_dims ({word_rope_dims}) must be <= head_dim ({self.head_dim})"
assert word_rope_dims % 2 == 0, \
f"word_rope_dims ({word_rope_dims}) must be even"
self.q_proj = nn.Linear(hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False)
self.rotary = RotaryEmbedding(self.head_dim, max_seq_len)
# Word-position RoPE (optional)
self.word_rope = WordPositionRoPE(word_rope_dims, word_rope_base) if word_rope_dims > 0 else None
# Build causal mask (optionally windowed)
mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
if window_size is not None:
# Band mask: position i attends to [max(0, i-window+1), i]
band = torch.triu(torch.ones(max_seq_len, max_seq_len), diagonal=-(window_size - 1))
mask = mask * band
self.register_buffer(
"causal_mask",
mask.view(1, 1, max_seq_len, max_seq_len),
persistent=False,
)
def _expand_kv(self, kv: torch.Tensor) -> torch.Tensor:
"""Expand KV heads to match Q heads for GQA. No-op if num_kv_heads == num_heads."""
if self.num_kv_groups == 1:
return kv
B, H_kv, L, D = kv.shape
return kv.unsqueeze(2).expand(B, H_kv, self.num_kv_groups, L, D).reshape(B, self.num_heads, L, D)
def forward(
self, x: torch.Tensor, use_cache: bool = False, past_kv: tuple | None = None,
word_positions: torch.Tensor | None = None,
) -> tuple[torch.Tensor, tuple | None]:
B, L, _ = x.shape
q = self.q_proj(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
# RoPE: use correct position offset for KV-cached generation
offset = past_kv[0].size(2) if past_kv is not None else 0
cos, sin = self.rotary(x, offset + L)
cos = cos[offset:offset + L]
sin = sin[offset:offset + L]
# Override word-position dims if enabled
if self.word_rope is not None and word_positions is not None:
cos, sin = self.word_rope(cos, sin, word_positions)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
# KV cache at kv_heads granularity (memory efficient for GQA)
if past_kv is not None:
past_k, past_v = past_kv
k = torch.cat([past_k, k], dim=2)
v = torch.cat([past_v, v], dim=2)
new_kv = (k, v) if use_cache else None
dropout_p = self.dropout if self.training else 0.0
use_gqa = self.num_kv_groups > 1
if self.window_size is not None:
# Windowed attention: manual path (SDPA FlashAttention doesn't support arbitrary masks)
k_expanded = self._expand_kv(k)
v_expanded = self._expand_kv(v)
seq_len = k.size(2)
attn = torch.matmul(q, k_expanded.transpose(-2, -1)) / math.sqrt(self.head_dim)
if seq_len <= self.causal_mask.size(-1):
mask = self.causal_mask[:, :, offset:offset + L, :seq_len]
attn = attn.masked_fill(mask == 0, float("-inf"))
attn = F.softmax(attn, dim=-1)
if dropout_p > 0:
attn = F.dropout(attn, p=dropout_p)
out = torch.matmul(attn, v_expanded)
else:
# SDPA: auto-dispatches to FlashAttention2 / memory-efficient / math backend
# Native GQA support avoids expanding KV heads (saves memory + enables FlashAttention GQA kernel)
is_causal = past_kv is None and L > 1
out = F.scaled_dot_product_attention(
q, k, v,
dropout_p=dropout_p,
is_causal=is_causal,
enable_gqa=use_gqa,
)
out = out.transpose(1, 2).contiguous().view(B, L, self.hidden_size)
return self.o_proj(out), new_kv
class SwiGLU(nn.Module):
"""SwiGLU feed-forward network."""
def __init__(self, hidden_size: int, intermediate_size: int | None = None):
super().__init__()
intermediate_size = intermediate_size or int(hidden_size * 8 / 3)
intermediate_size = ((intermediate_size + 63) // 64) * 64
self.w1 = nn.Linear(hidden_size, intermediate_size, bias=False)
self.w2 = nn.Linear(intermediate_size, hidden_size, bias=False)
self.w3 = nn.Linear(hidden_size, intermediate_size, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(F.silu(self.w1(x)) * self.w3(x))