File size: 12,662 Bytes
cb20bed
 
 
95c6137
cb20bed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e79fa0a
cb20bed
 
 
 
 
95c6137
cb20bed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95c6137
 
cb20bed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95c6137
cb20bed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95c6137
cb20bed
 
 
 
 
 
 
 
 
 
 
 
e79fa0a
cb20bed
d12b4ea
 
fad9fad
cb20bed
d12b4ea
 
 
 
cb20bed
 
 
d12b4ea
cb20bed
 
 
 
 
 
 
 
 
 
 
 
 
 
fad9fad
d12b4ea
 
 
 
 
 
 
fad9fad
 
cb20bed
 
 
10bad22
cb20bed
 
 
e79fa0a
 
95c6137
cb20bed
 
e79fa0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb20bed
 
 
a2c97d7
cb20bed
 
 
a2c97d7
 
cb20bed
 
 
 
 
a2c97d7
 
 
cb20bed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e79fa0a
 
 
 
cb20bed
 
 
 
 
 
95c6137
 
cb20bed
e79fa0a
cb20bed
 
 
 
e79fa0a
 
 
 
 
 
 
 
 
 
 
 
 
cb20bed
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
#!/usr/bin/env python3
"""
Main conversion script: downloads Plapre Pico weights from HuggingFace,
builds the decode wrapper model, traces, and converts to CoreML.

Usage:
    python convert.py [--model-dir PATH] [--output-dir PATH]

If --model-dir is not provided, downloads from syvai/plapre-pico.
"""

import argparse
import json
import shutil
from pathlib import Path

import numpy as np
import torch
import coremltools as ct
from coremltools.converters.mil.mil import Builder as mb
from huggingface_hub import snapshot_download
from safetensors.torch import load_file

from attention import precompute_rope_frequencies
from model_wrapper import (
    PlaprePico,
    NUM_LAYERS,
    NUM_KV_HEADS,
    MAX_CONTEXT,
    HEAD_DIM,
    PREFILL_SEQ_LEN,
    VOCAB_SIZE,
    HIDDEN_SIZE,
    SPEAKER_DIM,
)


def download_model(model_id: str = "syvai/plapre-pico") -> Path:
    """Download model from HuggingFace Hub, return local path."""
    print(f"Downloading {model_id}...")
    path = snapshot_download(model_id)
    print(f"Model cached at: {path}")
    return Path(path)


def load_weights(model_dir: Path) -> dict[str, torch.Tensor]:
    """Load model.safetensors and speaker_proj.pt, cast bf16 → fp16."""
    weights = {}

    safetensors_path = model_dir / "model.safetensors"
    print(f"Loading {safetensors_path}...")
    st_weights = load_file(str(safetensors_path))
    for name, tensor in st_weights.items():
        if tensor.dtype == torch.bfloat16:
            tensor = tensor.to(torch.float16)
        weights[name] = tensor

    speaker_proj_path = model_dir / "speaker_proj.pt"
    print(f"Loading {speaker_proj_path}...")
    sp_weights = torch.load(str(speaker_proj_path), map_location="cpu", weights_only=True)
    for name, tensor in sp_weights.items():
        if tensor.dtype == torch.bfloat16:
            tensor = tensor.to(torch.float16)
        weights[f"speaker_proj.{name}"] = tensor

    print(f"Loaded {len(weights)} weight tensors")
    return weights


def _map_weight_name(hf_name: str) -> str | None:
    """Map HuggingFace weight name to our model's parameter name."""
    if hf_name == "model.embed_tokens.weight":
        return "embed_tokens.weight"
    if hf_name == "model.norm.weight":
        return "norm.weight"
    if hf_name == "lm_head.weight":
        return None  # tied to embed_tokens

    if hf_name.startswith("model.layers."):
        rest = hf_name[len("model.layers."):]
        parts = rest.split(".", 1)
        layer_idx = parts[0]
        component = parts[1]

        mapping = {
            "self_attn.q_proj.weight": "self_attn.q_proj.weight",
            "self_attn.k_proj.weight": "self_attn.k_proj.weight",
            "self_attn.v_proj.weight": "self_attn.v_proj.weight",
            "self_attn.o_proj.weight": "self_attn.o_proj.weight",
            "mlp.gate_proj.weight": "mlp.gate_proj.weight",
            "mlp.up_proj.weight": "mlp.up_proj.weight",
            "mlp.down_proj.weight": "mlp.down_proj.weight",
            "input_layernorm.weight": "input_layernorm.weight",
            "post_attention_layernorm.weight": "post_attention_layernorm.weight",
        }
        if component in mapping:
            return f"layers.{layer_idx}.{mapping[component]}"

    if hf_name.startswith("speaker_proj."):
        return hf_name

    print(f"  WARNING: unmapped weight: {hf_name}")
    return None


def populate_weights(model: torch.nn.Module, weights: dict[str, torch.Tensor]):
    """Load weights into a PlaprePico model."""
    state_dict = model.state_dict()
    new_state = {}

    for hf_name, tensor in weights.items():
        our_name = _map_weight_name(hf_name)
        if our_name is None:
            continue
        if our_name in state_dict:
            if state_dict[our_name].shape != tensor.shape:
                print(f"  Shape mismatch for {our_name}: "
                      f"expected {state_dict[our_name].shape}, got {tensor.shape}")
                continue
            new_state[our_name] = tensor

    missing, unexpected = model.load_state_dict(new_state, strict=False)
    missing = [k for k in missing if not k.startswith(("k_cache_", "v_cache_", "rope_"))]
    if missing:
        print(f"  Missing weights: {missing}")
    if unexpected:
        print(f"  Unexpected weights: {unexpected}")
    print(f"  Loaded {len(new_state)} weight tensors")


def build_kv_cache_states() -> list:
    """Build CoreML StateType list for 60 KV cache buffers."""
    states = []
    for i in range(NUM_LAYERS):
        states.append(
            ct.StateType(
                wrapped_type=ct.TensorType(
                    shape=(1, NUM_KV_HEADS, MAX_CONTEXT, HEAD_DIM),
                    dtype=np.float16,
                ),
                name=f"k_cache_{i}",
            )
        )
        states.append(
            ct.StateType(
                wrapped_type=ct.TensorType(
                    shape=(1, NUM_KV_HEADS, MAX_CONTEXT, HEAD_DIM),
                    dtype=np.float16,
                ),
                name=f"v_cache_{i}",
            )
        )
    return states



def convert_decode(model: PlaprePico, output_dir: Path):
    """Trace and convert decode model to CoreML."""
    model.eval()
    print("Tracing decode model...")

    input_ids = torch.zeros(1, 1, dtype=torch.int32)
    causal_mask = torch.full((1, 1, 1, MAX_CONTEXT), float("-inf"), dtype=torch.float16)
    causal_mask[0, 0, 0, :PREFILL_SEQ_LEN] = 0.0

    cos = torch.zeros(1, 1, 1, HEAD_DIM, dtype=torch.float16)
    sin = torch.zeros(1, 1, 1, HEAD_DIM, dtype=torch.float16)

    update_mask = torch.zeros(1, 1, MAX_CONTEXT, 1, dtype=torch.float16)
    update_mask[0, 0, 0, 0] = 1.0  # any valid position for tracing

    speaker_embedding = torch.zeros(1, SPEAKER_DIM, dtype=torch.float16)
    is_speaker_step = torch.zeros(1, dtype=torch.float16)

    with torch.no_grad():
        traced = torch.jit.trace(model, (
            input_ids, causal_mask, cos, sin, update_mask,
            speaker_embedding, is_speaker_step,
        ))

    print("Converting decode to CoreML...")
    mlmodel = ct.convert(
        traced,
        inputs=[
            ct.TensorType(name="input_ids", shape=(1, 1), dtype=np.int32),
            ct.TensorType(
                name="causal_mask",
                shape=(1, 1, 1, MAX_CONTEXT),
                dtype=np.float16,
            ),
            ct.TensorType(name="cos", shape=(1, 1, 1, HEAD_DIM), dtype=np.float16),
            ct.TensorType(name="sin", shape=(1, 1, 1, HEAD_DIM), dtype=np.float16),
            ct.TensorType(
                name="update_mask",
                shape=(1, 1, MAX_CONTEXT, 1),
                dtype=np.float16,
            ),
            ct.TensorType(
                name="speaker_embedding",
                shape=(1, SPEAKER_DIM),
                dtype=np.float16,
            ),
            ct.TensorType(
                name="is_speaker_step",
                shape=(1,),
                dtype=np.float16,
            ),
        ],
        outputs=[ct.TensorType(name="logits", dtype=np.float16)],
        states=build_kv_cache_states(),
        compute_precision=ct.precision.FLOAT16,
        minimum_deployment_target=ct.target.iOS18,
    )

    inject_state_updates(mlmodel)

    out_path = output_dir / "PlaprePico.mlpackage"
    mlmodel.save(str(out_path))
    print(f"Saved decode model to {out_path}")
    return out_path


def inject_state_updates(mlmodel):
    """Inject coreml_update_state ops into a converted stateful CoreML model.

    torch.jit.trace doesn't emit prim::SetAttr for buffer mutations, so coremltools
    can't generate coreml_update_state ops automatically. This walks the MIL graph,
    finds the read_state -> (cast?) -> mul -> add cache update pattern, and inserts
    coreml_update_state ops before the first consumer of each cache update.
    """
    prog = mlmodel._mil_program
    main_fn = prog.functions["main"]

    read_ops = list(main_fn.find_ops(op_type="read_state"))
    print(f"Found {len(read_ops)} read_state ops")

    updates = []
    for read_op in read_ops:
        state_var = read_op.inputs["input"]
        output = read_op.outputs[0]

        # FLOAT32: read_state -> cast(fp16->fp32) -> mul -> add
        # FLOAT16: read_state -> mul -> add
        first_child = output.child_ops[0]
        search_output = first_child.outputs[0] if first_child.op_type == "cast" else output

        mul_op = next((c for c in search_output.child_ops if c.op_type == "mul"), None)
        if mul_op is None:
            print(f"  WARNING: no mul found for {state_var.name}")
            continue

        add_op = next((c for c in mul_op.outputs[0].child_ops if c.op_type == "add"), None)
        if add_op is None:
            print(f"  WARNING: no add found for {state_var.name}")
            continue

        updates.append((state_var, add_op))

    print(f"Injecting {len(updates)} coreml_update_state ops...")

    block = main_fn.find_ops(op_type="read_state")[0].enclosing_block
    with block:
        for state_var, add_op in updates:
            add_out = add_op.outputs[0]
            consumers = list(add_out.child_ops)
            if not consumers:
                print(f"  WARNING: no consumers for {state_var.name} add output")
                continue
            first_consumer = consumers[0]

            with mb.set_before_op(before_op=first_consumer):
                if str(add_out.dtype) == "fp16":
                    state_val = add_out
                else:
                    state_val = mb.cast(
                        x=add_out, dtype="fp16",
                        name=f"state_cast_{state_var.name}",
                    )
                mb.coreml_update_state(
                    state=state_var, value=state_val,
                    name=f"state_update_{state_var.name}",
                )

    prog_str = str(prog)
    print(f"  read_state: {prog_str.count('read_state')}")
    print(f"  coreml_update_state: {prog_str.count('coreml_update_state')}")


def copy_assets(model_dir: Path, output_dir: Path):
    """Copy tokenizer.json, speakers.json, and RoPE tables to output root."""
    for filename in ["tokenizer.json", "speakers.json"]:
        src = model_dir / filename
        if src.exists():
            shutil.copy2(src, output_dir / filename)
            print(f"Copied {filename} to {output_dir}")
        else:
            print(f"  WARNING: {filename} not found in {model_dir}")

    # Export RoPE tables for the iOS runtime to build cos/sin/update_mask inputs
    cos_full, sin_full = precompute_rope_frequencies(HEAD_DIM, MAX_CONTEXT, 100000.0)
    np.save(str(output_dir / "rope_cos.npy"), cos_full.numpy().astype(np.float16))
    np.save(str(output_dir / "rope_sin.npy"), sin_full.numpy().astype(np.float16))
    print(f"Exported RoPE tables to {output_dir}")

    manifest = {
        "model": "plapre-pico",
        "version": "1.0",
        "context_length": MAX_CONTEXT,
        "prefill_length": PREFILL_SEQ_LEN,
        "vocab_size": VOCAB_SIZE,
        "num_layers": NUM_LAYERS,
        "hidden_size": HIDDEN_SIZE,
        "num_kv_heads": NUM_KV_HEADS,
        "head_dim": HEAD_DIM,
        "speaker_dim": SPEAKER_DIM,
        "precision": "float16",
    }
    manifest_path = output_dir / "manifest.json"
    with open(manifest_path, "w") as f:
        json.dump(manifest, f, indent=2)
    print(f"Wrote manifest to {manifest_path}")


def convert_llm(output_dir: Path, model_dir: Path | None = None) -> Path:
    """Convert Plapre Pico LLM end-to-end: download → load → trace → convert →
    inject state updates → copy assets. Returns path to PlaprePico.mlpackage."""
    if model_dir is None:
        model_dir = download_model()

    output_dir.mkdir(parents=True, exist_ok=True)
    weights = load_weights(model_dir)

    print("\n=== Building decode model ===")
    decode = PlaprePico()
    populate_weights(decode, weights)
    decode = decode.half()
    out_path = convert_decode(decode, output_dir)

    print("\n=== Copying assets ===")
    copy_assets(model_dir, output_dir)

    print(f"\nLLM conversion complete: {out_path}")
    return out_path


def main():
    parser = argparse.ArgumentParser(description="Convert Plapre Pico LLM to CoreML")
    parser.add_argument("--model-dir", type=str, help="Path to downloaded model directory")
    parser.add_argument("--output-dir", type=str, default=str(Path(__file__).parent.parent), help="Output directory")
    args = parser.parse_args()
    convert_llm(
        output_dir=Path(args.output_dir),
        model_dir=Path(args.model_dir) if args.model_dir else None,
    )


if __name__ == "__main__":
    main()