plapre-pico-coreml / scripts /convert_llm.py
Daniel Rothmann
Tidy up repo and conversion scripts
e79fa0a
#!/usr/bin/env python3
"""
Main conversion script: downloads Plapre Pico weights from HuggingFace,
builds the decode wrapper model, traces, and converts to CoreML.
Usage:
python convert.py [--model-dir PATH] [--output-dir PATH]
If --model-dir is not provided, downloads from syvai/plapre-pico.
"""
import argparse
import json
import shutil
from pathlib import Path
import numpy as np
import torch
import coremltools as ct
from coremltools.converters.mil.mil import Builder as mb
from huggingface_hub import snapshot_download
from safetensors.torch import load_file
from attention import precompute_rope_frequencies
from model_wrapper import (
PlaprePico,
NUM_LAYERS,
NUM_KV_HEADS,
MAX_CONTEXT,
HEAD_DIM,
PREFILL_SEQ_LEN,
VOCAB_SIZE,
HIDDEN_SIZE,
SPEAKER_DIM,
)
def download_model(model_id: str = "syvai/plapre-pico") -> Path:
"""Download model from HuggingFace Hub, return local path."""
print(f"Downloading {model_id}...")
path = snapshot_download(model_id)
print(f"Model cached at: {path}")
return Path(path)
def load_weights(model_dir: Path) -> dict[str, torch.Tensor]:
"""Load model.safetensors and speaker_proj.pt, cast bf16 → fp16."""
weights = {}
safetensors_path = model_dir / "model.safetensors"
print(f"Loading {safetensors_path}...")
st_weights = load_file(str(safetensors_path))
for name, tensor in st_weights.items():
if tensor.dtype == torch.bfloat16:
tensor = tensor.to(torch.float16)
weights[name] = tensor
speaker_proj_path = model_dir / "speaker_proj.pt"
print(f"Loading {speaker_proj_path}...")
sp_weights = torch.load(str(speaker_proj_path), map_location="cpu", weights_only=True)
for name, tensor in sp_weights.items():
if tensor.dtype == torch.bfloat16:
tensor = tensor.to(torch.float16)
weights[f"speaker_proj.{name}"] = tensor
print(f"Loaded {len(weights)} weight tensors")
return weights
def _map_weight_name(hf_name: str) -> str | None:
"""Map HuggingFace weight name to our model's parameter name."""
if hf_name == "model.embed_tokens.weight":
return "embed_tokens.weight"
if hf_name == "model.norm.weight":
return "norm.weight"
if hf_name == "lm_head.weight":
return None # tied to embed_tokens
if hf_name.startswith("model.layers."):
rest = hf_name[len("model.layers."):]
parts = rest.split(".", 1)
layer_idx = parts[0]
component = parts[1]
mapping = {
"self_attn.q_proj.weight": "self_attn.q_proj.weight",
"self_attn.k_proj.weight": "self_attn.k_proj.weight",
"self_attn.v_proj.weight": "self_attn.v_proj.weight",
"self_attn.o_proj.weight": "self_attn.o_proj.weight",
"mlp.gate_proj.weight": "mlp.gate_proj.weight",
"mlp.up_proj.weight": "mlp.up_proj.weight",
"mlp.down_proj.weight": "mlp.down_proj.weight",
"input_layernorm.weight": "input_layernorm.weight",
"post_attention_layernorm.weight": "post_attention_layernorm.weight",
}
if component in mapping:
return f"layers.{layer_idx}.{mapping[component]}"
if hf_name.startswith("speaker_proj."):
return hf_name
print(f" WARNING: unmapped weight: {hf_name}")
return None
def populate_weights(model: torch.nn.Module, weights: dict[str, torch.Tensor]):
"""Load weights into a PlaprePico model."""
state_dict = model.state_dict()
new_state = {}
for hf_name, tensor in weights.items():
our_name = _map_weight_name(hf_name)
if our_name is None:
continue
if our_name in state_dict:
if state_dict[our_name].shape != tensor.shape:
print(f" Shape mismatch for {our_name}: "
f"expected {state_dict[our_name].shape}, got {tensor.shape}")
continue
new_state[our_name] = tensor
missing, unexpected = model.load_state_dict(new_state, strict=False)
missing = [k for k in missing if not k.startswith(("k_cache_", "v_cache_", "rope_"))]
if missing:
print(f" Missing weights: {missing}")
if unexpected:
print(f" Unexpected weights: {unexpected}")
print(f" Loaded {len(new_state)} weight tensors")
def build_kv_cache_states() -> list:
"""Build CoreML StateType list for 60 KV cache buffers."""
states = []
for i in range(NUM_LAYERS):
states.append(
ct.StateType(
wrapped_type=ct.TensorType(
shape=(1, NUM_KV_HEADS, MAX_CONTEXT, HEAD_DIM),
dtype=np.float16,
),
name=f"k_cache_{i}",
)
)
states.append(
ct.StateType(
wrapped_type=ct.TensorType(
shape=(1, NUM_KV_HEADS, MAX_CONTEXT, HEAD_DIM),
dtype=np.float16,
),
name=f"v_cache_{i}",
)
)
return states
def convert_decode(model: PlaprePico, output_dir: Path):
"""Trace and convert decode model to CoreML."""
model.eval()
print("Tracing decode model...")
input_ids = torch.zeros(1, 1, dtype=torch.int32)
causal_mask = torch.full((1, 1, 1, MAX_CONTEXT), float("-inf"), dtype=torch.float16)
causal_mask[0, 0, 0, :PREFILL_SEQ_LEN] = 0.0
cos = torch.zeros(1, 1, 1, HEAD_DIM, dtype=torch.float16)
sin = torch.zeros(1, 1, 1, HEAD_DIM, dtype=torch.float16)
update_mask = torch.zeros(1, 1, MAX_CONTEXT, 1, dtype=torch.float16)
update_mask[0, 0, 0, 0] = 1.0 # any valid position for tracing
speaker_embedding = torch.zeros(1, SPEAKER_DIM, dtype=torch.float16)
is_speaker_step = torch.zeros(1, dtype=torch.float16)
with torch.no_grad():
traced = torch.jit.trace(model, (
input_ids, causal_mask, cos, sin, update_mask,
speaker_embedding, is_speaker_step,
))
print("Converting decode to CoreML...")
mlmodel = ct.convert(
traced,
inputs=[
ct.TensorType(name="input_ids", shape=(1, 1), dtype=np.int32),
ct.TensorType(
name="causal_mask",
shape=(1, 1, 1, MAX_CONTEXT),
dtype=np.float16,
),
ct.TensorType(name="cos", shape=(1, 1, 1, HEAD_DIM), dtype=np.float16),
ct.TensorType(name="sin", shape=(1, 1, 1, HEAD_DIM), dtype=np.float16),
ct.TensorType(
name="update_mask",
shape=(1, 1, MAX_CONTEXT, 1),
dtype=np.float16,
),
ct.TensorType(
name="speaker_embedding",
shape=(1, SPEAKER_DIM),
dtype=np.float16,
),
ct.TensorType(
name="is_speaker_step",
shape=(1,),
dtype=np.float16,
),
],
outputs=[ct.TensorType(name="logits", dtype=np.float16)],
states=build_kv_cache_states(),
compute_precision=ct.precision.FLOAT16,
minimum_deployment_target=ct.target.iOS18,
)
inject_state_updates(mlmodel)
out_path = output_dir / "PlaprePico.mlpackage"
mlmodel.save(str(out_path))
print(f"Saved decode model to {out_path}")
return out_path
def inject_state_updates(mlmodel):
"""Inject coreml_update_state ops into a converted stateful CoreML model.
torch.jit.trace doesn't emit prim::SetAttr for buffer mutations, so coremltools
can't generate coreml_update_state ops automatically. This walks the MIL graph,
finds the read_state -> (cast?) -> mul -> add cache update pattern, and inserts
coreml_update_state ops before the first consumer of each cache update.
"""
prog = mlmodel._mil_program
main_fn = prog.functions["main"]
read_ops = list(main_fn.find_ops(op_type="read_state"))
print(f"Found {len(read_ops)} read_state ops")
updates = []
for read_op in read_ops:
state_var = read_op.inputs["input"]
output = read_op.outputs[0]
# FLOAT32: read_state -> cast(fp16->fp32) -> mul -> add
# FLOAT16: read_state -> mul -> add
first_child = output.child_ops[0]
search_output = first_child.outputs[0] if first_child.op_type == "cast" else output
mul_op = next((c for c in search_output.child_ops if c.op_type == "mul"), None)
if mul_op is None:
print(f" WARNING: no mul found for {state_var.name}")
continue
add_op = next((c for c in mul_op.outputs[0].child_ops if c.op_type == "add"), None)
if add_op is None:
print(f" WARNING: no add found for {state_var.name}")
continue
updates.append((state_var, add_op))
print(f"Injecting {len(updates)} coreml_update_state ops...")
block = main_fn.find_ops(op_type="read_state")[0].enclosing_block
with block:
for state_var, add_op in updates:
add_out = add_op.outputs[0]
consumers = list(add_out.child_ops)
if not consumers:
print(f" WARNING: no consumers for {state_var.name} add output")
continue
first_consumer = consumers[0]
with mb.set_before_op(before_op=first_consumer):
if str(add_out.dtype) == "fp16":
state_val = add_out
else:
state_val = mb.cast(
x=add_out, dtype="fp16",
name=f"state_cast_{state_var.name}",
)
mb.coreml_update_state(
state=state_var, value=state_val,
name=f"state_update_{state_var.name}",
)
prog_str = str(prog)
print(f" read_state: {prog_str.count('read_state')}")
print(f" coreml_update_state: {prog_str.count('coreml_update_state')}")
def copy_assets(model_dir: Path, output_dir: Path):
"""Copy tokenizer.json, speakers.json, and RoPE tables to output root."""
for filename in ["tokenizer.json", "speakers.json"]:
src = model_dir / filename
if src.exists():
shutil.copy2(src, output_dir / filename)
print(f"Copied {filename} to {output_dir}")
else:
print(f" WARNING: {filename} not found in {model_dir}")
# Export RoPE tables for the iOS runtime to build cos/sin/update_mask inputs
cos_full, sin_full = precompute_rope_frequencies(HEAD_DIM, MAX_CONTEXT, 100000.0)
np.save(str(output_dir / "rope_cos.npy"), cos_full.numpy().astype(np.float16))
np.save(str(output_dir / "rope_sin.npy"), sin_full.numpy().astype(np.float16))
print(f"Exported RoPE tables to {output_dir}")
manifest = {
"model": "plapre-pico",
"version": "1.0",
"context_length": MAX_CONTEXT,
"prefill_length": PREFILL_SEQ_LEN,
"vocab_size": VOCAB_SIZE,
"num_layers": NUM_LAYERS,
"hidden_size": HIDDEN_SIZE,
"num_kv_heads": NUM_KV_HEADS,
"head_dim": HEAD_DIM,
"speaker_dim": SPEAKER_DIM,
"precision": "float16",
}
manifest_path = output_dir / "manifest.json"
with open(manifest_path, "w") as f:
json.dump(manifest, f, indent=2)
print(f"Wrote manifest to {manifest_path}")
def convert_llm(output_dir: Path, model_dir: Path | None = None) -> Path:
"""Convert Plapre Pico LLM end-to-end: download → load → trace → convert →
inject state updates → copy assets. Returns path to PlaprePico.mlpackage."""
if model_dir is None:
model_dir = download_model()
output_dir.mkdir(parents=True, exist_ok=True)
weights = load_weights(model_dir)
print("\n=== Building decode model ===")
decode = PlaprePico()
populate_weights(decode, weights)
decode = decode.half()
out_path = convert_decode(decode, output_dir)
print("\n=== Copying assets ===")
copy_assets(model_dir, output_dir)
print(f"\nLLM conversion complete: {out_path}")
return out_path
def main():
parser = argparse.ArgumentParser(description="Convert Plapre Pico LLM to CoreML")
parser.add_argument("--model-dir", type=str, help="Path to downloaded model directory")
parser.add_argument("--output-dir", type=str, default=str(Path(__file__).parent.parent), help="Output directory")
args = parser.parse_args()
convert_llm(
output_dir=Path(args.output_dir),
model_dir=Path(args.model_dir) if args.model_dir else None,
)
if __name__ == "__main__":
main()