File size: 12,662 Bytes
cb20bed 95c6137 cb20bed e79fa0a cb20bed 95c6137 cb20bed 95c6137 cb20bed 95c6137 cb20bed 95c6137 cb20bed e79fa0a cb20bed d12b4ea fad9fad cb20bed d12b4ea cb20bed d12b4ea cb20bed fad9fad d12b4ea fad9fad cb20bed 10bad22 cb20bed e79fa0a 95c6137 cb20bed e79fa0a cb20bed a2c97d7 cb20bed a2c97d7 cb20bed a2c97d7 cb20bed e79fa0a cb20bed 95c6137 cb20bed e79fa0a cb20bed e79fa0a cb20bed | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 | #!/usr/bin/env python3
"""
Main conversion script: downloads Plapre Pico weights from HuggingFace,
builds the decode wrapper model, traces, and converts to CoreML.
Usage:
python convert.py [--model-dir PATH] [--output-dir PATH]
If --model-dir is not provided, downloads from syvai/plapre-pico.
"""
import argparse
import json
import shutil
from pathlib import Path
import numpy as np
import torch
import coremltools as ct
from coremltools.converters.mil.mil import Builder as mb
from huggingface_hub import snapshot_download
from safetensors.torch import load_file
from attention import precompute_rope_frequencies
from model_wrapper import (
PlaprePico,
NUM_LAYERS,
NUM_KV_HEADS,
MAX_CONTEXT,
HEAD_DIM,
PREFILL_SEQ_LEN,
VOCAB_SIZE,
HIDDEN_SIZE,
SPEAKER_DIM,
)
def download_model(model_id: str = "syvai/plapre-pico") -> Path:
"""Download model from HuggingFace Hub, return local path."""
print(f"Downloading {model_id}...")
path = snapshot_download(model_id)
print(f"Model cached at: {path}")
return Path(path)
def load_weights(model_dir: Path) -> dict[str, torch.Tensor]:
"""Load model.safetensors and speaker_proj.pt, cast bf16 → fp16."""
weights = {}
safetensors_path = model_dir / "model.safetensors"
print(f"Loading {safetensors_path}...")
st_weights = load_file(str(safetensors_path))
for name, tensor in st_weights.items():
if tensor.dtype == torch.bfloat16:
tensor = tensor.to(torch.float16)
weights[name] = tensor
speaker_proj_path = model_dir / "speaker_proj.pt"
print(f"Loading {speaker_proj_path}...")
sp_weights = torch.load(str(speaker_proj_path), map_location="cpu", weights_only=True)
for name, tensor in sp_weights.items():
if tensor.dtype == torch.bfloat16:
tensor = tensor.to(torch.float16)
weights[f"speaker_proj.{name}"] = tensor
print(f"Loaded {len(weights)} weight tensors")
return weights
def _map_weight_name(hf_name: str) -> str | None:
"""Map HuggingFace weight name to our model's parameter name."""
if hf_name == "model.embed_tokens.weight":
return "embed_tokens.weight"
if hf_name == "model.norm.weight":
return "norm.weight"
if hf_name == "lm_head.weight":
return None # tied to embed_tokens
if hf_name.startswith("model.layers."):
rest = hf_name[len("model.layers."):]
parts = rest.split(".", 1)
layer_idx = parts[0]
component = parts[1]
mapping = {
"self_attn.q_proj.weight": "self_attn.q_proj.weight",
"self_attn.k_proj.weight": "self_attn.k_proj.weight",
"self_attn.v_proj.weight": "self_attn.v_proj.weight",
"self_attn.o_proj.weight": "self_attn.o_proj.weight",
"mlp.gate_proj.weight": "mlp.gate_proj.weight",
"mlp.up_proj.weight": "mlp.up_proj.weight",
"mlp.down_proj.weight": "mlp.down_proj.weight",
"input_layernorm.weight": "input_layernorm.weight",
"post_attention_layernorm.weight": "post_attention_layernorm.weight",
}
if component in mapping:
return f"layers.{layer_idx}.{mapping[component]}"
if hf_name.startswith("speaker_proj."):
return hf_name
print(f" WARNING: unmapped weight: {hf_name}")
return None
def populate_weights(model: torch.nn.Module, weights: dict[str, torch.Tensor]):
"""Load weights into a PlaprePico model."""
state_dict = model.state_dict()
new_state = {}
for hf_name, tensor in weights.items():
our_name = _map_weight_name(hf_name)
if our_name is None:
continue
if our_name in state_dict:
if state_dict[our_name].shape != tensor.shape:
print(f" Shape mismatch for {our_name}: "
f"expected {state_dict[our_name].shape}, got {tensor.shape}")
continue
new_state[our_name] = tensor
missing, unexpected = model.load_state_dict(new_state, strict=False)
missing = [k for k in missing if not k.startswith(("k_cache_", "v_cache_", "rope_"))]
if missing:
print(f" Missing weights: {missing}")
if unexpected:
print(f" Unexpected weights: {unexpected}")
print(f" Loaded {len(new_state)} weight tensors")
def build_kv_cache_states() -> list:
"""Build CoreML StateType list for 60 KV cache buffers."""
states = []
for i in range(NUM_LAYERS):
states.append(
ct.StateType(
wrapped_type=ct.TensorType(
shape=(1, NUM_KV_HEADS, MAX_CONTEXT, HEAD_DIM),
dtype=np.float16,
),
name=f"k_cache_{i}",
)
)
states.append(
ct.StateType(
wrapped_type=ct.TensorType(
shape=(1, NUM_KV_HEADS, MAX_CONTEXT, HEAD_DIM),
dtype=np.float16,
),
name=f"v_cache_{i}",
)
)
return states
def convert_decode(model: PlaprePico, output_dir: Path):
"""Trace and convert decode model to CoreML."""
model.eval()
print("Tracing decode model...")
input_ids = torch.zeros(1, 1, dtype=torch.int32)
causal_mask = torch.full((1, 1, 1, MAX_CONTEXT), float("-inf"), dtype=torch.float16)
causal_mask[0, 0, 0, :PREFILL_SEQ_LEN] = 0.0
cos = torch.zeros(1, 1, 1, HEAD_DIM, dtype=torch.float16)
sin = torch.zeros(1, 1, 1, HEAD_DIM, dtype=torch.float16)
update_mask = torch.zeros(1, 1, MAX_CONTEXT, 1, dtype=torch.float16)
update_mask[0, 0, 0, 0] = 1.0 # any valid position for tracing
speaker_embedding = torch.zeros(1, SPEAKER_DIM, dtype=torch.float16)
is_speaker_step = torch.zeros(1, dtype=torch.float16)
with torch.no_grad():
traced = torch.jit.trace(model, (
input_ids, causal_mask, cos, sin, update_mask,
speaker_embedding, is_speaker_step,
))
print("Converting decode to CoreML...")
mlmodel = ct.convert(
traced,
inputs=[
ct.TensorType(name="input_ids", shape=(1, 1), dtype=np.int32),
ct.TensorType(
name="causal_mask",
shape=(1, 1, 1, MAX_CONTEXT),
dtype=np.float16,
),
ct.TensorType(name="cos", shape=(1, 1, 1, HEAD_DIM), dtype=np.float16),
ct.TensorType(name="sin", shape=(1, 1, 1, HEAD_DIM), dtype=np.float16),
ct.TensorType(
name="update_mask",
shape=(1, 1, MAX_CONTEXT, 1),
dtype=np.float16,
),
ct.TensorType(
name="speaker_embedding",
shape=(1, SPEAKER_DIM),
dtype=np.float16,
),
ct.TensorType(
name="is_speaker_step",
shape=(1,),
dtype=np.float16,
),
],
outputs=[ct.TensorType(name="logits", dtype=np.float16)],
states=build_kv_cache_states(),
compute_precision=ct.precision.FLOAT16,
minimum_deployment_target=ct.target.iOS18,
)
inject_state_updates(mlmodel)
out_path = output_dir / "PlaprePico.mlpackage"
mlmodel.save(str(out_path))
print(f"Saved decode model to {out_path}")
return out_path
def inject_state_updates(mlmodel):
"""Inject coreml_update_state ops into a converted stateful CoreML model.
torch.jit.trace doesn't emit prim::SetAttr for buffer mutations, so coremltools
can't generate coreml_update_state ops automatically. This walks the MIL graph,
finds the read_state -> (cast?) -> mul -> add cache update pattern, and inserts
coreml_update_state ops before the first consumer of each cache update.
"""
prog = mlmodel._mil_program
main_fn = prog.functions["main"]
read_ops = list(main_fn.find_ops(op_type="read_state"))
print(f"Found {len(read_ops)} read_state ops")
updates = []
for read_op in read_ops:
state_var = read_op.inputs["input"]
output = read_op.outputs[0]
# FLOAT32: read_state -> cast(fp16->fp32) -> mul -> add
# FLOAT16: read_state -> mul -> add
first_child = output.child_ops[0]
search_output = first_child.outputs[0] if first_child.op_type == "cast" else output
mul_op = next((c for c in search_output.child_ops if c.op_type == "mul"), None)
if mul_op is None:
print(f" WARNING: no mul found for {state_var.name}")
continue
add_op = next((c for c in mul_op.outputs[0].child_ops if c.op_type == "add"), None)
if add_op is None:
print(f" WARNING: no add found for {state_var.name}")
continue
updates.append((state_var, add_op))
print(f"Injecting {len(updates)} coreml_update_state ops...")
block = main_fn.find_ops(op_type="read_state")[0].enclosing_block
with block:
for state_var, add_op in updates:
add_out = add_op.outputs[0]
consumers = list(add_out.child_ops)
if not consumers:
print(f" WARNING: no consumers for {state_var.name} add output")
continue
first_consumer = consumers[0]
with mb.set_before_op(before_op=first_consumer):
if str(add_out.dtype) == "fp16":
state_val = add_out
else:
state_val = mb.cast(
x=add_out, dtype="fp16",
name=f"state_cast_{state_var.name}",
)
mb.coreml_update_state(
state=state_var, value=state_val,
name=f"state_update_{state_var.name}",
)
prog_str = str(prog)
print(f" read_state: {prog_str.count('read_state')}")
print(f" coreml_update_state: {prog_str.count('coreml_update_state')}")
def copy_assets(model_dir: Path, output_dir: Path):
"""Copy tokenizer.json, speakers.json, and RoPE tables to output root."""
for filename in ["tokenizer.json", "speakers.json"]:
src = model_dir / filename
if src.exists():
shutil.copy2(src, output_dir / filename)
print(f"Copied {filename} to {output_dir}")
else:
print(f" WARNING: {filename} not found in {model_dir}")
# Export RoPE tables for the iOS runtime to build cos/sin/update_mask inputs
cos_full, sin_full = precompute_rope_frequencies(HEAD_DIM, MAX_CONTEXT, 100000.0)
np.save(str(output_dir / "rope_cos.npy"), cos_full.numpy().astype(np.float16))
np.save(str(output_dir / "rope_sin.npy"), sin_full.numpy().astype(np.float16))
print(f"Exported RoPE tables to {output_dir}")
manifest = {
"model": "plapre-pico",
"version": "1.0",
"context_length": MAX_CONTEXT,
"prefill_length": PREFILL_SEQ_LEN,
"vocab_size": VOCAB_SIZE,
"num_layers": NUM_LAYERS,
"hidden_size": HIDDEN_SIZE,
"num_kv_heads": NUM_KV_HEADS,
"head_dim": HEAD_DIM,
"speaker_dim": SPEAKER_DIM,
"precision": "float16",
}
manifest_path = output_dir / "manifest.json"
with open(manifest_path, "w") as f:
json.dump(manifest, f, indent=2)
print(f"Wrote manifest to {manifest_path}")
def convert_llm(output_dir: Path, model_dir: Path | None = None) -> Path:
"""Convert Plapre Pico LLM end-to-end: download → load → trace → convert →
inject state updates → copy assets. Returns path to PlaprePico.mlpackage."""
if model_dir is None:
model_dir = download_model()
output_dir.mkdir(parents=True, exist_ok=True)
weights = load_weights(model_dir)
print("\n=== Building decode model ===")
decode = PlaprePico()
populate_weights(decode, weights)
decode = decode.half()
out_path = convert_decode(decode, output_dir)
print("\n=== Copying assets ===")
copy_assets(model_dir, output_dir)
print(f"\nLLM conversion complete: {out_path}")
return out_path
def main():
parser = argparse.ArgumentParser(description="Convert Plapre Pico LLM to CoreML")
parser.add_argument("--model-dir", type=str, help="Path to downloaded model directory")
parser.add_argument("--output-dir", type=str, default=str(Path(__file__).parent.parent), help="Output directory")
args = parser.parse_args()
convert_llm(
output_dir=Path(args.output_dir),
model_dir=Path(args.model_dir) if args.model_dir else None,
)
if __name__ == "__main__":
main()
|