| """ |
| Prefill and Decode wrapper models for CoreML conversion. |
| |
| Key design: all dynamic indexing is eliminated. Positions are encoded into |
| cos/sin/mask inputs by the caller, not computed inside the model. This makes |
| the models fully traceable by torch.jit.trace and convertible by coremltools. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from attention import ( |
| LlamaDecoderLayerDecode, |
| LlamaRMSNorm, |
| precompute_rope_frequencies, |
| ) |
|
|
| |
| NUM_LAYERS = 30 |
| HIDDEN_SIZE = 576 |
| NUM_HEADS = 9 |
| NUM_KV_HEADS = 3 |
| HEAD_DIM = 64 |
| INTERMEDIATE_SIZE = 1536 |
| VOCAB_SIZE = 20802 |
| RMS_NORM_EPS = 1e-5 |
| ROPE_THETA = 100000.0 |
| MAX_CONTEXT = 512 |
| PREFILL_SEQ_LEN = 512 |
| SPEAKER_DIM = 128 |
|
|
|
|
| class PlaprePico(nn.Module): |
| """Generates one token at a time using the KV cache. |
| |
| Also used for token-by-token prefill. Speaker conditioning is handled |
| internally: at position 0, pass is_speaker_step=1.0 and the raw |
| speaker_embedding. The model projects it and replaces the token embedding. |
| |
| Inputs: |
| input_ids: (1, 1) int32 |
| causal_mask: (1, 1, 1, 2048) float16 — 0 or -inf |
| cos: (1, 1, 1, 64) float16 — RoPE cos for current position |
| sin: (1, 1, 1, 64) float16 — RoPE sin for current position |
| update_mask: (1, 1, 2048, 1) float16 — one-hot at current position |
| speaker_embedding: (1, 128) float16 — raw speaker embedding (used at position 0) |
| is_speaker_step: (1,) float16 — 1.0 at position 0, 0.0 otherwise |
| |
| State buffers: |
| k_cache_0..29, v_cache_0..29: (1, 3, 2048, 64) float16 |
| |
| Output: |
| logits: (1, 1, 20802) float16 |
| """ |
|
|
| def __init__(self): |
| super().__init__() |
|
|
| self.embed_tokens = nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) |
| self.speaker_proj = nn.Linear(SPEAKER_DIM, HIDDEN_SIZE, bias=True) |
|
|
| self.layers = nn.ModuleList( |
| [ |
| LlamaDecoderLayerDecode( |
| hidden_size=HIDDEN_SIZE, |
| num_heads=NUM_HEADS, |
| num_kv_heads=NUM_KV_HEADS, |
| head_dim=HEAD_DIM, |
| intermediate_size=INTERMEDIATE_SIZE, |
| rms_norm_eps=RMS_NORM_EPS, |
| max_context=MAX_CONTEXT, |
| ) |
| for _ in range(NUM_LAYERS) |
| ] |
| ) |
| self.norm = LlamaRMSNorm(HIDDEN_SIZE, eps=RMS_NORM_EPS) |
|
|
| |
| for i in range(NUM_LAYERS): |
| self.register_buffer( |
| f"k_cache_{i}", |
| torch.zeros(1, NUM_KV_HEADS, MAX_CONTEXT, HEAD_DIM, dtype=torch.float16), |
| ) |
| self.register_buffer( |
| f"v_cache_{i}", |
| torch.zeros(1, NUM_KV_HEADS, MAX_CONTEXT, HEAD_DIM, dtype=torch.float16), |
| ) |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| causal_mask: torch.Tensor, |
| cos: torch.Tensor, |
| sin: torch.Tensor, |
| update_mask: torch.Tensor, |
| speaker_embedding: torch.Tensor, |
| is_speaker_step: torch.Tensor, |
| ) -> torch.Tensor: |
| hidden = self.embed_tokens(input_ids) |
| |
| |
| |
| |
| spk_proj = self.speaker_proj(speaker_embedding).unsqueeze(1) |
| flag = is_speaker_step.view(1, 1, 1) |
| hidden = hidden * (1.0 - flag) + spk_proj * flag |
|
|
| |
| |
| |
| hidden, new_k_0, new_v_0 = self.layers[0](hidden, cos, sin, causal_mask, self.k_cache_0, self.v_cache_0, update_mask) |
| self.k_cache_0 += (new_k_0 - self.k_cache_0) |
| self.v_cache_0 += (new_v_0 - self.v_cache_0) |
| hidden, new_k_1, new_v_1 = self.layers[1](hidden, cos, sin, causal_mask, self.k_cache_1, self.v_cache_1, update_mask) |
| self.k_cache_1 += (new_k_1 - self.k_cache_1) |
| self.v_cache_1 += (new_v_1 - self.v_cache_1) |
| hidden, new_k_2, new_v_2 = self.layers[2](hidden, cos, sin, causal_mask, self.k_cache_2, self.v_cache_2, update_mask) |
| self.k_cache_2 += (new_k_2 - self.k_cache_2) |
| self.v_cache_2 += (new_v_2 - self.v_cache_2) |
| hidden, new_k_3, new_v_3 = self.layers[3](hidden, cos, sin, causal_mask, self.k_cache_3, self.v_cache_3, update_mask) |
| self.k_cache_3 += (new_k_3 - self.k_cache_3) |
| self.v_cache_3 += (new_v_3 - self.v_cache_3) |
| hidden, new_k_4, new_v_4 = self.layers[4](hidden, cos, sin, causal_mask, self.k_cache_4, self.v_cache_4, update_mask) |
| self.k_cache_4 += (new_k_4 - self.k_cache_4) |
| self.v_cache_4 += (new_v_4 - self.v_cache_4) |
| hidden, new_k_5, new_v_5 = self.layers[5](hidden, cos, sin, causal_mask, self.k_cache_5, self.v_cache_5, update_mask) |
| self.k_cache_5 += (new_k_5 - self.k_cache_5) |
| self.v_cache_5 += (new_v_5 - self.v_cache_5) |
| hidden, new_k_6, new_v_6 = self.layers[6](hidden, cos, sin, causal_mask, self.k_cache_6, self.v_cache_6, update_mask) |
| self.k_cache_6 += (new_k_6 - self.k_cache_6) |
| self.v_cache_6 += (new_v_6 - self.v_cache_6) |
| hidden, new_k_7, new_v_7 = self.layers[7](hidden, cos, sin, causal_mask, self.k_cache_7, self.v_cache_7, update_mask) |
| self.k_cache_7 += (new_k_7 - self.k_cache_7) |
| self.v_cache_7 += (new_v_7 - self.v_cache_7) |
| hidden, new_k_8, new_v_8 = self.layers[8](hidden, cos, sin, causal_mask, self.k_cache_8, self.v_cache_8, update_mask) |
| self.k_cache_8 += (new_k_8 - self.k_cache_8) |
| self.v_cache_8 += (new_v_8 - self.v_cache_8) |
| hidden, new_k_9, new_v_9 = self.layers[9](hidden, cos, sin, causal_mask, self.k_cache_9, self.v_cache_9, update_mask) |
| self.k_cache_9 += (new_k_9 - self.k_cache_9) |
| self.v_cache_9 += (new_v_9 - self.v_cache_9) |
| hidden, new_k_10, new_v_10 = self.layers[10](hidden, cos, sin, causal_mask, self.k_cache_10, self.v_cache_10, update_mask) |
| self.k_cache_10 += (new_k_10 - self.k_cache_10) |
| self.v_cache_10 += (new_v_10 - self.v_cache_10) |
| hidden, new_k_11, new_v_11 = self.layers[11](hidden, cos, sin, causal_mask, self.k_cache_11, self.v_cache_11, update_mask) |
| self.k_cache_11 += (new_k_11 - self.k_cache_11) |
| self.v_cache_11 += (new_v_11 - self.v_cache_11) |
| hidden, new_k_12, new_v_12 = self.layers[12](hidden, cos, sin, causal_mask, self.k_cache_12, self.v_cache_12, update_mask) |
| self.k_cache_12 += (new_k_12 - self.k_cache_12) |
| self.v_cache_12 += (new_v_12 - self.v_cache_12) |
| hidden, new_k_13, new_v_13 = self.layers[13](hidden, cos, sin, causal_mask, self.k_cache_13, self.v_cache_13, update_mask) |
| self.k_cache_13 += (new_k_13 - self.k_cache_13) |
| self.v_cache_13 += (new_v_13 - self.v_cache_13) |
| hidden, new_k_14, new_v_14 = self.layers[14](hidden, cos, sin, causal_mask, self.k_cache_14, self.v_cache_14, update_mask) |
| self.k_cache_14 += (new_k_14 - self.k_cache_14) |
| self.v_cache_14 += (new_v_14 - self.v_cache_14) |
| hidden, new_k_15, new_v_15 = self.layers[15](hidden, cos, sin, causal_mask, self.k_cache_15, self.v_cache_15, update_mask) |
| self.k_cache_15 += (new_k_15 - self.k_cache_15) |
| self.v_cache_15 += (new_v_15 - self.v_cache_15) |
| hidden, new_k_16, new_v_16 = self.layers[16](hidden, cos, sin, causal_mask, self.k_cache_16, self.v_cache_16, update_mask) |
| self.k_cache_16 += (new_k_16 - self.k_cache_16) |
| self.v_cache_16 += (new_v_16 - self.v_cache_16) |
| hidden, new_k_17, new_v_17 = self.layers[17](hidden, cos, sin, causal_mask, self.k_cache_17, self.v_cache_17, update_mask) |
| self.k_cache_17 += (new_k_17 - self.k_cache_17) |
| self.v_cache_17 += (new_v_17 - self.v_cache_17) |
| hidden, new_k_18, new_v_18 = self.layers[18](hidden, cos, sin, causal_mask, self.k_cache_18, self.v_cache_18, update_mask) |
| self.k_cache_18 += (new_k_18 - self.k_cache_18) |
| self.v_cache_18 += (new_v_18 - self.v_cache_18) |
| hidden, new_k_19, new_v_19 = self.layers[19](hidden, cos, sin, causal_mask, self.k_cache_19, self.v_cache_19, update_mask) |
| self.k_cache_19 += (new_k_19 - self.k_cache_19) |
| self.v_cache_19 += (new_v_19 - self.v_cache_19) |
| hidden, new_k_20, new_v_20 = self.layers[20](hidden, cos, sin, causal_mask, self.k_cache_20, self.v_cache_20, update_mask) |
| self.k_cache_20 += (new_k_20 - self.k_cache_20) |
| self.v_cache_20 += (new_v_20 - self.v_cache_20) |
| hidden, new_k_21, new_v_21 = self.layers[21](hidden, cos, sin, causal_mask, self.k_cache_21, self.v_cache_21, update_mask) |
| self.k_cache_21 += (new_k_21 - self.k_cache_21) |
| self.v_cache_21 += (new_v_21 - self.v_cache_21) |
| hidden, new_k_22, new_v_22 = self.layers[22](hidden, cos, sin, causal_mask, self.k_cache_22, self.v_cache_22, update_mask) |
| self.k_cache_22 += (new_k_22 - self.k_cache_22) |
| self.v_cache_22 += (new_v_22 - self.v_cache_22) |
| hidden, new_k_23, new_v_23 = self.layers[23](hidden, cos, sin, causal_mask, self.k_cache_23, self.v_cache_23, update_mask) |
| self.k_cache_23 += (new_k_23 - self.k_cache_23) |
| self.v_cache_23 += (new_v_23 - self.v_cache_23) |
| hidden, new_k_24, new_v_24 = self.layers[24](hidden, cos, sin, causal_mask, self.k_cache_24, self.v_cache_24, update_mask) |
| self.k_cache_24 += (new_k_24 - self.k_cache_24) |
| self.v_cache_24 += (new_v_24 - self.v_cache_24) |
| hidden, new_k_25, new_v_25 = self.layers[25](hidden, cos, sin, causal_mask, self.k_cache_25, self.v_cache_25, update_mask) |
| self.k_cache_25 += (new_k_25 - self.k_cache_25) |
| self.v_cache_25 += (new_v_25 - self.v_cache_25) |
| hidden, new_k_26, new_v_26 = self.layers[26](hidden, cos, sin, causal_mask, self.k_cache_26, self.v_cache_26, update_mask) |
| self.k_cache_26 += (new_k_26 - self.k_cache_26) |
| self.v_cache_26 += (new_v_26 - self.v_cache_26) |
| hidden, new_k_27, new_v_27 = self.layers[27](hidden, cos, sin, causal_mask, self.k_cache_27, self.v_cache_27, update_mask) |
| self.k_cache_27 += (new_k_27 - self.k_cache_27) |
| self.v_cache_27 += (new_v_27 - self.v_cache_27) |
| hidden, new_k_28, new_v_28 = self.layers[28](hidden, cos, sin, causal_mask, self.k_cache_28, self.v_cache_28, update_mask) |
| self.k_cache_28 += (new_k_28 - self.k_cache_28) |
| self.v_cache_28 += (new_v_28 - self.v_cache_28) |
| hidden, new_k_29, new_v_29 = self.layers[29](hidden, cos, sin, causal_mask, self.k_cache_29, self.v_cache_29, update_mask) |
| self.k_cache_29 += (new_k_29 - self.k_cache_29) |
| self.v_cache_29 += (new_v_29 - self.v_cache_29) |
|
|
| hidden = self.norm(hidden) |
| logits = F.linear(hidden, self.embed_tokens.weight) |
| return logits |
|
|