sage / model /attention.py
sage002's picture
feat: rewrite SAGE 1B architecture and replace legacy repo contents
ef18673 verified
"""Grouped-query attention with SDPA and KV-cache support."""
from __future__ import annotations
from typing import Optional
import torch
import torch.nn.functional as F
from torch import nn
from model.config import ModelConfig
from model.rope import apply_rope
def repeat_kv(x: torch.Tensor, num_groups: int) -> torch.Tensor:
"""Expand KV heads to match the number of query heads."""
if num_groups == 1:
return x
batch, kv_heads, seq_len, head_dim = x.shape
x = x[:, :, None, :, :].expand(batch, kv_heads, num_groups, seq_len, head_dim)
return x.reshape(batch, kv_heads * num_groups, seq_len, head_dim)
class GQAAttention(nn.Module):
"""Fused-QKV grouped-query attention."""
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
self.num_heads = config.num_attn_heads
self.num_kv_heads = config.num_kv_heads
self.head_dim = config.head_dim
self.num_groups = self.num_heads // self.num_kv_heads
qkv_dim = (self.num_heads + 2 * self.num_kv_heads) * self.head_dim
self.qkv_proj = nn.Linear(config.d_model, qkv_dim, bias=False)
self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False)
self.dropout = config.dropout
def forward(
self,
hidden_states: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""Compute causal self-attention and return an updated KV cache."""
batch_size, seq_len, _ = hidden_states.shape
qkv = self.qkv_proj(hidden_states)
q_end = self.num_heads * self.head_dim
k_end = q_end + self.num_kv_heads * self.head_dim
q, k, v = qkv.split((q_end, self.num_kv_heads * self.head_dim, self.num_kv_heads * self.head_dim), dim=-1)
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
q_rope, k_rope = apply_rope(q, repeat_kv(k, self.num_groups), cos, sin)
k = k_rope[:, :: self.num_groups, :, :]
if past_key_value is not None:
past_key, past_value = past_key_value
k = torch.cat([past_key, k], dim=-2)
v = torch.cat([past_value, v], dim=-2)
expanded_k = repeat_kv(k, self.num_groups)
expanded_v = repeat_kv(v, self.num_groups)
attn_output = F.scaled_dot_product_attention(
q_rope,
expanded_k,
expanded_v,
attn_mask=None,
dropout_p=self.dropout if self.training else 0.0,
is_causal=past_key_value is None,
)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.config.d_model)
return self.out_proj(attn_output), (k, v)