File size: 5,941 Bytes
c27df58 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 | """
Multi-Head Attention with explicit KV cache for SLM.
Qualcomm-safe: No FlashAttention, no fused ops, clean ONNX export.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
from .config import SLMConfig
from .rope import RotaryEmbedding
from .kv_cache import KVCache
class MultiHeadAttention(nn.Module):
"""Multi-Head Self-Attention with RoPE and explicit KV cache.
Design choices for Qualcomm compatibility:
- Standard attention (no FlashAttention)
- No grouped/multi-query attention (simpler, v1.1 will add GQA)
- Explicit KV cache management
- Clean tensor operations for ONNX export
"""
def __init__(self, config: SLMConfig, layer_idx: int):
"""Initialize attention layer.
Args:
config: Model configuration
layer_idx: Index of this layer (for KV cache)
"""
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.num_heads = config.num_heads
self.head_dim = config.head_dim
self.dropout = config.attention_dropout
# Q, K, V projections
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
# Output projection
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
# Rotary embeddings
self.rotary_emb = RotaryEmbedding(
dim=self.head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
kv_cache: Optional[KVCache] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[KVCache]]:
"""Forward pass for attention.
Args:
hidden_states: Input tensor [batch, seq_len, hidden_size]
position_ids: Position indices [batch, seq_len]
attention_mask: Causal mask [batch, 1, seq_len, kv_seq_len]
kv_cache: Optional KV cache for inference
use_cache: Whether to use/update KV cache
Returns:
Tuple of (output, kv_cache)
"""
batch_size, seq_len, _ = hidden_states.shape
# Project to Q, K, V
query = self.q_proj(hidden_states)
key = self.k_proj(hidden_states)
value = self.v_proj(hidden_states)
# Reshape: [batch, seq, hidden] -> [batch, seq, heads, head_dim]
query = query.view(batch_size, seq_len, self.num_heads, self.head_dim)
key = key.view(batch_size, seq_len, self.num_heads, self.head_dim)
value = value.view(batch_size, seq_len, self.num_heads, self.head_dim)
# Transpose for attention: [batch, heads, seq, head_dim]
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# Apply rotary embeddings to Q and K
query, key = self.rotary_emb(query, key, position_ids)
# Handle KV cache
if use_cache and kv_cache is not None:
# Get the position to write to cache
cache_position = position_ids[0, 0].item()
# Update cache and get full K, V
key, value = kv_cache.update(
layer_idx=self.layer_idx,
key=key,
value=value,
position=cache_position,
)
# Compute attention scores
# [batch, heads, seq, head_dim] @ [batch, heads, head_dim, kv_seq]
# -> [batch, heads, seq, kv_seq]
scale = 1.0 / (self.head_dim ** 0.5)
attn_weights = torch.matmul(query, key.transpose(-2, -1)) * scale
# Apply causal mask
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
# Softmax and dropout
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
if self.training and self.dropout > 0:
attn_weights = F.dropout(attn_weights, p=self.dropout)
# Apply attention to values
# [batch, heads, seq, kv_seq] @ [batch, heads, kv_seq, head_dim]
# -> [batch, heads, seq, head_dim]
attn_output = torch.matmul(attn_weights, value)
# Reshape back: [batch, heads, seq, head_dim] -> [batch, seq, hidden]
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_len, self.hidden_size)
# Output projection
output = self.o_proj(attn_output)
return output, kv_cache
def create_causal_mask(
seq_len: int,
kv_seq_len: int,
dtype: torch.dtype,
device: torch.device,
) -> torch.Tensor:
"""Create a causal attention mask.
Args:
seq_len: Query sequence length
kv_seq_len: Key/value sequence length
dtype: Data type for mask
device: Device for mask
Returns:
Causal mask tensor [1, 1, seq_len, kv_seq_len]
"""
# Create lower triangular mask
mask = torch.full((seq_len, kv_seq_len), float("-inf"), dtype=dtype, device=device)
# For decode (seq_len=1), we can attend to all previous tokens
if seq_len == 1:
mask = torch.zeros((seq_len, kv_seq_len), dtype=dtype, device=device)
else:
# For prefill, create standard causal mask
# Position i can attend to positions 0..i
for i in range(seq_len):
# Offset for KV cache
offset = kv_seq_len - seq_len
mask[i, : offset + i + 1] = 0.0
return mask.unsqueeze(0).unsqueeze(0)
|