""" CoreML-compatible replacements for HuggingFace LlamaForCausalLM building blocks. HF's SDPA attention and dynamic RoPE are not traceable by torch.jit.trace / coremltools. This module provides static, explicit implementations that produce identical outputs. The decode attention processes 1 token per step and writes to the KV cache using a broadcast one-hot mask: k_cache * (1 - mask) + k * mask. """ import math import torch import torch.nn as nn import torch.nn.functional as F class LlamaRMSNorm(nn.Module): def __init__(self, hidden_size: int, eps: float = 1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.register_buffer("eps", torch.tensor(eps, dtype=torch.float32)) def forward(self, x: torch.Tensor) -> torch.Tensor: # fp16-safe RMSNorm: pre-scale to avoid overflow in x.pow(2) # fp16 max is 65504, so values > 256 overflow when squared. # Scale down by max abs value, compute norm, scale back. # Math: (x/s) / sqrt(mean((x/s)^2)) = x / sqrt(mean(x^2)) — s cancels. scale = x.abs().amax(-1, keepdim=True).clamp(min=1.0) x_scaled = x / scale variance = x_scaled.pow(2).mean(-1, keepdim=True) x_norm = x_scaled * torch.rsqrt(variance + self.eps) return self.weight * x_norm def precompute_rope_frequencies( head_dim: int, max_positions: int, theta: float = 100000.0 ) -> tuple[torch.Tensor, torch.Tensor]: """Precompute cos/sin tables for RoPE. Returns cos, sin each of shape (1, 1, max_positions, head_dim). """ inv_freq = 1.0 / ( theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim) ) positions = torch.arange(max_positions, dtype=torch.float32) freqs = torch.outer(positions, inv_freq) emb = torch.cat([freqs, freqs], dim=-1) cos = emb.cos().unsqueeze(0).unsqueeze(0) # (1, 1, max_pos, head_dim) sin = emb.sin().unsqueeze(0).unsqueeze(0) return cos, sin def _rotate_half(x: torch.Tensor) -> torch.Tensor: """Split-half rotation matching HF Llama convention. head_dim is always 64, so we hardcode the split at 32 to avoid dynamic size ops that coremltools cannot convert. """ x1 = x[..., :32] x2 = x[..., 32:] return torch.cat((-x2, x1), dim=-1) class LlamaMLP(nn.Module): def __init__(self, hidden_size: int = 576, intermediate_size: int = 1536): super().__init__() self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) # ── Decode variant (1 token, broadcast-mask cache write) ────────────────── class LlamaAttentionDecode(nn.Module): """Attention for decode: processes 1 token, writes cache at current_pos via scatter.""" def __init__( self, hidden_size: int = 576, num_heads: int = 9, num_kv_heads: int = 3, head_dim: int = 64, max_context: int = 2048, ): super().__init__() self.hidden_size = hidden_size self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.head_dim = head_dim self.num_kv_groups = num_heads // num_kv_heads self.scale = head_dim ** -0.5 self.max_context = max_context self.q_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False) self.k_proj = nn.Linear(hidden_size, num_kv_heads * head_dim, bias=False) self.v_proj = nn.Linear(hidden_size, num_kv_heads * head_dim, bias=False) self.o_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=False) def forward( self, hidden_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, causal_mask: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, update_mask: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: hidden_states: (1, 1, hidden_size) cos: (1, 1, 1, head_dim) — pre-sliced for current_pos sin: (1, 1, 1, head_dim) causal_mask: (1, 1, 1, max_ctx) k_cache, v_cache: (1, num_kv_heads, max_ctx, head_dim) update_mask: (1, 1, max_ctx, 1) — one-hot float mask for current_pos """ q = self.q_proj(hidden_states).view(1, 1, self.num_heads, self.head_dim).transpose(1, 2) k = self.k_proj(hidden_states).view(1, 1, self.num_kv_heads, self.head_dim).transpose(1, 2) v = self.v_proj(hidden_states).view(1, 1, self.num_kv_heads, self.head_dim).transpose(1, 2) # Apply RoPE q = (q * cos) + (_rotate_half(q) * sin) k = (k * cos) + (_rotate_half(k) * sin) # Write to cache via broadcast multiply with one-hot mask # update_mask is (1, 1, max_ctx, 1) with 1.0 at current_pos, 0.0 elsewhere k_cache = k_cache * (1.0 - update_mask) + k * update_mask v_cache = v_cache * (1.0 - update_mask) + v * update_mask # GQA expand and attend k_full = k_cache.repeat_interleave(self.num_kv_groups, dim=1) v_full = v_cache.repeat_interleave(self.num_kv_groups, dim=1) attn_weights = torch.matmul(q, k_full.transpose(2, 3)) * self.scale attn_weights = attn_weights + causal_mask attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # Zero out NaN from all-inf mask rows (same fix as prefill attention) attn_weights = attn_weights.nan_to_num(0.0).to(q.dtype) attn_output = torch.matmul(attn_weights, v_full) attn_output = attn_output.transpose(1, 2).contiguous().reshape(1, 1, self.num_heads * self.head_dim) return self.o_proj(attn_output), k_cache, v_cache class LlamaDecoderLayerDecode(nn.Module): def __init__( self, hidden_size: int = 576, num_heads: int = 9, num_kv_heads: int = 3, head_dim: int = 64, intermediate_size: int = 1536, rms_norm_eps: float = 1e-5, max_context: int = 2048, ): super().__init__() self.self_attn = LlamaAttentionDecode( hidden_size, num_heads, num_kv_heads, head_dim, max_context, ) self.mlp = LlamaMLP(hidden_size, intermediate_size) self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps) def forward(self, hidden_states, cos, sin, causal_mask, k_cache, v_cache, update_mask): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states, k_cache, v_cache = self.self_attn( hidden_states, cos, sin, causal_mask, k_cache, v_cache, update_mask, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states, k_cache, v_cache