Daniel Rothmann commited on
Commit
e79fa0a
·
1 Parent(s): fed9119

Tidy up repo and conversion scripts

Browse files
KanadeDecoder.mlpackage/Data/com.apple.CoreML/model.mlmodel CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a443e8a50213febe82723e6bc41696fa420a0b8050eb626bffdc76f1a7c36e0b
3
- size 199420
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee0b31302fb2709b2ad5cde88ed77e8dde6170f4b711a03835b1ee1b17fb60d1
3
+ size 199364
KanadeDecoder.mlpackage/Manifest.json CHANGED
@@ -1,18 +1,18 @@
1
  {
2
  "fileFormatVersion": "1.0.0",
3
  "itemInfoEntries": {
4
- "C2BF25B7-4DB5-4C6F-A4B8-4BC558D0C7E7": {
5
- "author": "com.apple.CoreML",
6
- "description": "CoreML Model Specification",
7
- "name": "model.mlmodel",
8
- "path": "com.apple.CoreML/model.mlmodel"
9
- },
10
- "FFA4DD20-41E6-4D3A-B810-5F438BD1787B": {
11
  "author": "com.apple.CoreML",
12
  "description": "CoreML Model Weights",
13
  "name": "weights",
14
  "path": "com.apple.CoreML/weights"
 
 
 
 
 
 
15
  }
16
  },
17
- "rootModelIdentifier": "C2BF25B7-4DB5-4C6F-A4B8-4BC558D0C7E7"
18
  }
 
1
  {
2
  "fileFormatVersion": "1.0.0",
3
  "itemInfoEntries": {
4
+ "0E204957-794C-49DA-B444-2E03C7B62509": {
 
 
 
 
 
 
5
  "author": "com.apple.CoreML",
6
  "description": "CoreML Model Weights",
7
  "name": "weights",
8
  "path": "com.apple.CoreML/weights"
9
+ },
10
+ "9CF9C1BF-D9A1-4092-ADCC-2D70ECBAAFB3": {
11
+ "author": "com.apple.CoreML",
12
+ "description": "CoreML Model Specification",
13
+ "name": "model.mlmodel",
14
+ "path": "com.apple.CoreML/model.mlmodel"
15
  }
16
  },
17
+ "rootModelIdentifier": "9CF9C1BF-D9A1-4092-ADCC-2D70ECBAAFB3"
18
  }
PlaprePico.mlpackage/Data/com.apple.CoreML/model.mlmodel CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b0ea4fbe5939f8db381da0ccadf9e90b61c82f5f0eca58b46e89b3a5541a49f0
3
- size 957824
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eb9711d00982520e62e667a7f524df950ca1d1080991bdd9cba6d2327b891511
3
+ size 956591
PlaprePico.mlpackage/Manifest.json CHANGED
@@ -1,18 +1,18 @@
1
  {
2
  "fileFormatVersion": "1.0.0",
3
  "itemInfoEntries": {
4
- "1F911078-42FE-4F91-A2D0-E5B86F87F7AD": {
5
  "author": "com.apple.CoreML",
6
  "description": "CoreML Model Weights",
7
  "name": "weights",
8
  "path": "com.apple.CoreML/weights"
9
  },
10
- "3E69D1BF-E09D-43D9-A7FE-E3B15CDDF0BD": {
11
  "author": "com.apple.CoreML",
12
  "description": "CoreML Model Specification",
13
  "name": "model.mlmodel",
14
  "path": "com.apple.CoreML/model.mlmodel"
15
  }
16
  },
17
- "rootModelIdentifier": "3E69D1BF-E09D-43D9-A7FE-E3B15CDDF0BD"
18
  }
 
1
  {
2
  "fileFormatVersion": "1.0.0",
3
  "itemInfoEntries": {
4
+ "6F441A1D-F82F-44B2-A345-005396E0926A": {
5
  "author": "com.apple.CoreML",
6
  "description": "CoreML Model Weights",
7
  "name": "weights",
8
  "path": "com.apple.CoreML/weights"
9
  },
10
+ "FAB5B87A-358B-488E-A85C-97F3F8FDFA45": {
11
  "author": "com.apple.CoreML",
12
  "description": "CoreML Model Specification",
13
  "name": "model.mlmodel",
14
  "path": "com.apple.CoreML/model.mlmodel"
15
  }
16
  },
17
+ "rootModelIdentifier": "FAB5B87A-358B-488E-A85C-97F3F8FDFA45"
18
  }
PlaprePico_int4.mlpackage/Data/com.apple.CoreML/model.mlmodel CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a04406d6835f50b4a6a94b028474305217518acd50352b91ba071afbb7a70a45
3
- size 935132
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a156dd782a9aabc95ce6b1e1a34b05a4c71557a96e36cb69e07833a65c530e04
3
+ size 986129
PlaprePico_int4.mlpackage/Data/com.apple.CoreML/weights/weight.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b9ef8ca208c59506b6b4205b59c27ec1e2a838338e969c271fb90db332a0c68e
3
- size 59654148
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dbfed1ebb98d1541a462d6ec13e11b22d086c6ea7088ecdcc1c3c687fd99ee2f
3
+ size 59614916
PlaprePico_int4.mlpackage/Manifest.json CHANGED
@@ -1,18 +1,18 @@
1
  {
2
  "fileFormatVersion": "1.0.0",
3
  "itemInfoEntries": {
4
- "B95C3123-5C93-4485-B674-506310CC30FA": {
5
- "author": "com.apple.CoreML",
6
- "description": "CoreML Model Weights",
7
- "name": "weights",
8
- "path": "com.apple.CoreML/weights"
9
- },
10
- "EF542C51-0401-4C0A-9713-6B90D89D10D4": {
11
  "author": "com.apple.CoreML",
12
  "description": "CoreML Model Specification",
13
  "name": "model.mlmodel",
14
  "path": "com.apple.CoreML/model.mlmodel"
 
 
 
 
 
 
15
  }
16
  },
17
- "rootModelIdentifier": "EF542C51-0401-4C0A-9713-6B90D89D10D4"
18
  }
 
1
  {
2
  "fileFormatVersion": "1.0.0",
3
  "itemInfoEntries": {
4
+ "ABEB1845-9A9A-4CDB-AACD-335B4EEE0328": {
 
 
 
 
 
 
5
  "author": "com.apple.CoreML",
6
  "description": "CoreML Model Specification",
7
  "name": "model.mlmodel",
8
  "path": "com.apple.CoreML/model.mlmodel"
9
+ },
10
+ "F3186CBA-EC22-47C9-AAD1-AE8E1C7669C8": {
11
+ "author": "com.apple.CoreML",
12
+ "description": "CoreML Model Weights",
13
+ "name": "weights",
14
+ "path": "com.apple.CoreML/weights"
15
  }
16
  },
17
+ "rootModelIdentifier": "ABEB1845-9A9A-4CDB-AACD-335B4EEE0328"
18
  }
PlaprePico_int8.mlpackage/Data/com.apple.CoreML/model.mlmodel CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:dac4b9812c8cd6a53a43d2e1d5dc92daf112461026334e39ce104e7b4d20c488
3
- size 935132
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7dd7913ab08a8436f8be1d66604de5af00494f3e7c6512590e002678417dc1ee
3
+ size 986129
PlaprePico_int8.mlpackage/Data/com.apple.CoreML/weights/weight.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:fa4953349237cf80de350967e34ef51e1532eb0e1f4472c474047a0bc12bc0e8
3
- size 118766148
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f3f2ba6cb15545f208241e0c8e3bc148167ee252c6d25029ef0b36c795f48f9
3
+ size 118726916
PlaprePico_int8.mlpackage/Manifest.json CHANGED
@@ -1,18 +1,18 @@
1
  {
2
  "fileFormatVersion": "1.0.0",
3
  "itemInfoEntries": {
4
- "8A809F7D-AC26-4562-8336-52FEB0717F2B": {
5
- "author": "com.apple.CoreML",
6
- "description": "CoreML Model Specification",
7
- "name": "model.mlmodel",
8
- "path": "com.apple.CoreML/model.mlmodel"
9
- },
10
- "B96C4545-8329-431E-A1F1-6DB58F13ACA9": {
11
  "author": "com.apple.CoreML",
12
  "description": "CoreML Model Weights",
13
  "name": "weights",
14
  "path": "com.apple.CoreML/weights"
 
 
 
 
 
 
15
  }
16
  },
17
- "rootModelIdentifier": "8A809F7D-AC26-4562-8336-52FEB0717F2B"
18
  }
 
1
  {
2
  "fileFormatVersion": "1.0.0",
3
  "itemInfoEntries": {
4
+ "0758A291-73CC-483A-896D-F8A0679A8DDB": {
 
 
 
 
 
 
5
  "author": "com.apple.CoreML",
6
  "description": "CoreML Model Weights",
7
  "name": "weights",
8
  "path": "com.apple.CoreML/weights"
9
+ },
10
+ "A80CC513-9349-459D-9D4B-327E783D5596": {
11
+ "author": "com.apple.CoreML",
12
+ "description": "CoreML Model Specification",
13
+ "name": "model.mlmodel",
14
+ "path": "com.apple.CoreML/model.mlmodel"
15
  }
16
  },
17
+ "rootModelIdentifier": "A80CC513-9349-459D-9D4B-327E783D5596"
18
  }
Vocoder.mlpackage/Data/com.apple.CoreML/model.mlmodel CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:88536c7f82ce5963c40ab46ab192452ddd1af731ecd4e08a40ea827fc544fbb6
3
- size 1298694
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:982ea75647bdbbcec542f1b19b739ac2620e115e75a27a05e5c6eb9794422c41
3
+ size 1298631
Vocoder.mlpackage/Manifest.json CHANGED
@@ -1,18 +1,18 @@
1
  {
2
  "fileFormatVersion": "1.0.0",
3
  "itemInfoEntries": {
4
- "1865D6B1-DF08-4C5C-8B25-53058EF04D75": {
5
  "author": "com.apple.CoreML",
6
  "description": "CoreML Model Specification",
7
  "name": "model.mlmodel",
8
  "path": "com.apple.CoreML/model.mlmodel"
9
  },
10
- "6D12622B-E675-4537-9163-574EA27CA0C1": {
11
  "author": "com.apple.CoreML",
12
  "description": "CoreML Model Weights",
13
  "name": "weights",
14
  "path": "com.apple.CoreML/weights"
15
  }
16
  },
17
- "rootModelIdentifier": "1865D6B1-DF08-4C5C-8B25-53058EF04D75"
18
  }
 
1
  {
2
  "fileFormatVersion": "1.0.0",
3
  "itemInfoEntries": {
4
+ "62101658-5B4C-4B17-B935-8DB8A3E815C9": {
5
  "author": "com.apple.CoreML",
6
  "description": "CoreML Model Specification",
7
  "name": "model.mlmodel",
8
  "path": "com.apple.CoreML/model.mlmodel"
9
  },
10
+ "DB7344E2-6983-4F26-9909-B8957A9147DE": {
11
  "author": "com.apple.CoreML",
12
  "description": "CoreML Model Weights",
13
  "name": "weights",
14
  "path": "com.apple.CoreML/weights"
15
  }
16
  },
17
+ "rootModelIdentifier": "62101658-5B4C-4B17-B935-8DB8A3E815C9"
18
  }
manifest.json CHANGED
@@ -1,7 +1,7 @@
1
  {
2
  "model": "plapre-pico",
3
  "version": "1.0",
4
- "context_length": 2048,
5
  "prefill_length": 512,
6
  "vocab_size": 20802,
7
  "num_layers": 30,
 
1
  {
2
  "model": "plapre-pico",
3
  "version": "1.0",
4
+ "context_length": 512,
5
  "prefill_length": 512,
6
  "vocab_size": 20802,
7
  "num_layers": 30,
scripts/build.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Unified build entry point for the Plapre Pico CoreML pipeline.
4
+
5
+ Builds all three CoreML models (Plapre Pico LLM, Kanade decoder, HiFT vocoder)
6
+ and optionally produces quantized variants of the LLM.
7
+
8
+ Usage:
9
+ python scripts/build.py # build all 3 models
10
+ python scripts/build.py --quantize int4 # + PlaprePico_int4.mlpackage
11
+ python scripts/build.py --quantize int4 --quantize int8 # both quantizations
12
+ python scripts/build.py --skip llm # only rebuild audio models
13
+ python scripts/build.py --skip audio # only rebuild LLM
14
+ python scripts/build.py --output-dir PATH # default: repo root
15
+ python scripts/build.py --num-tokens 100 # passed through to audio
16
+ """
17
+
18
+ import argparse
19
+ from pathlib import Path
20
+
21
+ from convert_llm import convert_llm
22
+ from convert_audio import convert_audio
23
+ from quantize import quantize_model
24
+
25
+
26
+ REPO_ROOT = Path(__file__).parent.parent
27
+
28
+
29
+ def _dir_size_mb(path: Path) -> float:
30
+ if not path.exists():
31
+ return 0.0
32
+ return sum(f.stat().st_size for f in path.rglob("*") if f.is_file()) / 1e6
33
+
34
+
35
+ def main():
36
+ parser = argparse.ArgumentParser(description="Build full Plapre Pico CoreML pipeline")
37
+ parser.add_argument("--output-dir", type=str, default=str(REPO_ROOT))
38
+ parser.add_argument("--model-dir", type=str, default=None,
39
+ help="Local Plapre Pico HF snapshot (otherwise downloaded)")
40
+ parser.add_argument("--num-tokens", type=int, default=100,
41
+ help="Audio token count for vocoder mel length")
42
+ parser.add_argument("--quantize", action="append", choices=["int4", "int8"], default=[],
43
+ help="Produce quantized LLM variant(s); may be repeated")
44
+ parser.add_argument("--skip", action="append", choices=["llm", "audio"], default=[],
45
+ help="Skip a stage")
46
+ args = parser.parse_args()
47
+
48
+ output_dir = Path(args.output_dir)
49
+ output_dir.mkdir(parents=True, exist_ok=True)
50
+ artifacts: list[Path] = []
51
+
52
+ if "llm" not in args.skip:
53
+ print("\n========== LLM ==========")
54
+ llm_path = convert_llm(
55
+ output_dir=output_dir,
56
+ model_dir=Path(args.model_dir) if args.model_dir else None,
57
+ )
58
+ artifacts.append(llm_path)
59
+ for filename in ["rope_cos.npy", "rope_sin.npy", "manifest.json",
60
+ "tokenizer.json", "speakers.json"]:
61
+ p = output_dir / filename
62
+ if p.exists():
63
+ artifacts.append(p)
64
+ else:
65
+ llm_path = output_dir / "PlaprePico.mlpackage"
66
+
67
+ for q in args.quantize:
68
+ print(f"\n========== Quantize {q} ==========")
69
+ if not llm_path.exists():
70
+ print(f" SKIP: {llm_path} not found (run without --skip llm first)")
71
+ continue
72
+ bits = int(q[3:])
73
+ out = output_dir / f"PlaprePico_{q}.mlpackage"
74
+ quantize_model(llm_path, out, bits)
75
+ artifacts.append(out)
76
+
77
+ if "audio" not in args.skip:
78
+ print("\n========== Audio (Kanade + Vocoder) ==========")
79
+ kanade_path, vocoder_path = convert_audio(output_dir, args.num_tokens)
80
+ artifacts.extend([kanade_path, vocoder_path])
81
+
82
+ print("\n========== Build summary ==========")
83
+ for p in artifacts:
84
+ size = _dir_size_mb(p) if p.is_dir() else (p.stat().st_size / 1e6 if p.exists() else 0)
85
+ print(f" {p.name:40s} {size:8.1f} MB")
86
+ print(f"\nOutput directory: {output_dir}")
87
+
88
+
89
+ if __name__ == "__main__":
90
+ main()
scripts/{convert_kanade.py → convert_audio.py} RENAMED
@@ -682,7 +682,13 @@ def main():
682
  help="Fixed number of audio tokens (determines mel length)",
683
  )
684
  args = parser.parse_args()
685
- output_dir = Path(args.output_dir)
 
 
 
 
 
 
686
  output_dir.mkdir(parents=True, exist_ok=True)
687
 
688
  print("Loading Kanade model...")
@@ -690,20 +696,20 @@ def main():
690
  patch_kanade_for_coreml(kanade)
691
  vocoder = load_vocoder(kanade.config.vocoder_name).eval().float()
692
 
693
- # Compute mel_length for this token count
694
  mel_length = kanade._calculate_target_mel_length(
695
- kanade._calculate_original_audio_length(args.num_tokens)
696
  )
697
 
698
  print(f"\n=== Converting Kanade decoder ===")
699
- convert_kanade_decoder(kanade, args.num_tokens, output_dir)
700
 
701
  print(f"\n=== Converting full vocoder (mel → waveform) ===")
702
  convert_full_vocoder(vocoder, mel_length, output_dir)
703
 
704
- print("\nDone!")
705
- print(f" KanadeDecoder: {args.num_tokens} tokens → mel (80, {mel_length})")
706
  print(f" Vocoder: mel (80, {mel_length}) → waveform")
 
707
 
708
 
709
  def convert_full_vocoder(vocoder, mel_length: int, output_dir: Path):
 
682
  help="Fixed number of audio tokens (determines mel length)",
683
  )
684
  args = parser.parse_args()
685
+ convert_audio(Path(args.output_dir), args.num_tokens)
686
+
687
+
688
+ def convert_audio(output_dir: Path, num_tokens: int = 100) -> tuple[Path, Path]:
689
+ """Convert Kanade decoder + HiFT vocoder to CoreML.
690
+
691
+ Returns (KanadeDecoder.mlpackage, Vocoder.mlpackage) paths."""
692
  output_dir.mkdir(parents=True, exist_ok=True)
693
 
694
  print("Loading Kanade model...")
 
696
  patch_kanade_for_coreml(kanade)
697
  vocoder = load_vocoder(kanade.config.vocoder_name).eval().float()
698
 
 
699
  mel_length = kanade._calculate_target_mel_length(
700
+ kanade._calculate_original_audio_length(num_tokens)
701
  )
702
 
703
  print(f"\n=== Converting Kanade decoder ===")
704
+ convert_kanade_decoder(kanade, num_tokens, output_dir)
705
 
706
  print(f"\n=== Converting full vocoder (mel → waveform) ===")
707
  convert_full_vocoder(vocoder, mel_length, output_dir)
708
 
709
+ print("\nAudio conversion complete!")
710
+ print(f" KanadeDecoder: {num_tokens} tokens → mel (80, {mel_length})")
711
  print(f" Vocoder: mel (80, {mel_length}) → waveform")
712
+ return output_dir / "KanadeDecoder.mlpackage", output_dir / "Vocoder.mlpackage"
713
 
714
 
715
  def convert_full_vocoder(vocoder, mel_length: int, output_dir: Path):
scripts/{convert.py → convert_llm.py} RENAMED
@@ -17,6 +17,7 @@ from pathlib import Path
17
  import numpy as np
18
  import torch
19
  import coremltools as ct
 
20
  from huggingface_hub import snapshot_download
21
  from safetensors.torch import load_file
22
 
@@ -166,7 +167,7 @@ def convert_decode(model: PlaprePico, output_dir: Path):
166
  sin = torch.zeros(1, 1, 1, HEAD_DIM, dtype=torch.float16)
167
 
168
  update_mask = torch.zeros(1, 1, MAX_CONTEXT, 1, dtype=torch.float16)
169
- update_mask[0, 0, PREFILL_SEQ_LEN, 0] = 1.0
170
 
171
  speaker_embedding = torch.zeros(1, SPEAKER_DIM, dtype=torch.float16)
172
  is_speaker_step = torch.zeros(1, dtype=torch.float16)
@@ -211,9 +212,78 @@ def convert_decode(model: PlaprePico, output_dir: Path):
211
  minimum_deployment_target=ct.target.iOS18,
212
  )
213
 
 
 
214
  out_path = output_dir / "PlaprePico.mlpackage"
215
  mlmodel.save(str(out_path))
216
  print(f"Saved decode model to {out_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
 
219
  def copy_assets(model_dir: Path, output_dir: Path):
@@ -251,33 +321,37 @@ def copy_assets(model_dir: Path, output_dir: Path):
251
  print(f"Wrote manifest to {manifest_path}")
252
 
253
 
254
- def main():
255
- parser = argparse.ArgumentParser(description="Convert Plapre Pico to CoreML")
256
- parser.add_argument("--model-dir", type=str, help="Path to downloaded model directory")
257
- parser.add_argument("--output-dir", type=str, default=str(Path(__file__).parent.parent), help="Output directory")
258
- args = parser.parse_args()
259
-
260
- if args.model_dir:
261
- model_dir = Path(args.model_dir)
262
- else:
263
  model_dir = download_model()
264
 
265
- output_dir = Path(args.output_dir)
266
  output_dir.mkdir(parents=True, exist_ok=True)
267
-
268
  weights = load_weights(model_dir)
269
 
270
  print("\n=== Building decode model ===")
271
  decode = PlaprePico()
272
  populate_weights(decode, weights)
273
  decode = decode.half()
274
- convert_decode(decode, output_dir)
275
 
276
  print("\n=== Copying assets ===")
277
  copy_assets(model_dir, output_dir)
278
 
279
- print("\nConversion complete!")
280
- print(f"Output: {output_dir}")
 
 
 
 
 
 
 
 
 
 
 
281
 
282
 
283
  if __name__ == "__main__":
 
17
  import numpy as np
18
  import torch
19
  import coremltools as ct
20
+ from coremltools.converters.mil.mil import Builder as mb
21
  from huggingface_hub import snapshot_download
22
  from safetensors.torch import load_file
23
 
 
167
  sin = torch.zeros(1, 1, 1, HEAD_DIM, dtype=torch.float16)
168
 
169
  update_mask = torch.zeros(1, 1, MAX_CONTEXT, 1, dtype=torch.float16)
170
+ update_mask[0, 0, 0, 0] = 1.0 # any valid position for tracing
171
 
172
  speaker_embedding = torch.zeros(1, SPEAKER_DIM, dtype=torch.float16)
173
  is_speaker_step = torch.zeros(1, dtype=torch.float16)
 
212
  minimum_deployment_target=ct.target.iOS18,
213
  )
214
 
215
+ inject_state_updates(mlmodel)
216
+
217
  out_path = output_dir / "PlaprePico.mlpackage"
218
  mlmodel.save(str(out_path))
219
  print(f"Saved decode model to {out_path}")
220
+ return out_path
221
+
222
+
223
+ def inject_state_updates(mlmodel):
224
+ """Inject coreml_update_state ops into a converted stateful CoreML model.
225
+
226
+ torch.jit.trace doesn't emit prim::SetAttr for buffer mutations, so coremltools
227
+ can't generate coreml_update_state ops automatically. This walks the MIL graph,
228
+ finds the read_state -> (cast?) -> mul -> add cache update pattern, and inserts
229
+ coreml_update_state ops before the first consumer of each cache update.
230
+ """
231
+ prog = mlmodel._mil_program
232
+ main_fn = prog.functions["main"]
233
+
234
+ read_ops = list(main_fn.find_ops(op_type="read_state"))
235
+ print(f"Found {len(read_ops)} read_state ops")
236
+
237
+ updates = []
238
+ for read_op in read_ops:
239
+ state_var = read_op.inputs["input"]
240
+ output = read_op.outputs[0]
241
+
242
+ # FLOAT32: read_state -> cast(fp16->fp32) -> mul -> add
243
+ # FLOAT16: read_state -> mul -> add
244
+ first_child = output.child_ops[0]
245
+ search_output = first_child.outputs[0] if first_child.op_type == "cast" else output
246
+
247
+ mul_op = next((c for c in search_output.child_ops if c.op_type == "mul"), None)
248
+ if mul_op is None:
249
+ print(f" WARNING: no mul found for {state_var.name}")
250
+ continue
251
+
252
+ add_op = next((c for c in mul_op.outputs[0].child_ops if c.op_type == "add"), None)
253
+ if add_op is None:
254
+ print(f" WARNING: no add found for {state_var.name}")
255
+ continue
256
+
257
+ updates.append((state_var, add_op))
258
+
259
+ print(f"Injecting {len(updates)} coreml_update_state ops...")
260
+
261
+ block = main_fn.find_ops(op_type="read_state")[0].enclosing_block
262
+ with block:
263
+ for state_var, add_op in updates:
264
+ add_out = add_op.outputs[0]
265
+ consumers = list(add_out.child_ops)
266
+ if not consumers:
267
+ print(f" WARNING: no consumers for {state_var.name} add output")
268
+ continue
269
+ first_consumer = consumers[0]
270
+
271
+ with mb.set_before_op(before_op=first_consumer):
272
+ if str(add_out.dtype) == "fp16":
273
+ state_val = add_out
274
+ else:
275
+ state_val = mb.cast(
276
+ x=add_out, dtype="fp16",
277
+ name=f"state_cast_{state_var.name}",
278
+ )
279
+ mb.coreml_update_state(
280
+ state=state_var, value=state_val,
281
+ name=f"state_update_{state_var.name}",
282
+ )
283
+
284
+ prog_str = str(prog)
285
+ print(f" read_state: {prog_str.count('read_state')}")
286
+ print(f" coreml_update_state: {prog_str.count('coreml_update_state')}")
287
 
288
 
289
  def copy_assets(model_dir: Path, output_dir: Path):
 
321
  print(f"Wrote manifest to {manifest_path}")
322
 
323
 
324
+ def convert_llm(output_dir: Path, model_dir: Path | None = None) -> Path:
325
+ """Convert Plapre Pico LLM end-to-end: download → load → trace → convert →
326
+ inject state updates copy assets. Returns path to PlaprePico.mlpackage."""
327
+ if model_dir is None:
 
 
 
 
 
328
  model_dir = download_model()
329
 
 
330
  output_dir.mkdir(parents=True, exist_ok=True)
 
331
  weights = load_weights(model_dir)
332
 
333
  print("\n=== Building decode model ===")
334
  decode = PlaprePico()
335
  populate_weights(decode, weights)
336
  decode = decode.half()
337
+ out_path = convert_decode(decode, output_dir)
338
 
339
  print("\n=== Copying assets ===")
340
  copy_assets(model_dir, output_dir)
341
 
342
+ print(f"\nLLM conversion complete: {out_path}")
343
+ return out_path
344
+
345
+
346
+ def main():
347
+ parser = argparse.ArgumentParser(description="Convert Plapre Pico LLM to CoreML")
348
+ parser.add_argument("--model-dir", type=str, help="Path to downloaded model directory")
349
+ parser.add_argument("--output-dir", type=str, default=str(Path(__file__).parent.parent), help="Output directory")
350
+ args = parser.parse_args()
351
+ convert_llm(
352
+ output_dir=Path(args.output_dir),
353
+ model_dir=Path(args.model_dir) if args.model_dir else None,
354
+ )
355
 
356
 
357
  if __name__ == "__main__":
scripts/inject_state_updates.py DELETED
@@ -1,172 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Post-process a traced CoreML model to inject coreml_update_state ops.
4
-
5
- torch.jit.trace doesn't emit prim::SetAttr for buffer mutations, so coremltools
6
- can't generate coreml_update_state ops automatically. This script:
7
- 1. Loads the converted model
8
- 2. Finds the read_state -> computation -> cache update pattern
9
- 3. Injects coreml_update_state after each cache update
10
- 4. Saves the fixed model
11
- """
12
-
13
- import sys
14
- import numpy as np
15
- import torch
16
- import coremltools as ct
17
- from coremltools.converters.mil.mil import Builder as mb
18
- from coremltools.converters.mil.mil import types
19
- from pathlib import Path
20
-
21
- sys.path.insert(0, str(Path(__file__).parent))
22
-
23
- from model_wrapper import (
24
- PlaprePico, HIDDEN_SIZE, HEAD_DIM, MAX_CONTEXT,
25
- NUM_KV_HEADS, NUM_LAYERS, VOCAB_SIZE, SPEAKER_DIM,
26
- )
27
- from convert import load_weights, _map_weight_name, build_kv_cache_states
28
- from huggingface_hub import snapshot_download
29
-
30
-
31
- def convert_and_fix_decode(output_path: Path):
32
- """Convert decode model and inject coreml_update_state ops."""
33
- model_dir = Path(snapshot_download("syvai/plapre-pico"))
34
- weights = load_weights(model_dir)
35
-
36
- model = PlaprePico()
37
- sd = model.state_dict()
38
- ns = {}
39
- for k, v in weights.items():
40
- n = _map_weight_name(k)
41
- if n and n in sd and sd[n].shape == v.shape:
42
- ns[n] = v
43
- model.load_state_dict(ns, strict=False)
44
- model = model.half().eval()
45
-
46
- inputs = (
47
- torch.zeros(1, 1, dtype=torch.int32),
48
- torch.full((1, 1, 1, MAX_CONTEXT), float("-inf"), dtype=torch.float16),
49
- torch.zeros(1, 1, 1, HEAD_DIM, dtype=torch.float16),
50
- torch.zeros(1, 1, 1, HEAD_DIM, dtype=torch.float16),
51
- torch.zeros(1, 1, MAX_CONTEXT, 1, dtype=torch.float16),
52
- torch.zeros(1, SPEAKER_DIM, dtype=torch.float16),
53
- torch.zeros(1, dtype=torch.float16),
54
- )
55
-
56
- print("Tracing...")
57
- with torch.no_grad():
58
- traced = torch.jit.trace(model, inputs)
59
-
60
- print("Converting to CoreML...")
61
- mlmodel = ct.convert(
62
- traced,
63
- inputs=[
64
- ct.TensorType(name="input_ids", shape=(1, 1), dtype=np.int32),
65
- ct.TensorType(name="causal_mask", shape=(1, 1, 1, MAX_CONTEXT), dtype=np.float16),
66
- ct.TensorType(name="cos", shape=(1, 1, 1, HEAD_DIM), dtype=np.float16),
67
- ct.TensorType(name="sin", shape=(1, 1, 1, HEAD_DIM), dtype=np.float16),
68
- ct.TensorType(name="update_mask", shape=(1, 1, MAX_CONTEXT, 1), dtype=np.float16),
69
- ct.TensorType(name="speaker_embedding", shape=(1, SPEAKER_DIM), dtype=np.float16),
70
- ct.TensorType(name="is_speaker_step", shape=(1,), dtype=np.float16),
71
- ],
72
- outputs=[ct.TensorType(name="logits", dtype=np.float16)],
73
- states=build_kv_cache_states(),
74
- compute_precision=ct.precision.FLOAT16,
75
- minimum_deployment_target=ct.target.iOS18,
76
- )
77
-
78
- prog = mlmodel._mil_program
79
- main_fn = prog.functions["main"]
80
-
81
- # Find all read_state ops and trace to their cache update (add) ops
82
- read_ops = list(main_fn.find_ops(op_type="read_state"))
83
- print(f"Found {len(read_ops)} read_state ops")
84
-
85
- updates = []
86
- for read_op in read_ops:
87
- state_var = read_op.inputs["input"]
88
- output = read_op.outputs[0]
89
-
90
- # Follow the graph from read_state to the cache update (add) op.
91
- # With FLOAT32 precision: read_state -> cast(fp16->fp32) -> mul -> add
92
- # With FLOAT16 precision: read_state -> mul -> add (no cast needed)
93
- first_child = output.child_ops[0]
94
- if first_child.op_type == "cast":
95
- search_output = first_child.outputs[0]
96
- else:
97
- search_output = output
98
-
99
- mul_op = None
100
- for child in search_output.child_ops:
101
- if child.op_type == "mul":
102
- mul_op = child
103
- break
104
-
105
- if mul_op is None:
106
- print(f" WARNING: no mul found for {state_var.name}")
107
- continue
108
-
109
- mul_out = mul_op.outputs[0]
110
-
111
- add_op = None
112
- for child in mul_out.child_ops:
113
- if child.op_type == "add":
114
- add_op = child
115
- break
116
-
117
- if add_op is None:
118
- print(f" WARNING: no add found for {state_var.name}")
119
- continue
120
-
121
- updates.append((state_var, add_op))
122
-
123
- print(f"Injecting {len(updates)} coreml_update_state ops...")
124
-
125
- # Get the block
126
- block = main_fn.find_ops(op_type="read_state")[0].enclosing_block
127
- injected = 0
128
-
129
- with block:
130
- for state_var, add_op in updates:
131
- add_out = add_op.outputs[0]
132
-
133
- # Find the first consumer of add_out to insert before it
134
- consumers = list(add_out.child_ops)
135
- if not consumers:
136
- print(f" WARNING: no consumers for {state_var.name} add output")
137
- continue
138
-
139
- # Insert cast fp32->fp16 and coreml_update_state before the first consumer
140
- first_consumer = consumers[0]
141
-
142
- with mb.set_before_op(before_op=first_consumer):
143
- # Cast to fp16 if needed (fp32 precision produces fp32 add output)
144
- if str(add_out.dtype) == "fp16":
145
- state_val = add_out
146
- else:
147
- state_val = mb.cast(
148
- x=add_out, dtype="fp16",
149
- name=f"state_cast_{state_var.name}",
150
- )
151
- # Write updated cache back to state
152
- updated_val = mb.coreml_update_state(
153
- state=state_var, value=state_val,
154
- name=f"state_update_{state_var.name}",
155
- )
156
-
157
- injected += 1
158
-
159
- # Verify state injection
160
- prog_str = str(prog)
161
- print(f"After state injection:")
162
- print(f" read_state: {prog_str.count('read_state')}")
163
- print(f" coreml_update_state: {prog_str.count('coreml_update_state')}")
164
-
165
- print(f"Saving to {output_path}...")
166
- mlmodel.save(str(output_path))
167
- print("Done!")
168
-
169
-
170
- if __name__ == "__main__":
171
- output = Path(__file__).parent.parent / "PlaprePico.mlpackage"
172
- convert_and_fix_decode(output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/mixed_precision.py DELETED
@@ -1,184 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Convert PlaprePico with mixed precision: fp16 matmuls, fp32 RMSNorm+softmax.
4
-
5
- Uses coremltools' FP16ComputePrecision pass with op_selector to selectively
6
- downcast ops to fp16 while keeping numerically sensitive ops in fp32.
7
- """
8
-
9
- import sys
10
- import numpy as np
11
- import torch
12
- import coremltools as ct
13
- from coremltools.converters.mil.mil import Builder as mb
14
- from coremltools.converters.mil.mil.passes.defs.quantization import FP16ComputePrecision
15
- from pathlib import Path
16
-
17
- sys.path.insert(0, str(Path(__file__).parent))
18
-
19
- from model_wrapper import (
20
- PlaprePico, HIDDEN_SIZE, HEAD_DIM, MAX_CONTEXT,
21
- NUM_KV_HEADS, NUM_LAYERS, VOCAB_SIZE, SPEAKER_DIM,
22
- )
23
- from convert import load_weights, _map_weight_name, build_kv_cache_states
24
- from huggingface_hub import snapshot_download
25
-
26
-
27
- # Ops that MUST stay in fp32 (overflow-prone)
28
- FP32_OPS = {"reduce_mean", "softmax", "rsqrt", "reduce_sum", "pow"}
29
-
30
-
31
- def mixed_precision_selector(op):
32
- """Return True if this op should be cast to fp16, False to keep in fp32."""
33
- if op.op_type in FP32_OPS:
34
- return False
35
- return True
36
-
37
-
38
- def convert_mixed_precision(output_path: Path):
39
- """Convert with mixed precision: fp16 everywhere except RMSNorm+softmax."""
40
- model_dir = Path(snapshot_download("syvai/plapre-pico"))
41
- weights = load_weights(model_dir)
42
-
43
- model = PlaprePico()
44
- sd = model.state_dict()
45
- ns = {}
46
- for k, v in weights.items():
47
- n = _map_weight_name(k)
48
- if n and n in sd and sd[n].shape == v.shape:
49
- ns[n] = v
50
- model.load_state_dict(ns, strict=False)
51
- model = model.half().eval()
52
-
53
- inputs = (
54
- torch.zeros(1, 1, dtype=torch.int32),
55
- torch.full((1, 1, 1, MAX_CONTEXT), float("-inf"), dtype=torch.float16),
56
- torch.zeros(1, 1, 1, HEAD_DIM, dtype=torch.float16),
57
- torch.zeros(1, 1, 1, HEAD_DIM, dtype=torch.float16),
58
- torch.zeros(1, 1, MAX_CONTEXT, 1, dtype=torch.float16),
59
- torch.zeros(1, SPEAKER_DIM, dtype=torch.float16),
60
- torch.zeros(1, dtype=torch.float16),
61
- )
62
-
63
- print("Tracing...")
64
- with torch.no_grad():
65
- traced = torch.jit.trace(model, inputs)
66
-
67
- # Convert with NO automatic precision — we'll apply it manually
68
- print("Converting to CoreML (no precision pass)...")
69
- mlmodel = ct.convert(
70
- traced,
71
- inputs=[
72
- ct.TensorType(name="input_ids", shape=(1, 1), dtype=np.int32),
73
- ct.TensorType(name="causal_mask", shape=(1, 1, 1, MAX_CONTEXT), dtype=np.float16),
74
- ct.TensorType(name="cos", shape=(1, 1, 1, HEAD_DIM), dtype=np.float16),
75
- ct.TensorType(name="sin", shape=(1, 1, 1, HEAD_DIM), dtype=np.float16),
76
- ct.TensorType(name="update_mask", shape=(1, 1, MAX_CONTEXT, 1), dtype=np.float16),
77
- ct.TensorType(name="speaker_embedding", shape=(1, SPEAKER_DIM), dtype=np.float16),
78
- ct.TensorType(name="is_speaker_step", shape=(1,), dtype=np.float16),
79
- ],
80
- outputs=[ct.TensorType(name="logits", dtype=np.float16)],
81
- states=build_kv_cache_states(),
82
- compute_precision=ct.precision.FLOAT32,
83
- minimum_deployment_target=ct.target.iOS18,
84
- )
85
-
86
- prog = mlmodel._mil_program
87
- main_fn = prog.functions["main"]
88
-
89
- # === Step 1: Inject coreml_update_state ops FIRST (before fp16 pass changes graph) ===
90
- print("Injecting coreml_update_state ops...")
91
- read_ops = list(main_fn.find_ops(op_type="read_state"))
92
- print(f" Found {len(read_ops)} read_state ops")
93
-
94
- updates = []
95
- for read_op in read_ops:
96
- state_var = read_op.inputs["input"]
97
- output = read_op.outputs[0]
98
-
99
- first_child = output.child_ops[0]
100
- if first_child.op_type == "cast":
101
- search_output = first_child.outputs[0]
102
- else:
103
- search_output = output
104
-
105
- mul_op = None
106
- for child in search_output.child_ops:
107
- if child.op_type == "mul":
108
- mul_op = child
109
- break
110
-
111
- if mul_op is None:
112
- print(f" WARNING: no mul found for {state_var.name}")
113
- continue
114
-
115
- mul_out = mul_op.outputs[0]
116
-
117
- add_op = None
118
- for child in mul_out.child_ops:
119
- if child.op_type == "add":
120
- add_op = child
121
- break
122
-
123
- if add_op is None:
124
- print(f" WARNING: no add found for {state_var.name}")
125
- continue
126
-
127
- updates.append((state_var, add_op))
128
-
129
- print(f" Injecting {len(updates)} coreml_update_state ops...")
130
-
131
- block = main_fn.find_ops(op_type="read_state")[0].enclosing_block
132
-
133
- with block:
134
- for state_var, add_op in updates:
135
- add_out = add_op.outputs[0]
136
- consumers = list(add_out.child_ops)
137
- if not consumers:
138
- continue
139
-
140
- first_consumer = consumers[0]
141
- with mb.set_before_op(before_op=first_consumer):
142
- if str(add_out.dtype) == "fp16":
143
- state_val = add_out
144
- else:
145
- state_val = mb.cast(
146
- x=add_out, dtype="fp16",
147
- name=f"state_cast_{state_var.name}",
148
- )
149
- mb.coreml_update_state(
150
- state=state_var, value=state_val,
151
- name=f"state_update_{state_var.name}",
152
- )
153
-
154
- prog_str = str(prog)
155
- print(f" read_state: {prog_str.count('read_state')}")
156
- print(f" coreml_update_state: {prog_str.count('coreml_update_state')}")
157
-
158
- # === Step 2: Apply selective fp16 cast pass ===
159
- print("\nApplying mixed precision (fp16 matmuls, fp32 RMSNorm+softmax)...")
160
- mil_str = str(prog)
161
- print(f" Before: {mil_str.count('cast')} cast ops")
162
-
163
- fp16_pass = FP16ComputePrecision(op_selector=mixed_precision_selector)
164
- fp16_pass.apply(prog)
165
-
166
- mil_str = str(prog)
167
- print(f" After: {mil_str.count('cast')} cast ops")
168
-
169
- # Verify sensitive ops stayed fp32
170
- fp32_softmax = sum(1 for line in mil_str.split('\n') if 'softmax' in line and 'fp32' in line)
171
- fp32_reduce = sum(1 for line in mil_str.split('\n') if 'reduce_mean' in line and 'fp32' in line)
172
- fp16_linear = sum(1 for line in mil_str.split('\n') if 'linear' in line and 'fp16' in line)
173
- fp16_matmul = sum(1 for line in mil_str.split('\n') if 'matmul' in line and 'fp16' in line)
174
- print(f" fp32 softmax: {fp32_softmax}, fp32 reduce_mean: {fp32_reduce}")
175
- print(f" fp16 linear: {fp16_linear}, fp16 matmul: {fp16_matmul}")
176
-
177
- print(f"\nSaving to {output_path}...")
178
- mlmodel.save(str(output_path))
179
- print("Done!")
180
-
181
-
182
- if __name__ == "__main__":
183
- output = Path(__file__).parent.parent / "PlaprePico.mlpackage"
184
- convert_mixed_precision(output)