| """ |
| CoreML-compatible replacements for HuggingFace LlamaForCausalLM building blocks. |
| |
| HF's SDPA attention and dynamic RoPE are not traceable by torch.jit.trace / coremltools. |
| This module provides static, explicit implementations that produce identical outputs. |
| |
| The decode attention processes 1 token per step and writes to the KV cache using a |
| broadcast one-hot mask: k_cache * (1 - mask) + k * mask. |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class LlamaRMSNorm(nn.Module): |
| def __init__(self, hidden_size: int, eps: float = 1e-6): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.register_buffer("eps", torch.tensor(eps, dtype=torch.float32)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| |
| |
| |
| scale = x.abs().amax(-1, keepdim=True).clamp(min=1.0) |
| x_scaled = x / scale |
| variance = x_scaled.pow(2).mean(-1, keepdim=True) |
| x_norm = x_scaled * torch.rsqrt(variance + self.eps) |
| return self.weight * x_norm |
|
|
|
|
| def precompute_rope_frequencies( |
| head_dim: int, max_positions: int, theta: float = 100000.0 |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """Precompute cos/sin tables for RoPE. |
| |
| Returns cos, sin each of shape (1, 1, max_positions, head_dim). |
| """ |
| inv_freq = 1.0 / ( |
| theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim) |
| ) |
| positions = torch.arange(max_positions, dtype=torch.float32) |
| freqs = torch.outer(positions, inv_freq) |
| emb = torch.cat([freqs, freqs], dim=-1) |
| cos = emb.cos().unsqueeze(0).unsqueeze(0) |
| sin = emb.sin().unsqueeze(0).unsqueeze(0) |
| return cos, sin |
|
|
|
|
| def _rotate_half(x: torch.Tensor) -> torch.Tensor: |
| """Split-half rotation matching HF Llama convention. |
| head_dim is always 64, so we hardcode the split at 32 to avoid |
| dynamic size ops that coremltools cannot convert. |
| """ |
| x1 = x[..., :32] |
| x2 = x[..., 32:] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| class LlamaMLP(nn.Module): |
| def __init__(self, hidden_size: int = 576, intermediate_size: int = 1536): |
| super().__init__() |
| self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) |
| self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) |
| self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
|
| |
|
|
|
|
| class LlamaAttentionDecode(nn.Module): |
| """Attention for decode: processes 1 token, writes cache at current_pos via scatter.""" |
|
|
| def __init__( |
| self, |
| hidden_size: int = 576, |
| num_heads: int = 9, |
| num_kv_heads: int = 3, |
| head_dim: int = 64, |
| max_context: int = 2048, |
| ): |
| super().__init__() |
| self.hidden_size = hidden_size |
| self.num_heads = num_heads |
| self.num_kv_heads = num_kv_heads |
| self.head_dim = head_dim |
| self.num_kv_groups = num_heads // num_kv_heads |
| self.scale = head_dim ** -0.5 |
| self.max_context = max_context |
|
|
| self.q_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False) |
| self.k_proj = nn.Linear(hidden_size, num_kv_heads * head_dim, bias=False) |
| self.v_proj = nn.Linear(hidden_size, num_kv_heads * head_dim, bias=False) |
| self.o_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=False) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| cos: torch.Tensor, |
| sin: torch.Tensor, |
| causal_mask: torch.Tensor, |
| k_cache: torch.Tensor, |
| v_cache: torch.Tensor, |
| update_mask: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| Args: |
| hidden_states: (1, 1, hidden_size) |
| cos: (1, 1, 1, head_dim) β pre-sliced for current_pos |
| sin: (1, 1, 1, head_dim) |
| causal_mask: (1, 1, 1, max_ctx) |
| k_cache, v_cache: (1, num_kv_heads, max_ctx, head_dim) |
| update_mask: (1, 1, max_ctx, 1) β one-hot float mask for current_pos |
| """ |
| q = self.q_proj(hidden_states).view(1, 1, self.num_heads, self.head_dim).transpose(1, 2) |
| k = self.k_proj(hidden_states).view(1, 1, self.num_kv_heads, self.head_dim).transpose(1, 2) |
| v = self.v_proj(hidden_states).view(1, 1, self.num_kv_heads, self.head_dim).transpose(1, 2) |
|
|
| |
| q = (q * cos) + (_rotate_half(q) * sin) |
| k = (k * cos) + (_rotate_half(k) * sin) |
|
|
| |
| |
| k_cache = k_cache * (1.0 - update_mask) + k * update_mask |
| v_cache = v_cache * (1.0 - update_mask) + v * update_mask |
|
|
| |
| k_full = k_cache.repeat_interleave(self.num_kv_groups, dim=1) |
| v_full = v_cache.repeat_interleave(self.num_kv_groups, dim=1) |
|
|
| attn_weights = torch.matmul(q, k_full.transpose(2, 3)) * self.scale |
| attn_weights = attn_weights + causal_mask |
| attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32) |
| |
| attn_weights = attn_weights.nan_to_num(0.0).to(q.dtype) |
| attn_output = torch.matmul(attn_weights, v_full) |
|
|
| attn_output = attn_output.transpose(1, 2).contiguous().reshape(1, 1, self.num_heads * self.head_dim) |
| return self.o_proj(attn_output), k_cache, v_cache |
|
|
|
|
| class LlamaDecoderLayerDecode(nn.Module): |
| def __init__( |
| self, |
| hidden_size: int = 576, |
| num_heads: int = 9, |
| num_kv_heads: int = 3, |
| head_dim: int = 64, |
| intermediate_size: int = 1536, |
| rms_norm_eps: float = 1e-5, |
| max_context: int = 2048, |
| ): |
| super().__init__() |
| self.self_attn = LlamaAttentionDecode( |
| hidden_size, num_heads, num_kv_heads, head_dim, max_context, |
| ) |
| self.mlp = LlamaMLP(hidden_size, intermediate_size) |
| self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps) |
| self.post_attention_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps) |
|
|
| def forward(self, hidden_states, cos, sin, causal_mask, k_cache, v_cache, update_mask): |
| residual = hidden_states |
| hidden_states = self.input_layernorm(hidden_states) |
| hidden_states, k_cache, v_cache = self.self_attn( |
| hidden_states, cos, sin, causal_mask, k_cache, v_cache, update_mask, |
| ) |
| hidden_states = residual + hidden_states |
|
|
| residual = hidden_states |
| hidden_states = self.post_attention_layernorm(hidden_states) |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
|
|
| return hidden_states, k_cache, v_cache |
|
|