"""Grouped-query attention with SDPA and KV-cache support.""" from __future__ import annotations from typing import Optional import torch import torch.nn.functional as F from torch import nn from model.config import ModelConfig from model.rope import apply_rope def repeat_kv(x: torch.Tensor, num_groups: int) -> torch.Tensor: """Expand KV heads to match the number of query heads.""" if num_groups == 1: return x batch, kv_heads, seq_len, head_dim = x.shape x = x[:, :, None, :, :].expand(batch, kv_heads, num_groups, seq_len, head_dim) return x.reshape(batch, kv_heads * num_groups, seq_len, head_dim) class GQAAttention(nn.Module): """Fused-QKV grouped-query attention.""" def __init__(self, config: ModelConfig): super().__init__() self.config = config self.num_heads = config.num_attn_heads self.num_kv_heads = config.num_kv_heads self.head_dim = config.head_dim self.num_groups = self.num_heads // self.num_kv_heads qkv_dim = (self.num_heads + 2 * self.num_kv_heads) * self.head_dim self.qkv_proj = nn.Linear(config.d_model, qkv_dim, bias=False) self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False) self.dropout = config.dropout def forward( self, hidden_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: """Compute causal self-attention and return an updated KV cache.""" batch_size, seq_len, _ = hidden_states.shape qkv = self.qkv_proj(hidden_states) q_end = self.num_heads * self.head_dim k_end = q_end + self.num_kv_heads * self.head_dim q, k, v = qkv.split((q_end, self.num_kv_heads * self.head_dim, self.num_kv_heads * self.head_dim), dim=-1) q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) q_rope, k_rope = apply_rope(q, repeat_kv(k, self.num_groups), cos, sin) k = k_rope[:, :: self.num_groups, :, :] if past_key_value is not None: past_key, past_value = past_key_value k = torch.cat([past_key, k], dim=-2) v = torch.cat([past_value, v], dim=-2) expanded_k = repeat_kv(k, self.num_groups) expanded_v = repeat_kv(v, self.num_groups) attn_output = F.scaled_dot_product_attention( q_rope, expanded_k, expanded_v, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=past_key_value is None, ) attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.config.d_model) return self.out_proj(attn_output), (k, v)