SHOREKEEPER / src /council /attention.py
geoore's picture
Restructure to src/ layout with attention, per-layer MoE, and working chat
73400c8
import torch
import torch.nn as nn
import torch.nn.functional as F
class RotaryEmbedding(nn.Module):
def __init__(self, head_dim: int, max_seq_len: int, theta: float = 1000000.0):
super().__init__()
freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim))
t = torch.arange(max_seq_len).float()
freqs = torch.outer(t, freqs) # (max_seq_len, head_dim//2)
self.register_buffer("cos", freqs.cos())
self.register_buffer("sin", freqs.sin())
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, n_heads, T, head_dim)
T = x.shape[2]
cos = self.cos[:T].unsqueeze(0).unsqueeze(0) # (1, 1, T, head_dim//2)
sin = self.sin[:T].unsqueeze(0).unsqueeze(0)
half = x.shape[-1] // 2
x1, x2 = x[..., :half], x[..., half:]
return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)
class AttentionLayer(nn.Module):
"""Grouped Query Attention with RoPE and pre-norm residual block."""
def __init__(self, cfg: dict):
super().__init__()
self.n_heads = cfg["n_heads"]
self.n_kv_heads = cfg["n_kv_heads"]
self.head_dim = cfg["head_dim"]
self.dim = cfg["dim"]
self.n_rep = self.n_heads // self.n_kv_heads
self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
self.norm = nn.RMSNorm(self.dim)
self.rope = RotaryEmbedding(self.head_dim, cfg["seq_len"], cfg.get("rope_theta", 1000000.0))
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = self.norm(x)
B, T, _ = x.shape
q = self.wq(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
q = self.rope(q)
k = self.rope(k)
# Expand KV heads to match Q heads (GQA)
k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)
attn = F.scaled_dot_product_attention(q, k, v, is_causal=True)
out = attn.transpose(1, 2).contiguous().view(B, T, -1)
return residual + self.wo(out)