| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class RotaryEmbedding(nn.Module): |
| def __init__(self, head_dim: int, max_seq_len: int, theta: float = 1000000.0): |
| super().__init__() |
| freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim)) |
| t = torch.arange(max_seq_len).float() |
| freqs = torch.outer(t, freqs) |
| self.register_buffer("cos", freqs.cos()) |
| self.register_buffer("sin", freqs.sin()) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| T = x.shape[2] |
| cos = self.cos[:T].unsqueeze(0).unsqueeze(0) |
| sin = self.sin[:T].unsqueeze(0).unsqueeze(0) |
| half = x.shape[-1] // 2 |
| x1, x2 = x[..., :half], x[..., half:] |
| return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) |
|
|
|
|
| class AttentionLayer(nn.Module): |
| """Grouped Query Attention with RoPE and pre-norm residual block.""" |
|
|
| def __init__(self, cfg: dict): |
| super().__init__() |
| self.n_heads = cfg["n_heads"] |
| self.n_kv_heads = cfg["n_kv_heads"] |
| self.head_dim = cfg["head_dim"] |
| self.dim = cfg["dim"] |
| self.n_rep = self.n_heads // self.n_kv_heads |
|
|
| self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False) |
| self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) |
| self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) |
| self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) |
|
|
| self.norm = nn.RMSNorm(self.dim) |
| self.rope = RotaryEmbedding(self.head_dim, cfg["seq_len"], cfg.get("rope_theta", 1000000.0)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| residual = x |
| x = self.norm(x) |
| B, T, _ = x.shape |
|
|
| q = self.wq(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
| k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) |
| v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) |
|
|
| q = self.rope(q) |
| k = self.rope(k) |
|
|
| |
| k = k.repeat_interleave(self.n_rep, dim=1) |
| v = v.repeat_interleave(self.n_rep, dim=1) |
|
|
| attn = F.scaled_dot_product_attention(q, k, v, is_causal=True) |
| out = attn.transpose(1, 2).contiguous().view(B, T, -1) |
| return residual + self.wo(out) |
|
|