PebbleLM-117M / src /model /attention.py
nameissakthi's picture
Add model architecture code
27871e7
"""
Multi-Head Attention with explicit KV cache for SLM.
Qualcomm-safe: No FlashAttention, no fused ops, clean ONNX export.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
from .config import SLMConfig
from .rope import RotaryEmbedding
from .kv_cache import KVCache
class MultiHeadAttention(nn.Module):
"""Multi-Head Self-Attention with RoPE and explicit KV cache.
Design choices for Qualcomm compatibility:
- Standard attention (no FlashAttention)
- No grouped/multi-query attention (simpler, v1.1 will add GQA)
- Explicit KV cache management
- Clean tensor operations for ONNX export
"""
def __init__(self, config: SLMConfig, layer_idx: int):
"""Initialize attention layer.
Args:
config: Model configuration
layer_idx: Index of this layer (for KV cache)
"""
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.num_heads = config.num_heads
self.head_dim = config.head_dim
self.dropout = config.attention_dropout
# Q, K, V projections
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
# Output projection
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
# Rotary embeddings
self.rotary_emb = RotaryEmbedding(
dim=self.head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
kv_cache: Optional[KVCache] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[KVCache]]:
"""Forward pass for attention.
Args:
hidden_states: Input tensor [batch, seq_len, hidden_size]
position_ids: Position indices [batch, seq_len]
attention_mask: Causal mask [batch, 1, seq_len, kv_seq_len]
kv_cache: Optional KV cache for inference
use_cache: Whether to use/update KV cache
Returns:
Tuple of (output, kv_cache)
"""
batch_size, seq_len, _ = hidden_states.shape
# Project to Q, K, V
query = self.q_proj(hidden_states)
key = self.k_proj(hidden_states)
value = self.v_proj(hidden_states)
# Reshape: [batch, seq, hidden] -> [batch, seq, heads, head_dim]
query = query.view(batch_size, seq_len, self.num_heads, self.head_dim)
key = key.view(batch_size, seq_len, self.num_heads, self.head_dim)
value = value.view(batch_size, seq_len, self.num_heads, self.head_dim)
# Transpose for attention: [batch, heads, seq, head_dim]
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# Apply rotary embeddings to Q and K
query, key = self.rotary_emb(query, key, position_ids)
# Handle KV cache
if use_cache and kv_cache is not None:
# Get the position to write to cache
cache_position = position_ids[0, 0].item()
# Update cache and get full K, V
key, value = kv_cache.update(
layer_idx=self.layer_idx,
key=key,
value=value,
position=cache_position,
)
# Compute attention scores
# [batch, heads, seq, head_dim] @ [batch, heads, head_dim, kv_seq]
# -> [batch, heads, seq, kv_seq]
scale = 1.0 / (self.head_dim ** 0.5)
attn_weights = torch.matmul(query, key.transpose(-2, -1)) * scale
# Apply causal mask
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
# Softmax and dropout
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
if self.training and self.dropout > 0:
attn_weights = F.dropout(attn_weights, p=self.dropout)
# Apply attention to values
# [batch, heads, seq, kv_seq] @ [batch, heads, kv_seq, head_dim]
# -> [batch, heads, seq, head_dim]
attn_output = torch.matmul(attn_weights, value)
# Reshape back: [batch, heads, seq, head_dim] -> [batch, seq, hidden]
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_len, self.hidden_size)
# Output projection
output = self.o_proj(attn_output)
return output, kv_cache
def create_causal_mask(
seq_len: int,
kv_seq_len: int,
dtype: torch.dtype,
device: torch.device,
) -> torch.Tensor:
"""Create a causal attention mask.
Args:
seq_len: Query sequence length
kv_seq_len: Key/value sequence length
dtype: Data type for mask
device: Device for mask
Returns:
Causal mask tensor [1, 1, seq_len, kv_seq_len]
"""
# Create lower triangular mask
mask = torch.full((seq_len, kv_seq_len), float("-inf"), dtype=dtype, device=device)
# For decode (seq_len=1), we can attend to all previous tokens
if seq_len == 1:
mask = torch.zeros((seq_len, kv_seq_len), dtype=dtype, device=device)
else:
# For prefill, create standard causal mask
# Position i can attend to positions 0..i
for i in range(seq_len):
# Offset for KV cache
offset = kv_seq_len - seq_len
mask[i, : offset + i + 1] = 0.0
return mask.unsqueeze(0).unsqueeze(0)