plapre-pico-coreml / TRIALS.md
Daniel Rothmann
Remove unnecessary retry for prefill
fed9119

Plapre Pico CoreML β€” Trial Log

Chronological record of attempts, failures, and fixes to port Plapre Pico from PyTorch to CoreML and optimize for real-time iOS inference.


Phase 1: CoreML Conversion

Trial 1 β€” Custom attention + RoPE

Approach: Rewrite all non-traceable HF components: explicit matmul attention, split-half RoPE with precomputed cos/sin, GQA via repeat_interleave, one-hot broadcast mask for KV cache writes.

Result: Success. Max abs error 0.02 vs HF, argmax exact match.

Trial 2 β€” Speaker conditioning

Approach: Blend-factor is_speaker_step flag replaces token embedding with speaker_proj(emb) at position 0.

Result: Identical logits to HF prepend approach.

Trial 3 β€” Softmax NaN on padded positions

Bug: softmax([-inf, -inf, ...]) = NaN on pad rows, propagates through model.

Fix: nan_to_num(0.0) after softmax.

Trial 4 β€” fp16 compute precision: all-zero output

Approach: compute_precision=ct.precision.FLOAT16.

Result: Failed at 3+ layers. RMSNorm x.pow(2) overflows fp16 (layer 4: 263^2 > 65504).

Workaround: compute_precision=FLOAT32 β€” works but slower.

Trial 5 β€” KV cache state not persisting

Bug: torch.jit.trace doesn't emit prim::SetAttr β†’ no coreml_update_state ops β†’ cache resets every call.

Fix: MIL post-processing (inject_state_updates.py) to inject 60 coreml_update_state ops.

Result: Prefill argmax matches PyTorch at all positions.


Phase 2: Kanade + Vocoder

Trial 6 β€” Kanade RoPE (split-half, wrong)

Bug: Kanade uses interleaved RoPE (complex multiply), not split-half like Llama.

Fix: Interleaved real-valued rotation. Mel diff: 0.014.

Trial 7 β€” Kanade missing local attention

Bug: Full attention instead of windowed (window_per_side=32). Mel diff: 5.58.

Fix: Banded attention mask. Mel diff: 0.45.

Trial 8 β€” Vocoder iSTFT

Problem: torch.istft unsupported. fold too slow for 24k frames.

Fix: Matmul with precomputed iDFT basis + static slice-add overlap. Max diff: 0.003 vs torch.istft.


Phase 3: Swift Pipeline

Trial 9 β€” MLMultiArray stride corruption

Bug: Padded strides (stride 208 != shape 200) corrupt data on read/recreate.

Fix: Pass MLMultiArray directly between models.

Trial 10 β€” Token-by-token prefill

Problem: MLState is per-model β€” can't share between prefill/decode models.

Fix: Process all tokens through decode model one at a time.

Trial 11 β€” BPE tokenizer

Bug: Naive char-level tokenizer produces wrong tokens.

Fix: Hardcoded tokens (CLI), swift-transformers package (iOS app).


Phase 4: GPU / ANE Acceleration

Trial 12 β€” GPU (.cpuAndGPU)

Error -14 on Mac M1. ~2 tok/s on iPhone 15.

Trial 13 β€” ANE (.cpuAndNeuralEngine)

Falls back to CPU. coreml_update_state not supported on ANE.

Trial 14 β€” Head padding (3β†’4 KV heads)

Still error -14. Not dimension-related.

Trial 15 β€” Stateless model (explicit KV I/O)

Works on GPU but 2.1 tok/s β€” 45MB KV cache copy per step. Reverted.


Phase 5: Quantization

Trial 16 β€” int8 quantization

120MB (vs 258MB fp16). Nearly identical logits. No speed improvement.

Trial 17 β€” int4 quantization

61MB. Some logit drift. No speed improvement.


Phase 6: FP16 Precision Fix

Trial 18 β€” Clamp RMSNorm (x.clamp(-240, 240))

Valid logits but incoherent speech.

Trial 19 β€” Mixed precision (fp16 matmul + fp32 RMSNorm)

Correct logits but 18,671 cast ops inserted. 10 tok/s on iPhone β€” same as fp32 baseline.

Trial 20 β€” fp16-safe RMSNorm with pre-scaling

scale = x.abs().amax(-1, keepdim=True).clamp(min=1.0)
x_scaled = x / scale
variance = x_scaled.pow(2).mean(-1, keepdim=True)
x_norm = x_scaled * torch.rsqrt(variance + eps)

Scale cancels: (x/s) / sqrt(mean((x/s)^2)) = x / sqrt(mean(x^2)). Pure fp16, correct output, coherent speech.


Phase 7: Performance Optimization

Trial 21 β€” Context length 2048β†’512

TTS rarely exceeds 200 positions. Decode: 10β†’16 tok/s.

Trial 22 β€” Pre-allocated MLMultiArrays + custom MLFeatureProvider

Allocate once, update in place. MLDictionaryFeatureProvider reads stale data β€” custom provider required. Decode: 16β†’19 tok/s.

Trial 23 β€” Direct fp16 top-K sampling

withUnsafeBufferPointer(ofType: Float16.self) instead of 20,802 NSNumber subscripts. Find top-K in fp16, convert only K values to Float32 for softmax. Decode: 19β†’50 tok/s.


Phase 8: macOS NaN Investigation

Model started producing non-deterministic all-NaN logits on Mac M1 (iPhone unaffected). Tried many things β€” coreml_update_state injection, fp16 vs fp32 compute, MLState initialisation, layer count, attention vs MLP isolation, bare matmul repros, coremltools 8.3 vs 9.0, every compute unit. All red herrings.

Real cause: freshly-allocated MLMultiArray input buffers that haven't been fully written via withUnsafeMutableBufferPointer cause CoreML prediction to read garbage on macOS. The CLI worked all along because its setup happens to fill every buffer; standalone tests failed because they passed untouched arrays straight to prediction.

Note: after try MLMultiArray(shape:..., dataType:...), immediately fill every element, even with zeros. Apple guarantees zero-init for MLState buffers but not for input MLMultiArray on macOS.


Results (iPhone 15 / A16, CPU Only)

Trial Change Prefill Decode RTF
4 fp32, ctx 2048, naive Swift 14 tok/s 10 tok/s 2.5x
20 + fp16 safe RMSNorm 20 tok/s 12 tok/s 2.1x
21 + context 512 36 tok/s 16 tok/s ~1.6x
22 + pre-alloc + custom provider 60 tok/s 19 tok/s ~1.3x
23 + direct fp16 sampling 60 tok/s 50 tok/s ~0.5x

~2x realtime on iPhone 15 CPU.

Bugs Found

Bug Symptom Fix
Softmax on all--inf rows NaN propagation nan_to_num(0.0)
fp16 overflow in RMSNorm x.pow(2) All-zero output Pre-scale by amax
torch.jit.trace missing prim::SetAttr KV cache resets every step MIL post-processing
coreml_update_state on GPU/ANE Error -14 / fallback .cpuOnly
Kanade interleaved vs split-half RoPE Wrong mel Interleaved rotation
Missing local attention in Kanade Mel diff 5.58 Banded mask
MLMultiArray padded strides Audio artifacts Pass arrays directly
MLDictionaryFeatureProvider stale data Wrong logits Custom MLFeatureProvider
NSNumber logits read 19 tok/s ceiling Direct fp16 pointer + top-K
Uninitialised MLMultiArray inputs (macOS) Non-deterministic NaN logits Fully populate every input buffer at allocation, even with zeros