| """ |
| Multi-Head (and Grouped-Query) Attention with optional FlashAttention-2 backend. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import math |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from .config import LMConfig |
|
|
| |
| |
| |
| try: |
| from flash_attn import flash_attn_func |
| HAS_FLASH_ATTN = True |
| except ImportError: |
| HAS_FLASH_ATTN = False |
|
|
| |
| |
| |
| try: |
| import transformer_engine.pytorch as te |
| HAS_TE = True |
| except ImportError: |
| te = None |
| HAS_TE = False |
|
|
|
|
| |
| |
| |
|
|
| def apply_rotary_emb( |
| x: torch.Tensor, |
| cos: torch.Tensor, |
| sin: torch.Tensor, |
| ) -> torch.Tensor: |
| """Apply rotary positional embeddings to query or key tensor. |
| |
| Args: |
| x: (B, T, H, D_head) |
| cos: (T, D_head // 2) — from RotaryEmbedding.forward |
| sin: (T, D_head // 2) — from RotaryEmbedding.forward |
| |
| Returns: |
| Tensor with the same shape as *x*, rotated. |
| """ |
| d = x.shape[-1] |
| half_d = d // 2 |
|
|
| x1 = x[..., :half_d] |
| x2 = x[..., half_d:] |
|
|
| |
| cos = cos.unsqueeze(0).unsqueeze(2) |
| sin = sin.unsqueeze(0).unsqueeze(2) |
|
|
| rotated = torch.cat( |
| [x1 * cos - x2 * sin, x1 * sin + x2 * cos], |
| dim=-1, |
| ) |
| return rotated.to(x.dtype) |
|
|
|
|
|
|
| |
| |
| |
|
|
| class MultiHeadAttention(nn.Module): |
| """Multi-head (or grouped-query) causal self-attention. |
| |
| Supports: |
| - Standard MHA: n_kv_heads == n_heads |
| - GQA / MQA: n_kv_heads < n_heads (must evenly divide n_heads) |
| |
| Attention backend: |
| - FlashAttention-2 when available and config.use_flash_attn is True |
| - Vanilla scaled dot-product otherwise (causal mask via upper-triangular) |
| """ |
|
|
| def __init__(self, config: LMConfig) -> None: |
| super().__init__() |
|
|
| self.n_heads = config.n_heads |
| self.n_kv_heads = config.n_kv_heads |
| self.head_dim = config.d_model // config.n_heads |
| self.d_model = config.d_model |
| self.dropout = config.dropout |
| self.use_flash = config.use_flash_attn |
|
|
| |
| self.n_rep = self.n_heads // self.n_kv_heads |
|
|
| |
| |
| _Linear = te.Linear if (config.use_fp8 and HAS_TE) else nn.Linear |
|
|
| |
| |
| self._q_dim = self.n_heads * self.head_dim |
| self._kv_dim = self.n_kv_heads * self.head_dim |
| self.qkv_proj = _Linear( |
| config.d_model, |
| self._q_dim + 2 * self._kv_dim, |
| bias=config.bias, |
| ) |
| self.out_proj = _Linear( |
| config.d_model, |
| config.d_model, |
| bias=config.bias, |
| ) |
|
|
| |
| |
| |
|
|
| @staticmethod |
| def _repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: |
| """Expand KV heads to match the number of query heads. |
| |
| Args: |
| x: (B, T, n_kv_heads, head_dim) |
| n_rep: repetition factor |
| |
| Returns: |
| (B, T, n_kv_heads * n_rep, head_dim) |
| """ |
| if n_rep == 1: |
| return x |
| B, T, n_kv, D = x.shape |
| return x.repeat_interleave(n_rep, dim=2) |
|
|
| |
| |
| |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| cos: torch.Tensor, |
| sin: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| Args: |
| x: (B, T, C) |
| cos: (T, head_dim // 2) — from RotaryEmbedding |
| sin: (T, head_dim // 2) — from RotaryEmbedding |
| |
| Returns: |
| (B, T, C) |
| """ |
| B, T, C = x.shape |
|
|
| |
| qkv = self.qkv_proj(x) |
| q, k, v = qkv.split([self._q_dim, self._kv_dim, self._kv_dim], dim=-1) |
| q = q.view(B, T, self.n_heads, self.head_dim) |
| k = k.view(B, T, self.n_kv_heads, self.head_dim) |
| v = v.view(B, T, self.n_kv_heads, self.head_dim) |
|
|
| |
| |
| if q.dtype not in (torch.float16, torch.bfloat16): |
| q = q.to(torch.bfloat16) |
| k = k.to(torch.bfloat16) |
| v = v.to(torch.bfloat16) |
|
|
| |
| q = apply_rotary_emb(q, cos, sin) |
| k = apply_rotary_emb(k, cos, sin) |
|
|
| |
| if self.use_flash and HAS_FLASH_ATTN and x.is_cuda: |
| attn_out = self._flash_attention(q, k, v, B, T) |
| else: |
| attn_out = self._standard_attention(q, k, v, B, T) |
|
|
| |
| |
| return self.out_proj(attn_out) |
|
|
| |
| |
| |
|
|
| def _flash_attention( |
| self, |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| B: int, |
| T: int, |
| ) -> torch.Tensor: |
| """Run FlashAttention-2. |
| |
| flash_attn_func expects inputs in (B, T, H, D) layout and returns |
| (B, T, H, D). FlashAttention-2 natively supports GQA via head count |
| mismatch (q has n_heads, k/v have n_kv_heads) — no KV expansion needed. |
| """ |
| dropout_p = self.dropout if self.training else 0.0 |
|
|
| |
| |
| out = flash_attn_func(q, k, v, dropout_p=dropout_p, causal=True) |
|
|
| |
| return out.reshape(B, T, self.n_heads * self.head_dim) |
|
|
| |
| |
| |
|
|
| def _standard_attention( |
| self, |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| B: int, |
| T: int, |
| ) -> torch.Tensor: |
| """Vanilla scaled dot-product causal attention. |
| |
| Softmax is computed in float32 for numerical stability. |
| """ |
| |
| k = self._repeat_kv(k, self.n_rep) |
| v = self._repeat_kv(v, self.n_rep) |
|
|
| |
| q = q.transpose(1, 2) |
| k = k.transpose(1, 2) |
| v = v.transpose(1, 2) |
|
|
| scale = math.sqrt(self.head_dim) |
|
|
| |
| scores = torch.matmul(q, k.transpose(-2, -1)) / scale |
|
|
| |
| causal_mask = torch.triu( |
| torch.ones(T, T, device=q.device, dtype=torch.bool), diagonal=1 |
| ) |
| scores = scores.masked_fill(causal_mask, float("-inf")) |
|
|
| |
| attn_weights = F.softmax(scores.float(), dim=-1).to(q.dtype) |
|
|
| if self.training and self.dropout > 0.0: |
| attn_weights = F.dropout(attn_weights, p=self.dropout) |
|
|
| |
| out = torch.matmul(attn_weights, v) |
|
|
| |
| out = out.transpose(1, 2).contiguous().reshape(B, T, self.d_model) |
| return out |
|
|