plapre-pico-coreml / scripts /attention.py
Daniel Rothmann
Performance improvements on iOS
ffa94c8
"""
CoreML-compatible replacements for HuggingFace LlamaForCausalLM building blocks.
HF's SDPA attention and dynamic RoPE are not traceable by torch.jit.trace / coremltools.
This module provides static, explicit implementations that produce identical outputs.
The decode attention processes 1 token per step and writes to the KV cache using a
broadcast one-hot mask: k_cache * (1 - mask) + k * mask.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.register_buffer("eps", torch.tensor(eps, dtype=torch.float32))
def forward(self, x: torch.Tensor) -> torch.Tensor:
# fp16-safe RMSNorm: pre-scale to avoid overflow in x.pow(2)
# fp16 max is 65504, so values > 256 overflow when squared.
# Scale down by max abs value, compute norm, scale back.
# Math: (x/s) / sqrt(mean((x/s)^2)) = x / sqrt(mean(x^2)) β€” s cancels.
scale = x.abs().amax(-1, keepdim=True).clamp(min=1.0)
x_scaled = x / scale
variance = x_scaled.pow(2).mean(-1, keepdim=True)
x_norm = x_scaled * torch.rsqrt(variance + self.eps)
return self.weight * x_norm
def precompute_rope_frequencies(
head_dim: int, max_positions: int, theta: float = 100000.0
) -> tuple[torch.Tensor, torch.Tensor]:
"""Precompute cos/sin tables for RoPE.
Returns cos, sin each of shape (1, 1, max_positions, head_dim).
"""
inv_freq = 1.0 / (
theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim)
)
positions = torch.arange(max_positions, dtype=torch.float32)
freqs = torch.outer(positions, inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
cos = emb.cos().unsqueeze(0).unsqueeze(0) # (1, 1, max_pos, head_dim)
sin = emb.sin().unsqueeze(0).unsqueeze(0)
return cos, sin
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
"""Split-half rotation matching HF Llama convention.
head_dim is always 64, so we hardcode the split at 32 to avoid
dynamic size ops that coremltools cannot convert.
"""
x1 = x[..., :32]
x2 = x[..., 32:]
return torch.cat((-x2, x1), dim=-1)
class LlamaMLP(nn.Module):
def __init__(self, hidden_size: int = 576, intermediate_size: int = 1536):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
# ── Decode variant (1 token, broadcast-mask cache write) ──────────────────
class LlamaAttentionDecode(nn.Module):
"""Attention for decode: processes 1 token, writes cache at current_pos via scatter."""
def __init__(
self,
hidden_size: int = 576,
num_heads: int = 9,
num_kv_heads: int = 3,
head_dim: int = 64,
max_context: int = 2048,
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.num_kv_groups = num_heads // num_kv_heads
self.scale = head_dim ** -0.5
self.max_context = max_context
self.q_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False)
self.k_proj = nn.Linear(hidden_size, num_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(hidden_size, num_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=False)
def forward(
self,
hidden_states: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
causal_mask: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
update_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
hidden_states: (1, 1, hidden_size)
cos: (1, 1, 1, head_dim) β€” pre-sliced for current_pos
sin: (1, 1, 1, head_dim)
causal_mask: (1, 1, 1, max_ctx)
k_cache, v_cache: (1, num_kv_heads, max_ctx, head_dim)
update_mask: (1, 1, max_ctx, 1) β€” one-hot float mask for current_pos
"""
q = self.q_proj(hidden_states).view(1, 1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(hidden_states).view(1, 1, self.num_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(hidden_states).view(1, 1, self.num_kv_heads, self.head_dim).transpose(1, 2)
# Apply RoPE
q = (q * cos) + (_rotate_half(q) * sin)
k = (k * cos) + (_rotate_half(k) * sin)
# Write to cache via broadcast multiply with one-hot mask
# update_mask is (1, 1, max_ctx, 1) with 1.0 at current_pos, 0.0 elsewhere
k_cache = k_cache * (1.0 - update_mask) + k * update_mask
v_cache = v_cache * (1.0 - update_mask) + v * update_mask
# GQA expand and attend
k_full = k_cache.repeat_interleave(self.num_kv_groups, dim=1)
v_full = v_cache.repeat_interleave(self.num_kv_groups, dim=1)
attn_weights = torch.matmul(q, k_full.transpose(2, 3)) * self.scale
attn_weights = attn_weights + causal_mask
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32)
# Zero out NaN from all-inf mask rows (same fix as prefill attention)
attn_weights = attn_weights.nan_to_num(0.0).to(q.dtype)
attn_output = torch.matmul(attn_weights, v_full)
attn_output = attn_output.transpose(1, 2).contiguous().reshape(1, 1, self.num_heads * self.head_dim)
return self.o_proj(attn_output), k_cache, v_cache
class LlamaDecoderLayerDecode(nn.Module):
def __init__(
self,
hidden_size: int = 576,
num_heads: int = 9,
num_kv_heads: int = 3,
head_dim: int = 64,
intermediate_size: int = 1536,
rms_norm_eps: float = 1e-5,
max_context: int = 2048,
):
super().__init__()
self.self_attn = LlamaAttentionDecode(
hidden_size, num_heads, num_kv_heads, head_dim, max_context,
)
self.mlp = LlamaMLP(hidden_size, intermediate_size)
self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps)
def forward(self, hidden_states, cos, sin, causal_mask, k_cache, v_cache, update_mask):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, k_cache, v_cache = self.self_attn(
hidden_states, cos, sin, causal_mask, k_cache, v_cache, update_mask,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states, k_cache, v_cache