File size: 7,343 Bytes
cb20bed 95c6137 cb20bed d1bfb8c cb20bed ffa94c8 cb20bed 95c6137 cb20bed d1bfb8c cb20bed d1bfb8c cb20bed | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 | """
CoreML-compatible replacements for HuggingFace LlamaForCausalLM building blocks.
HF's SDPA attention and dynamic RoPE are not traceable by torch.jit.trace / coremltools.
This module provides static, explicit implementations that produce identical outputs.
The decode attention processes 1 token per step and writes to the KV cache using a
broadcast one-hot mask: k_cache * (1 - mask) + k * mask.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.register_buffer("eps", torch.tensor(eps, dtype=torch.float32))
def forward(self, x: torch.Tensor) -> torch.Tensor:
# fp16-safe RMSNorm: pre-scale to avoid overflow in x.pow(2)
# fp16 max is 65504, so values > 256 overflow when squared.
# Scale down by max abs value, compute norm, scale back.
# Math: (x/s) / sqrt(mean((x/s)^2)) = x / sqrt(mean(x^2)) β s cancels.
scale = x.abs().amax(-1, keepdim=True).clamp(min=1.0)
x_scaled = x / scale
variance = x_scaled.pow(2).mean(-1, keepdim=True)
x_norm = x_scaled * torch.rsqrt(variance + self.eps)
return self.weight * x_norm
def precompute_rope_frequencies(
head_dim: int, max_positions: int, theta: float = 100000.0
) -> tuple[torch.Tensor, torch.Tensor]:
"""Precompute cos/sin tables for RoPE.
Returns cos, sin each of shape (1, 1, max_positions, head_dim).
"""
inv_freq = 1.0 / (
theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim)
)
positions = torch.arange(max_positions, dtype=torch.float32)
freqs = torch.outer(positions, inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
cos = emb.cos().unsqueeze(0).unsqueeze(0) # (1, 1, max_pos, head_dim)
sin = emb.sin().unsqueeze(0).unsqueeze(0)
return cos, sin
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
"""Split-half rotation matching HF Llama convention.
head_dim is always 64, so we hardcode the split at 32 to avoid
dynamic size ops that coremltools cannot convert.
"""
x1 = x[..., :32]
x2 = x[..., 32:]
return torch.cat((-x2, x1), dim=-1)
class LlamaMLP(nn.Module):
def __init__(self, hidden_size: int = 576, intermediate_size: int = 1536):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
# ββ Decode variant (1 token, broadcast-mask cache write) ββββββββββββββββββ
class LlamaAttentionDecode(nn.Module):
"""Attention for decode: processes 1 token, writes cache at current_pos via scatter."""
def __init__(
self,
hidden_size: int = 576,
num_heads: int = 9,
num_kv_heads: int = 3,
head_dim: int = 64,
max_context: int = 2048,
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.num_kv_groups = num_heads // num_kv_heads
self.scale = head_dim ** -0.5
self.max_context = max_context
self.q_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False)
self.k_proj = nn.Linear(hidden_size, num_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(hidden_size, num_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=False)
def forward(
self,
hidden_states: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
causal_mask: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
update_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
hidden_states: (1, 1, hidden_size)
cos: (1, 1, 1, head_dim) β pre-sliced for current_pos
sin: (1, 1, 1, head_dim)
causal_mask: (1, 1, 1, max_ctx)
k_cache, v_cache: (1, num_kv_heads, max_ctx, head_dim)
update_mask: (1, 1, max_ctx, 1) β one-hot float mask for current_pos
"""
q = self.q_proj(hidden_states).view(1, 1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(hidden_states).view(1, 1, self.num_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(hidden_states).view(1, 1, self.num_kv_heads, self.head_dim).transpose(1, 2)
# Apply RoPE
q = (q * cos) + (_rotate_half(q) * sin)
k = (k * cos) + (_rotate_half(k) * sin)
# Write to cache via broadcast multiply with one-hot mask
# update_mask is (1, 1, max_ctx, 1) with 1.0 at current_pos, 0.0 elsewhere
k_cache = k_cache * (1.0 - update_mask) + k * update_mask
v_cache = v_cache * (1.0 - update_mask) + v * update_mask
# GQA expand and attend
k_full = k_cache.repeat_interleave(self.num_kv_groups, dim=1)
v_full = v_cache.repeat_interleave(self.num_kv_groups, dim=1)
attn_weights = torch.matmul(q, k_full.transpose(2, 3)) * self.scale
attn_weights = attn_weights + causal_mask
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32)
# Zero out NaN from all-inf mask rows (same fix as prefill attention)
attn_weights = attn_weights.nan_to_num(0.0).to(q.dtype)
attn_output = torch.matmul(attn_weights, v_full)
attn_output = attn_output.transpose(1, 2).contiguous().reshape(1, 1, self.num_heads * self.head_dim)
return self.o_proj(attn_output), k_cache, v_cache
class LlamaDecoderLayerDecode(nn.Module):
def __init__(
self,
hidden_size: int = 576,
num_heads: int = 9,
num_kv_heads: int = 3,
head_dim: int = 64,
intermediate_size: int = 1536,
rms_norm_eps: float = 1e-5,
max_context: int = 2048,
):
super().__init__()
self.self_attn = LlamaAttentionDecode(
hidden_size, num_heads, num_kv_heads, head_dim, max_context,
)
self.mlp = LlamaMLP(hidden_size, intermediate_size)
self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps)
def forward(self, hidden_states, cos, sin, causal_mask, k_cache, v_cache, update_mask):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, k_cache, v_cache = self.self_attn(
hidden_states, cos, sin, causal_mask, k_cache, v_cache, update_mask,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states, k_cache, v_cache
|