| |
| """ |
| Main conversion script: downloads Plapre Pico weights from HuggingFace, |
| builds the decode wrapper model, traces, and converts to CoreML. |
| |
| Usage: |
| python convert.py [--model-dir PATH] [--output-dir PATH] |
| |
| If --model-dir is not provided, downloads from syvai/plapre-pico. |
| """ |
|
|
| import argparse |
| import json |
| import shutil |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import coremltools as ct |
| from coremltools.converters.mil.mil import Builder as mb |
| from huggingface_hub import snapshot_download |
| from safetensors.torch import load_file |
|
|
| from attention import precompute_rope_frequencies |
| from model_wrapper import ( |
| PlaprePico, |
| NUM_LAYERS, |
| NUM_KV_HEADS, |
| MAX_CONTEXT, |
| HEAD_DIM, |
| PREFILL_SEQ_LEN, |
| VOCAB_SIZE, |
| HIDDEN_SIZE, |
| SPEAKER_DIM, |
| ) |
|
|
|
|
| def download_model(model_id: str = "syvai/plapre-pico") -> Path: |
| """Download model from HuggingFace Hub, return local path.""" |
| print(f"Downloading {model_id}...") |
| path = snapshot_download(model_id) |
| print(f"Model cached at: {path}") |
| return Path(path) |
|
|
|
|
| def load_weights(model_dir: Path) -> dict[str, torch.Tensor]: |
| """Load model.safetensors and speaker_proj.pt, cast bf16 → fp16.""" |
| weights = {} |
|
|
| safetensors_path = model_dir / "model.safetensors" |
| print(f"Loading {safetensors_path}...") |
| st_weights = load_file(str(safetensors_path)) |
| for name, tensor in st_weights.items(): |
| if tensor.dtype == torch.bfloat16: |
| tensor = tensor.to(torch.float16) |
| weights[name] = tensor |
|
|
| speaker_proj_path = model_dir / "speaker_proj.pt" |
| print(f"Loading {speaker_proj_path}...") |
| sp_weights = torch.load(str(speaker_proj_path), map_location="cpu", weights_only=True) |
| for name, tensor in sp_weights.items(): |
| if tensor.dtype == torch.bfloat16: |
| tensor = tensor.to(torch.float16) |
| weights[f"speaker_proj.{name}"] = tensor |
|
|
| print(f"Loaded {len(weights)} weight tensors") |
| return weights |
|
|
|
|
| def _map_weight_name(hf_name: str) -> str | None: |
| """Map HuggingFace weight name to our model's parameter name.""" |
| if hf_name == "model.embed_tokens.weight": |
| return "embed_tokens.weight" |
| if hf_name == "model.norm.weight": |
| return "norm.weight" |
| if hf_name == "lm_head.weight": |
| return None |
|
|
| if hf_name.startswith("model.layers."): |
| rest = hf_name[len("model.layers."):] |
| parts = rest.split(".", 1) |
| layer_idx = parts[0] |
| component = parts[1] |
|
|
| mapping = { |
| "self_attn.q_proj.weight": "self_attn.q_proj.weight", |
| "self_attn.k_proj.weight": "self_attn.k_proj.weight", |
| "self_attn.v_proj.weight": "self_attn.v_proj.weight", |
| "self_attn.o_proj.weight": "self_attn.o_proj.weight", |
| "mlp.gate_proj.weight": "mlp.gate_proj.weight", |
| "mlp.up_proj.weight": "mlp.up_proj.weight", |
| "mlp.down_proj.weight": "mlp.down_proj.weight", |
| "input_layernorm.weight": "input_layernorm.weight", |
| "post_attention_layernorm.weight": "post_attention_layernorm.weight", |
| } |
| if component in mapping: |
| return f"layers.{layer_idx}.{mapping[component]}" |
|
|
| if hf_name.startswith("speaker_proj."): |
| return hf_name |
|
|
| print(f" WARNING: unmapped weight: {hf_name}") |
| return None |
|
|
|
|
| def populate_weights(model: torch.nn.Module, weights: dict[str, torch.Tensor]): |
| """Load weights into a PlaprePico model.""" |
| state_dict = model.state_dict() |
| new_state = {} |
|
|
| for hf_name, tensor in weights.items(): |
| our_name = _map_weight_name(hf_name) |
| if our_name is None: |
| continue |
| if our_name in state_dict: |
| if state_dict[our_name].shape != tensor.shape: |
| print(f" Shape mismatch for {our_name}: " |
| f"expected {state_dict[our_name].shape}, got {tensor.shape}") |
| continue |
| new_state[our_name] = tensor |
|
|
| missing, unexpected = model.load_state_dict(new_state, strict=False) |
| missing = [k for k in missing if not k.startswith(("k_cache_", "v_cache_", "rope_"))] |
| if missing: |
| print(f" Missing weights: {missing}") |
| if unexpected: |
| print(f" Unexpected weights: {unexpected}") |
| print(f" Loaded {len(new_state)} weight tensors") |
|
|
|
|
| def build_kv_cache_states() -> list: |
| """Build CoreML StateType list for 60 KV cache buffers.""" |
| states = [] |
| for i in range(NUM_LAYERS): |
| states.append( |
| ct.StateType( |
| wrapped_type=ct.TensorType( |
| shape=(1, NUM_KV_HEADS, MAX_CONTEXT, HEAD_DIM), |
| dtype=np.float16, |
| ), |
| name=f"k_cache_{i}", |
| ) |
| ) |
| states.append( |
| ct.StateType( |
| wrapped_type=ct.TensorType( |
| shape=(1, NUM_KV_HEADS, MAX_CONTEXT, HEAD_DIM), |
| dtype=np.float16, |
| ), |
| name=f"v_cache_{i}", |
| ) |
| ) |
| return states |
|
|
|
|
|
|
| def convert_decode(model: PlaprePico, output_dir: Path): |
| """Trace and convert decode model to CoreML.""" |
| model.eval() |
| print("Tracing decode model...") |
|
|
| input_ids = torch.zeros(1, 1, dtype=torch.int32) |
| causal_mask = torch.full((1, 1, 1, MAX_CONTEXT), float("-inf"), dtype=torch.float16) |
| causal_mask[0, 0, 0, :PREFILL_SEQ_LEN] = 0.0 |
|
|
| cos = torch.zeros(1, 1, 1, HEAD_DIM, dtype=torch.float16) |
| sin = torch.zeros(1, 1, 1, HEAD_DIM, dtype=torch.float16) |
|
|
| update_mask = torch.zeros(1, 1, MAX_CONTEXT, 1, dtype=torch.float16) |
| update_mask[0, 0, 0, 0] = 1.0 |
|
|
| speaker_embedding = torch.zeros(1, SPEAKER_DIM, dtype=torch.float16) |
| is_speaker_step = torch.zeros(1, dtype=torch.float16) |
|
|
| with torch.no_grad(): |
| traced = torch.jit.trace(model, ( |
| input_ids, causal_mask, cos, sin, update_mask, |
| speaker_embedding, is_speaker_step, |
| )) |
|
|
| print("Converting decode to CoreML...") |
| mlmodel = ct.convert( |
| traced, |
| inputs=[ |
| ct.TensorType(name="input_ids", shape=(1, 1), dtype=np.int32), |
| ct.TensorType( |
| name="causal_mask", |
| shape=(1, 1, 1, MAX_CONTEXT), |
| dtype=np.float16, |
| ), |
| ct.TensorType(name="cos", shape=(1, 1, 1, HEAD_DIM), dtype=np.float16), |
| ct.TensorType(name="sin", shape=(1, 1, 1, HEAD_DIM), dtype=np.float16), |
| ct.TensorType( |
| name="update_mask", |
| shape=(1, 1, MAX_CONTEXT, 1), |
| dtype=np.float16, |
| ), |
| ct.TensorType( |
| name="speaker_embedding", |
| shape=(1, SPEAKER_DIM), |
| dtype=np.float16, |
| ), |
| ct.TensorType( |
| name="is_speaker_step", |
| shape=(1,), |
| dtype=np.float16, |
| ), |
| ], |
| outputs=[ct.TensorType(name="logits", dtype=np.float16)], |
| states=build_kv_cache_states(), |
| compute_precision=ct.precision.FLOAT16, |
| minimum_deployment_target=ct.target.iOS18, |
| ) |
|
|
| inject_state_updates(mlmodel) |
|
|
| out_path = output_dir / "PlaprePico.mlpackage" |
| mlmodel.save(str(out_path)) |
| print(f"Saved decode model to {out_path}") |
| return out_path |
|
|
|
|
| def inject_state_updates(mlmodel): |
| """Inject coreml_update_state ops into a converted stateful CoreML model. |
| |
| torch.jit.trace doesn't emit prim::SetAttr for buffer mutations, so coremltools |
| can't generate coreml_update_state ops automatically. This walks the MIL graph, |
| finds the read_state -> (cast?) -> mul -> add cache update pattern, and inserts |
| coreml_update_state ops before the first consumer of each cache update. |
| """ |
| prog = mlmodel._mil_program |
| main_fn = prog.functions["main"] |
|
|
| read_ops = list(main_fn.find_ops(op_type="read_state")) |
| print(f"Found {len(read_ops)} read_state ops") |
|
|
| updates = [] |
| for read_op in read_ops: |
| state_var = read_op.inputs["input"] |
| output = read_op.outputs[0] |
|
|
| |
| |
| first_child = output.child_ops[0] |
| search_output = first_child.outputs[0] if first_child.op_type == "cast" else output |
|
|
| mul_op = next((c for c in search_output.child_ops if c.op_type == "mul"), None) |
| if mul_op is None: |
| print(f" WARNING: no mul found for {state_var.name}") |
| continue |
|
|
| add_op = next((c for c in mul_op.outputs[0].child_ops if c.op_type == "add"), None) |
| if add_op is None: |
| print(f" WARNING: no add found for {state_var.name}") |
| continue |
|
|
| updates.append((state_var, add_op)) |
|
|
| print(f"Injecting {len(updates)} coreml_update_state ops...") |
|
|
| block = main_fn.find_ops(op_type="read_state")[0].enclosing_block |
| with block: |
| for state_var, add_op in updates: |
| add_out = add_op.outputs[0] |
| consumers = list(add_out.child_ops) |
| if not consumers: |
| print(f" WARNING: no consumers for {state_var.name} add output") |
| continue |
| first_consumer = consumers[0] |
|
|
| with mb.set_before_op(before_op=first_consumer): |
| if str(add_out.dtype) == "fp16": |
| state_val = add_out |
| else: |
| state_val = mb.cast( |
| x=add_out, dtype="fp16", |
| name=f"state_cast_{state_var.name}", |
| ) |
| mb.coreml_update_state( |
| state=state_var, value=state_val, |
| name=f"state_update_{state_var.name}", |
| ) |
|
|
| prog_str = str(prog) |
| print(f" read_state: {prog_str.count('read_state')}") |
| print(f" coreml_update_state: {prog_str.count('coreml_update_state')}") |
|
|
|
|
| def copy_assets(model_dir: Path, output_dir: Path): |
| """Copy tokenizer.json, speakers.json, and RoPE tables to output root.""" |
| for filename in ["tokenizer.json", "speakers.json"]: |
| src = model_dir / filename |
| if src.exists(): |
| shutil.copy2(src, output_dir / filename) |
| print(f"Copied {filename} to {output_dir}") |
| else: |
| print(f" WARNING: {filename} not found in {model_dir}") |
|
|
| |
| cos_full, sin_full = precompute_rope_frequencies(HEAD_DIM, MAX_CONTEXT, 100000.0) |
| np.save(str(output_dir / "rope_cos.npy"), cos_full.numpy().astype(np.float16)) |
| np.save(str(output_dir / "rope_sin.npy"), sin_full.numpy().astype(np.float16)) |
| print(f"Exported RoPE tables to {output_dir}") |
|
|
| manifest = { |
| "model": "plapre-pico", |
| "version": "1.0", |
| "context_length": MAX_CONTEXT, |
| "prefill_length": PREFILL_SEQ_LEN, |
| "vocab_size": VOCAB_SIZE, |
| "num_layers": NUM_LAYERS, |
| "hidden_size": HIDDEN_SIZE, |
| "num_kv_heads": NUM_KV_HEADS, |
| "head_dim": HEAD_DIM, |
| "speaker_dim": SPEAKER_DIM, |
| "precision": "float16", |
| } |
| manifest_path = output_dir / "manifest.json" |
| with open(manifest_path, "w") as f: |
| json.dump(manifest, f, indent=2) |
| print(f"Wrote manifest to {manifest_path}") |
|
|
|
|
| def convert_llm(output_dir: Path, model_dir: Path | None = None) -> Path: |
| """Convert Plapre Pico LLM end-to-end: download → load → trace → convert → |
| inject state updates → copy assets. Returns path to PlaprePico.mlpackage.""" |
| if model_dir is None: |
| model_dir = download_model() |
|
|
| output_dir.mkdir(parents=True, exist_ok=True) |
| weights = load_weights(model_dir) |
|
|
| print("\n=== Building decode model ===") |
| decode = PlaprePico() |
| populate_weights(decode, weights) |
| decode = decode.half() |
| out_path = convert_decode(decode, output_dir) |
|
|
| print("\n=== Copying assets ===") |
| copy_assets(model_dir, output_dir) |
|
|
| print(f"\nLLM conversion complete: {out_path}") |
| return out_path |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Convert Plapre Pico LLM to CoreML") |
| parser.add_argument("--model-dir", type=str, help="Path to downloaded model directory") |
| parser.add_argument("--output-dir", type=str, default=str(Path(__file__).parent.parent), help="Output directory") |
| args = parser.parse_args() |
| convert_llm( |
| output_dir=Path(args.output_dir), |
| model_dir=Path(args.model_dir) if args.model_dir else None, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|