#!/usr/bin/env python3 """ Main conversion script: downloads Plapre Pico weights from HuggingFace, builds the decode wrapper model, traces, and converts to CoreML. Usage: python convert.py [--model-dir PATH] [--output-dir PATH] If --model-dir is not provided, downloads from syvai/plapre-pico. """ import argparse import json import shutil from pathlib import Path import numpy as np import torch import coremltools as ct from coremltools.converters.mil.mil import Builder as mb from huggingface_hub import snapshot_download from safetensors.torch import load_file from attention import precompute_rope_frequencies from model_wrapper import ( PlaprePico, NUM_LAYERS, NUM_KV_HEADS, MAX_CONTEXT, HEAD_DIM, PREFILL_SEQ_LEN, VOCAB_SIZE, HIDDEN_SIZE, SPEAKER_DIM, ) def download_model(model_id: str = "syvai/plapre-pico") -> Path: """Download model from HuggingFace Hub, return local path.""" print(f"Downloading {model_id}...") path = snapshot_download(model_id) print(f"Model cached at: {path}") return Path(path) def load_weights(model_dir: Path) -> dict[str, torch.Tensor]: """Load model.safetensors and speaker_proj.pt, cast bf16 → fp16.""" weights = {} safetensors_path = model_dir / "model.safetensors" print(f"Loading {safetensors_path}...") st_weights = load_file(str(safetensors_path)) for name, tensor in st_weights.items(): if tensor.dtype == torch.bfloat16: tensor = tensor.to(torch.float16) weights[name] = tensor speaker_proj_path = model_dir / "speaker_proj.pt" print(f"Loading {speaker_proj_path}...") sp_weights = torch.load(str(speaker_proj_path), map_location="cpu", weights_only=True) for name, tensor in sp_weights.items(): if tensor.dtype == torch.bfloat16: tensor = tensor.to(torch.float16) weights[f"speaker_proj.{name}"] = tensor print(f"Loaded {len(weights)} weight tensors") return weights def _map_weight_name(hf_name: str) -> str | None: """Map HuggingFace weight name to our model's parameter name.""" if hf_name == "model.embed_tokens.weight": return "embed_tokens.weight" if hf_name == "model.norm.weight": return "norm.weight" if hf_name == "lm_head.weight": return None # tied to embed_tokens if hf_name.startswith("model.layers."): rest = hf_name[len("model.layers."):] parts = rest.split(".", 1) layer_idx = parts[0] component = parts[1] mapping = { "self_attn.q_proj.weight": "self_attn.q_proj.weight", "self_attn.k_proj.weight": "self_attn.k_proj.weight", "self_attn.v_proj.weight": "self_attn.v_proj.weight", "self_attn.o_proj.weight": "self_attn.o_proj.weight", "mlp.gate_proj.weight": "mlp.gate_proj.weight", "mlp.up_proj.weight": "mlp.up_proj.weight", "mlp.down_proj.weight": "mlp.down_proj.weight", "input_layernorm.weight": "input_layernorm.weight", "post_attention_layernorm.weight": "post_attention_layernorm.weight", } if component in mapping: return f"layers.{layer_idx}.{mapping[component]}" if hf_name.startswith("speaker_proj."): return hf_name print(f" WARNING: unmapped weight: {hf_name}") return None def populate_weights(model: torch.nn.Module, weights: dict[str, torch.Tensor]): """Load weights into a PlaprePico model.""" state_dict = model.state_dict() new_state = {} for hf_name, tensor in weights.items(): our_name = _map_weight_name(hf_name) if our_name is None: continue if our_name in state_dict: if state_dict[our_name].shape != tensor.shape: print(f" Shape mismatch for {our_name}: " f"expected {state_dict[our_name].shape}, got {tensor.shape}") continue new_state[our_name] = tensor missing, unexpected = model.load_state_dict(new_state, strict=False) missing = [k for k in missing if not k.startswith(("k_cache_", "v_cache_", "rope_"))] if missing: print(f" Missing weights: {missing}") if unexpected: print(f" Unexpected weights: {unexpected}") print(f" Loaded {len(new_state)} weight tensors") def build_kv_cache_states() -> list: """Build CoreML StateType list for 60 KV cache buffers.""" states = [] for i in range(NUM_LAYERS): states.append( ct.StateType( wrapped_type=ct.TensorType( shape=(1, NUM_KV_HEADS, MAX_CONTEXT, HEAD_DIM), dtype=np.float16, ), name=f"k_cache_{i}", ) ) states.append( ct.StateType( wrapped_type=ct.TensorType( shape=(1, NUM_KV_HEADS, MAX_CONTEXT, HEAD_DIM), dtype=np.float16, ), name=f"v_cache_{i}", ) ) return states def convert_decode(model: PlaprePico, output_dir: Path): """Trace and convert decode model to CoreML.""" model.eval() print("Tracing decode model...") input_ids = torch.zeros(1, 1, dtype=torch.int32) causal_mask = torch.full((1, 1, 1, MAX_CONTEXT), float("-inf"), dtype=torch.float16) causal_mask[0, 0, 0, :PREFILL_SEQ_LEN] = 0.0 cos = torch.zeros(1, 1, 1, HEAD_DIM, dtype=torch.float16) sin = torch.zeros(1, 1, 1, HEAD_DIM, dtype=torch.float16) update_mask = torch.zeros(1, 1, MAX_CONTEXT, 1, dtype=torch.float16) update_mask[0, 0, 0, 0] = 1.0 # any valid position for tracing speaker_embedding = torch.zeros(1, SPEAKER_DIM, dtype=torch.float16) is_speaker_step = torch.zeros(1, dtype=torch.float16) with torch.no_grad(): traced = torch.jit.trace(model, ( input_ids, causal_mask, cos, sin, update_mask, speaker_embedding, is_speaker_step, )) print("Converting decode to CoreML...") mlmodel = ct.convert( traced, inputs=[ ct.TensorType(name="input_ids", shape=(1, 1), dtype=np.int32), ct.TensorType( name="causal_mask", shape=(1, 1, 1, MAX_CONTEXT), dtype=np.float16, ), ct.TensorType(name="cos", shape=(1, 1, 1, HEAD_DIM), dtype=np.float16), ct.TensorType(name="sin", shape=(1, 1, 1, HEAD_DIM), dtype=np.float16), ct.TensorType( name="update_mask", shape=(1, 1, MAX_CONTEXT, 1), dtype=np.float16, ), ct.TensorType( name="speaker_embedding", shape=(1, SPEAKER_DIM), dtype=np.float16, ), ct.TensorType( name="is_speaker_step", shape=(1,), dtype=np.float16, ), ], outputs=[ct.TensorType(name="logits", dtype=np.float16)], states=build_kv_cache_states(), compute_precision=ct.precision.FLOAT16, minimum_deployment_target=ct.target.iOS18, ) inject_state_updates(mlmodel) out_path = output_dir / "PlaprePico.mlpackage" mlmodel.save(str(out_path)) print(f"Saved decode model to {out_path}") return out_path def inject_state_updates(mlmodel): """Inject coreml_update_state ops into a converted stateful CoreML model. torch.jit.trace doesn't emit prim::SetAttr for buffer mutations, so coremltools can't generate coreml_update_state ops automatically. This walks the MIL graph, finds the read_state -> (cast?) -> mul -> add cache update pattern, and inserts coreml_update_state ops before the first consumer of each cache update. """ prog = mlmodel._mil_program main_fn = prog.functions["main"] read_ops = list(main_fn.find_ops(op_type="read_state")) print(f"Found {len(read_ops)} read_state ops") updates = [] for read_op in read_ops: state_var = read_op.inputs["input"] output = read_op.outputs[0] # FLOAT32: read_state -> cast(fp16->fp32) -> mul -> add # FLOAT16: read_state -> mul -> add first_child = output.child_ops[0] search_output = first_child.outputs[0] if first_child.op_type == "cast" else output mul_op = next((c for c in search_output.child_ops if c.op_type == "mul"), None) if mul_op is None: print(f" WARNING: no mul found for {state_var.name}") continue add_op = next((c for c in mul_op.outputs[0].child_ops if c.op_type == "add"), None) if add_op is None: print(f" WARNING: no add found for {state_var.name}") continue updates.append((state_var, add_op)) print(f"Injecting {len(updates)} coreml_update_state ops...") block = main_fn.find_ops(op_type="read_state")[0].enclosing_block with block: for state_var, add_op in updates: add_out = add_op.outputs[0] consumers = list(add_out.child_ops) if not consumers: print(f" WARNING: no consumers for {state_var.name} add output") continue first_consumer = consumers[0] with mb.set_before_op(before_op=first_consumer): if str(add_out.dtype) == "fp16": state_val = add_out else: state_val = mb.cast( x=add_out, dtype="fp16", name=f"state_cast_{state_var.name}", ) mb.coreml_update_state( state=state_var, value=state_val, name=f"state_update_{state_var.name}", ) prog_str = str(prog) print(f" read_state: {prog_str.count('read_state')}") print(f" coreml_update_state: {prog_str.count('coreml_update_state')}") def copy_assets(model_dir: Path, output_dir: Path): """Copy tokenizer.json, speakers.json, and RoPE tables to output root.""" for filename in ["tokenizer.json", "speakers.json"]: src = model_dir / filename if src.exists(): shutil.copy2(src, output_dir / filename) print(f"Copied {filename} to {output_dir}") else: print(f" WARNING: {filename} not found in {model_dir}") # Export RoPE tables for the iOS runtime to build cos/sin/update_mask inputs cos_full, sin_full = precompute_rope_frequencies(HEAD_DIM, MAX_CONTEXT, 100000.0) np.save(str(output_dir / "rope_cos.npy"), cos_full.numpy().astype(np.float16)) np.save(str(output_dir / "rope_sin.npy"), sin_full.numpy().astype(np.float16)) print(f"Exported RoPE tables to {output_dir}") manifest = { "model": "plapre-pico", "version": "1.0", "context_length": MAX_CONTEXT, "prefill_length": PREFILL_SEQ_LEN, "vocab_size": VOCAB_SIZE, "num_layers": NUM_LAYERS, "hidden_size": HIDDEN_SIZE, "num_kv_heads": NUM_KV_HEADS, "head_dim": HEAD_DIM, "speaker_dim": SPEAKER_DIM, "precision": "float16", } manifest_path = output_dir / "manifest.json" with open(manifest_path, "w") as f: json.dump(manifest, f, indent=2) print(f"Wrote manifest to {manifest_path}") def convert_llm(output_dir: Path, model_dir: Path | None = None) -> Path: """Convert Plapre Pico LLM end-to-end: download → load → trace → convert → inject state updates → copy assets. Returns path to PlaprePico.mlpackage.""" if model_dir is None: model_dir = download_model() output_dir.mkdir(parents=True, exist_ok=True) weights = load_weights(model_dir) print("\n=== Building decode model ===") decode = PlaprePico() populate_weights(decode, weights) decode = decode.half() out_path = convert_decode(decode, output_dir) print("\n=== Copying assets ===") copy_assets(model_dir, output_dir) print(f"\nLLM conversion complete: {out_path}") return out_path def main(): parser = argparse.ArgumentParser(description="Convert Plapre Pico LLM to CoreML") parser.add_argument("--model-dir", type=str, help="Path to downloaded model directory") parser.add_argument("--output-dir", type=str, default=str(Path(__file__).parent.parent), help="Output directory") args = parser.parse_args() convert_llm( output_dir=Path(args.output_dir), model_dir=Path(args.model_dir) if args.model_dir else None, ) if __name__ == "__main__": main()