File size: 7,343 Bytes
cb20bed
 
 
 
 
 
95c6137
 
cb20bed
 
 
 
 
 
 
 
 
 
 
 
d1bfb8c
cb20bed
 
ffa94c8
 
 
 
 
 
 
 
 
cb20bed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95c6137
cb20bed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1bfb8c
cb20bed
 
 
 
 
 
 
 
 
 
d1bfb8c
 
 
cb20bed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
"""
CoreML-compatible replacements for HuggingFace LlamaForCausalLM building blocks.

HF's SDPA attention and dynamic RoPE are not traceable by torch.jit.trace / coremltools.
This module provides static, explicit implementations that produce identical outputs.

The decode attention processes 1 token per step and writes to the KV cache using a
broadcast one-hot mask: k_cache * (1 - mask) + k * mask.
"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size: int, eps: float = 1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.register_buffer("eps", torch.tensor(eps, dtype=torch.float32))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # fp16-safe RMSNorm: pre-scale to avoid overflow in x.pow(2)
        # fp16 max is 65504, so values > 256 overflow when squared.
        # Scale down by max abs value, compute norm, scale back.
        # Math: (x/s) / sqrt(mean((x/s)^2)) = x / sqrt(mean(x^2)) β€” s cancels.
        scale = x.abs().amax(-1, keepdim=True).clamp(min=1.0)
        x_scaled = x / scale
        variance = x_scaled.pow(2).mean(-1, keepdim=True)
        x_norm = x_scaled * torch.rsqrt(variance + self.eps)
        return self.weight * x_norm


def precompute_rope_frequencies(
    head_dim: int, max_positions: int, theta: float = 100000.0
) -> tuple[torch.Tensor, torch.Tensor]:
    """Precompute cos/sin tables for RoPE.

    Returns cos, sin each of shape (1, 1, max_positions, head_dim).
    """
    inv_freq = 1.0 / (
        theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim)
    )
    positions = torch.arange(max_positions, dtype=torch.float32)
    freqs = torch.outer(positions, inv_freq)
    emb = torch.cat([freqs, freqs], dim=-1)
    cos = emb.cos().unsqueeze(0).unsqueeze(0)  # (1, 1, max_pos, head_dim)
    sin = emb.sin().unsqueeze(0).unsqueeze(0)
    return cos, sin


def _rotate_half(x: torch.Tensor) -> torch.Tensor:
    """Split-half rotation matching HF Llama convention.
    head_dim is always 64, so we hardcode the split at 32 to avoid
    dynamic size ops that coremltools cannot convert.
    """
    x1 = x[..., :32]
    x2 = x[..., 32:]
    return torch.cat((-x2, x1), dim=-1)


class LlamaMLP(nn.Module):
    def __init__(self, hidden_size: int = 576, intermediate_size: int = 1536):
        super().__init__()
        self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))


# ── Decode variant (1 token, broadcast-mask cache write) ──────────────────


class LlamaAttentionDecode(nn.Module):
    """Attention for decode: processes 1 token, writes cache at current_pos via scatter."""

    def __init__(
        self,
        hidden_size: int = 576,
        num_heads: int = 9,
        num_kv_heads: int = 3,
        head_dim: int = 64,
        max_context: int = 2048,
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.head_dim = head_dim
        self.num_kv_groups = num_heads // num_kv_heads
        self.scale = head_dim ** -0.5
        self.max_context = max_context

        self.q_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False)
        self.k_proj = nn.Linear(hidden_size, num_kv_heads * head_dim, bias=False)
        self.v_proj = nn.Linear(hidden_size, num_kv_heads * head_dim, bias=False)
        self.o_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=False)

    def forward(
        self,
        hidden_states: torch.Tensor,
        cos: torch.Tensor,
        sin: torch.Tensor,
        causal_mask: torch.Tensor,
        k_cache: torch.Tensor,
        v_cache: torch.Tensor,
        update_mask: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Args:
            hidden_states: (1, 1, hidden_size)
            cos: (1, 1, 1, head_dim) β€” pre-sliced for current_pos
            sin: (1, 1, 1, head_dim)
            causal_mask: (1, 1, 1, max_ctx)
            k_cache, v_cache: (1, num_kv_heads, max_ctx, head_dim)
            update_mask: (1, 1, max_ctx, 1) β€” one-hot float mask for current_pos
        """
        q = self.q_proj(hidden_states).view(1, 1, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(hidden_states).view(1, 1, self.num_kv_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(hidden_states).view(1, 1, self.num_kv_heads, self.head_dim).transpose(1, 2)

        # Apply RoPE
        q = (q * cos) + (_rotate_half(q) * sin)
        k = (k * cos) + (_rotate_half(k) * sin)

        # Write to cache via broadcast multiply with one-hot mask
        # update_mask is (1, 1, max_ctx, 1) with 1.0 at current_pos, 0.0 elsewhere
        k_cache = k_cache * (1.0 - update_mask) + k * update_mask
        v_cache = v_cache * (1.0 - update_mask) + v * update_mask

        # GQA expand and attend
        k_full = k_cache.repeat_interleave(self.num_kv_groups, dim=1)
        v_full = v_cache.repeat_interleave(self.num_kv_groups, dim=1)

        attn_weights = torch.matmul(q, k_full.transpose(2, 3)) * self.scale
        attn_weights = attn_weights + causal_mask
        attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32)
        # Zero out NaN from all-inf mask rows (same fix as prefill attention)
        attn_weights = attn_weights.nan_to_num(0.0).to(q.dtype)
        attn_output = torch.matmul(attn_weights, v_full)

        attn_output = attn_output.transpose(1, 2).contiguous().reshape(1, 1, self.num_heads * self.head_dim)
        return self.o_proj(attn_output), k_cache, v_cache


class LlamaDecoderLayerDecode(nn.Module):
    def __init__(
        self,
        hidden_size: int = 576,
        num_heads: int = 9,
        num_kv_heads: int = 3,
        head_dim: int = 64,
        intermediate_size: int = 1536,
        rms_norm_eps: float = 1e-5,
        max_context: int = 2048,
    ):
        super().__init__()
        self.self_attn = LlamaAttentionDecode(
            hidden_size, num_heads, num_kv_heads, head_dim, max_context,
        )
        self.mlp = LlamaMLP(hidden_size, intermediate_size)
        self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps)
        self.post_attention_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps)

    def forward(self, hidden_states, cos, sin, causal_mask, k_cache, v_cache, update_mask):
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states, k_cache, v_cache = self.self_attn(
            hidden_states, cos, sin, causal_mask, k_cache, v_cache, update_mask,
        )
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states, k_cache, v_cache