| | """ |
| | Multi-Head Attention with explicit KV cache for SLM. |
| | Qualcomm-safe: No FlashAttention, no fused ops, clean ONNX export. |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from typing import Optional, Tuple |
| |
|
| | from .config import SLMConfig |
| | from .rope import RotaryEmbedding |
| | from .kv_cache import KVCache |
| |
|
| |
|
| | class MultiHeadAttention(nn.Module): |
| | """Multi-Head Self-Attention with RoPE and explicit KV cache. |
| | |
| | Design choices for Qualcomm compatibility: |
| | - Standard attention (no FlashAttention) |
| | - No grouped/multi-query attention (simpler, v1.1 will add GQA) |
| | - Explicit KV cache management |
| | - Clean tensor operations for ONNX export |
| | """ |
| |
|
| | def __init__(self, config: SLMConfig, layer_idx: int): |
| | """Initialize attention layer. |
| | |
| | Args: |
| | config: Model configuration |
| | layer_idx: Index of this layer (for KV cache) |
| | """ |
| | super().__init__() |
| | self.config = config |
| | self.layer_idx = layer_idx |
| |
|
| | self.hidden_size = config.hidden_size |
| | self.num_heads = config.num_heads |
| | self.head_dim = config.head_dim |
| | self.dropout = config.attention_dropout |
| |
|
| | |
| | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) |
| | self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) |
| | self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) |
| |
|
| | |
| | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) |
| |
|
| | |
| | self.rotary_emb = RotaryEmbedding( |
| | dim=self.head_dim, |
| | max_position_embeddings=config.max_position_embeddings, |
| | base=config.rope_theta, |
| | ) |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | position_ids: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | kv_cache: Optional[KVCache] = None, |
| | use_cache: bool = False, |
| | ) -> Tuple[torch.Tensor, Optional[KVCache]]: |
| | """Forward pass for attention. |
| | |
| | Args: |
| | hidden_states: Input tensor [batch, seq_len, hidden_size] |
| | position_ids: Position indices [batch, seq_len] |
| | attention_mask: Causal mask [batch, 1, seq_len, kv_seq_len] |
| | kv_cache: Optional KV cache for inference |
| | use_cache: Whether to use/update KV cache |
| | |
| | Returns: |
| | Tuple of (output, kv_cache) |
| | """ |
| | batch_size, seq_len, _ = hidden_states.shape |
| |
|
| | |
| | query = self.q_proj(hidden_states) |
| | key = self.k_proj(hidden_states) |
| | value = self.v_proj(hidden_states) |
| |
|
| | |
| | query = query.view(batch_size, seq_len, self.num_heads, self.head_dim) |
| | key = key.view(batch_size, seq_len, self.num_heads, self.head_dim) |
| | value = value.view(batch_size, seq_len, self.num_heads, self.head_dim) |
| |
|
| | |
| | query = query.transpose(1, 2) |
| | key = key.transpose(1, 2) |
| | value = value.transpose(1, 2) |
| |
|
| | |
| | query, key = self.rotary_emb(query, key, position_ids) |
| |
|
| | |
| | if use_cache and kv_cache is not None: |
| | |
| | cache_position = position_ids[0, 0].item() |
| |
|
| | |
| | key, value = kv_cache.update( |
| | layer_idx=self.layer_idx, |
| | key=key, |
| | value=value, |
| | position=cache_position, |
| | ) |
| |
|
| | |
| | |
| | |
| | scale = 1.0 / (self.head_dim ** 0.5) |
| | attn_weights = torch.matmul(query, key.transpose(-2, -1)) * scale |
| |
|
| | |
| | if attention_mask is not None: |
| | attn_weights = attn_weights + attention_mask |
| |
|
| | |
| | attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) |
| |
|
| | if self.training and self.dropout > 0: |
| | attn_weights = F.dropout(attn_weights, p=self.dropout) |
| |
|
| | |
| | |
| | |
| | attn_output = torch.matmul(attn_weights, value) |
| |
|
| | |
| | attn_output = attn_output.transpose(1, 2).contiguous() |
| | attn_output = attn_output.view(batch_size, seq_len, self.hidden_size) |
| |
|
| | |
| | output = self.o_proj(attn_output) |
| |
|
| | return output, kv_cache |
| |
|
| |
|
| | def create_causal_mask( |
| | seq_len: int, |
| | kv_seq_len: int, |
| | dtype: torch.dtype, |
| | device: torch.device, |
| | ) -> torch.Tensor: |
| | """Create a causal attention mask. |
| | |
| | Args: |
| | seq_len: Query sequence length |
| | kv_seq_len: Key/value sequence length |
| | dtype: Data type for mask |
| | device: Device for mask |
| | |
| | Returns: |
| | Causal mask tensor [1, 1, seq_len, kv_seq_len] |
| | """ |
| | |
| | mask = torch.full((seq_len, kv_seq_len), float("-inf"), dtype=dtype, device=device) |
| |
|
| | |
| | if seq_len == 1: |
| | mask = torch.zeros((seq_len, kv_seq_len), dtype=dtype, device=device) |
| | else: |
| | |
| | |
| | for i in range(seq_len): |
| | |
| | offset = kv_seq_len - seq_len |
| | mask[i, : offset + i + 1] = 0.0 |
| |
|
| | return mask.unsqueeze(0).unsqueeze(0) |
| |
|