""" Multi-Head Attention with explicit KV cache for SLM. Qualcomm-safe: No FlashAttention, no fused ops, clean ONNX export. """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple from .config import SLMConfig from .rope import RotaryEmbedding from .kv_cache import KVCache class MultiHeadAttention(nn.Module): """Multi-Head Self-Attention with RoPE and explicit KV cache. Design choices for Qualcomm compatibility: - Standard attention (no FlashAttention) - No grouped/multi-query attention (simpler, v1.1 will add GQA) - Explicit KV cache management - Clean tensor operations for ONNX export """ def __init__(self, config: SLMConfig, layer_idx: int): """Initialize attention layer. Args: config: Model configuration layer_idx: Index of this layer (for KV cache) """ super().__init__() self.config = config self.layer_idx = layer_idx self.hidden_size = config.hidden_size self.num_heads = config.num_heads self.head_dim = config.head_dim self.dropout = config.attention_dropout # Q, K, V projections self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) # Output projection self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) # Rotary embeddings self.rotary_emb = RotaryEmbedding( dim=self.head_dim, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta, ) def forward( self, hidden_states: torch.Tensor, position_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, kv_cache: Optional[KVCache] = None, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[KVCache]]: """Forward pass for attention. Args: hidden_states: Input tensor [batch, seq_len, hidden_size] position_ids: Position indices [batch, seq_len] attention_mask: Causal mask [batch, 1, seq_len, kv_seq_len] kv_cache: Optional KV cache for inference use_cache: Whether to use/update KV cache Returns: Tuple of (output, kv_cache) """ batch_size, seq_len, _ = hidden_states.shape # Project to Q, K, V query = self.q_proj(hidden_states) key = self.k_proj(hidden_states) value = self.v_proj(hidden_states) # Reshape: [batch, seq, hidden] -> [batch, seq, heads, head_dim] query = query.view(batch_size, seq_len, self.num_heads, self.head_dim) key = key.view(batch_size, seq_len, self.num_heads, self.head_dim) value = value.view(batch_size, seq_len, self.num_heads, self.head_dim) # Transpose for attention: [batch, heads, seq, head_dim] query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) # Apply rotary embeddings to Q and K query, key = self.rotary_emb(query, key, position_ids) # Handle KV cache if use_cache and kv_cache is not None: # Get the position to write to cache cache_position = position_ids[0, 0].item() # Update cache and get full K, V key, value = kv_cache.update( layer_idx=self.layer_idx, key=key, value=value, position=cache_position, ) # Compute attention scores # [batch, heads, seq, head_dim] @ [batch, heads, head_dim, kv_seq] # -> [batch, heads, seq, kv_seq] scale = 1.0 / (self.head_dim ** 0.5) attn_weights = torch.matmul(query, key.transpose(-2, -1)) * scale # Apply causal mask if attention_mask is not None: attn_weights = attn_weights + attention_mask # Softmax and dropout attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) if self.training and self.dropout > 0: attn_weights = F.dropout(attn_weights, p=self.dropout) # Apply attention to values # [batch, heads, seq, kv_seq] @ [batch, heads, kv_seq, head_dim] # -> [batch, heads, seq, head_dim] attn_output = torch.matmul(attn_weights, value) # Reshape back: [batch, heads, seq, head_dim] -> [batch, seq, hidden] attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(batch_size, seq_len, self.hidden_size) # Output projection output = self.o_proj(attn_output) return output, kv_cache def create_causal_mask( seq_len: int, kv_seq_len: int, dtype: torch.dtype, device: torch.device, ) -> torch.Tensor: """Create a causal attention mask. Args: seq_len: Query sequence length kv_seq_len: Key/value sequence length dtype: Data type for mask device: Device for mask Returns: Causal mask tensor [1, 1, seq_len, kv_seq_len] """ # Create lower triangular mask mask = torch.full((seq_len, kv_seq_len), float("-inf"), dtype=dtype, device=device) # For decode (seq_len=1), we can attend to all previous tokens if seq_len == 1: mask = torch.zeros((seq_len, kv_seq_len), dtype=dtype, device=device) else: # For prefill, create standard causal mask # Position i can attend to positions 0..i for i in range(seq_len): # Offset for KV cache offset = kv_seq_len - seq_len mask[i, : offset + i + 1] = 0.0 return mask.unsqueeze(0).unsqueeze(0)