| """Grouped Query Attention (GQA).""" |
|
|
| import math |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from llm_lab.config import ModelConfig |
| from .rope import RotaryPositionalEmbedding |
|
|
|
|
| class GroupedQueryAttention(nn.Module): |
| """GQA: A memory-efficient variant of Multi-Head Attention. |
| |
| MHA vs GQA vs MQA: |
| - MHA (Multi-Head Attention): Q, K, V all have num_heads β high memory usage |
| - MQA (Multi-Query Attention): K, V share a single head β risk of quality degradation |
| - GQA (Grouped Query Attention): K, V are grouped into num_kv_heads |
| β a middle ground between MHA and MQA, good quality-efficiency balance |
| |
| Example (num_heads=16, num_kv_heads=4): |
| Q heads: [0,1,2,3, 4,5,6,7, 8,9,10,11, 12,13,14,15] |
| K/V groups: [ 0 , 1 , 2 , 3 ] |
| β 4 Q heads share 1 K/V head |
| |
| Attention formula: |
| Attention(Q, K, V) = softmax(QΒ·K^T / βd_k) Β· V |
| """ |
|
|
| def __init__(self, config: ModelConfig): |
| super().__init__() |
| self.config = config |
| self.head_dim = config.head_dim |
| self.num_heads = config.num_heads |
| self.num_kv_heads = config.num_kv_heads |
| self.num_kv_groups = config.num_kv_groups |
|
|
| |
| |
| self.q_proj = nn.Linear(config.hidden_dim, config.num_heads * self.head_dim, bias=False) |
| |
| self.k_proj = nn.Linear(config.hidden_dim, config.num_kv_heads * self.head_dim, bias=False) |
| self.v_proj = nn.Linear(config.hidden_dim, config.num_kv_heads * self.head_dim, bias=False) |
|
|
| |
| self.o_proj = nn.Linear(config.num_heads * self.head_dim, config.hidden_dim, bias=False) |
|
|
| |
| self.rope = RotaryPositionalEmbedding( |
| dim=self.head_dim, max_seq_len=config.max_seq_len, theta=config.rope_theta |
| ) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| mask: Optional[torch.Tensor] = None, |
| position_offset: int = 0, |
| ) -> torch.Tensor: |
| """ |
| Args: |
| x: (batch_size, seq_len, hidden_dim) |
| mask: (seq_len, seq_len) causal mask |
| position_offset: position offset (used during inference) |
| |
| Returns: |
| (batch_size, seq_len, hidden_dim) |
| """ |
| batch_size, seq_len, _ = x.shape |
|
|
| |
| |
| |
| q = self.q_proj(x) |
| k = self.k_proj(x) |
| v = self.v_proj(x) |
|
|
| |
| q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
| |
| k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) |
| |
| v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) |
|
|
| |
| |
| |
| |
| |
| q, k = self.rope(q, k, position_offset) |
|
|
| |
| |
| |
| |
| if self.num_kv_groups > 1: |
| k = self._repeat_kv(k) |
| v = self._repeat_kv(v) |
|
|
| |
| |
| |
| |
| |
| |
| |
| scale = math.sqrt(self.head_dim) |
| attn_scores = torch.matmul(q, k.transpose(-2, -1)) / scale |
|
|
| |
| |
| |
| if mask is not None: |
| attn_scores = attn_scores + mask |
| else: |
| causal_mask = torch.triu( |
| torch.full((seq_len, seq_len), float("-inf"), device=q.device, dtype=q.dtype), |
| diagonal=1, |
| ) |
| attn_scores = attn_scores + causal_mask |
|
|
| |
| attn_weights = F.softmax(attn_scores, dim=-1) |
|
|
| |
| |
| |
| if self.training and self.config.dropout > 0.0: |
| attn_weights = F.dropout(attn_weights, p=self.config.dropout) |
|
|
| |
| attn_out = torch.matmul(attn_weights, v) |
| |
|
|
| |
| |
| |
| attn_out = attn_out.transpose(1, 2).contiguous().view(batch_size, seq_len, -1) |
| |
|
|
| return self.o_proj(attn_out) |
|
|
| def _repeat_kv(self, x: torch.Tensor) -> torch.Tensor: |
| """Repeat KV heads to match the number of Q heads. |
| |
| (batch_size, num_kv_heads, seq_len, head_dim) β (batch_size, num_heads, seq_len, head_dim) |
| |
| Example: num_kv_heads=4, num_kv_groups=4 |
| [kv0, kv1, kv2, kv3] β [kv0,kv0,kv0,kv0, kv1,kv1,kv1,kv1, ...] |
| """ |
| batch_size, num_kv_heads, seq_len, head_dim = x.shape |
| x = x[:, :, None, :, :] |
| x = x.expand(batch_size, num_kv_heads, self.num_kv_groups, seq_len, head_dim) |
| return x.reshape(batch_size, self.num_heads, seq_len, head_dim) |
|
|