Daniel Rothmann commited on
Commit
10bad22
·
1 Parent(s): 26b347d

Workaround for bad ML state on MacOS

Browse files
EXPERIMENTS.md DELETED
File without changes
PlaprePico.mlpackage/Data/com.apple.CoreML/model.mlmodel CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a70122791826c020dc3a1ee6bfadef2a5ac74d14e5e060e9dfab76c57284130b
3
  size 957824
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b0ea4fbe5939f8db381da0ccadf9e90b61c82f5f0eca58b46e89b3a5541a49f0
3
  size 957824
PlaprePico.mlpackage/Manifest.json CHANGED
@@ -1,18 +1,18 @@
1
  {
2
  "fileFormatVersion": "1.0.0",
3
  "itemInfoEntries": {
4
- "9B90755A-51A8-4710-8C16-AB5B86538A1A": {
5
- "author": "com.apple.CoreML",
6
- "description": "CoreML Model Specification",
7
- "name": "model.mlmodel",
8
- "path": "com.apple.CoreML/model.mlmodel"
9
- },
10
- "D85A4A62-BBF3-4C94-848E-7E37C3571EA6": {
11
  "author": "com.apple.CoreML",
12
  "description": "CoreML Model Weights",
13
  "name": "weights",
14
  "path": "com.apple.CoreML/weights"
 
 
 
 
 
 
15
  }
16
  },
17
- "rootModelIdentifier": "9B90755A-51A8-4710-8C16-AB5B86538A1A"
18
  }
 
1
  {
2
  "fileFormatVersion": "1.0.0",
3
  "itemInfoEntries": {
4
+ "1F911078-42FE-4F91-A2D0-E5B86F87F7AD": {
 
 
 
 
 
 
5
  "author": "com.apple.CoreML",
6
  "description": "CoreML Model Weights",
7
  "name": "weights",
8
  "path": "com.apple.CoreML/weights"
9
+ },
10
+ "3E69D1BF-E09D-43D9-A7FE-E3B15CDDF0BD": {
11
+ "author": "com.apple.CoreML",
12
+ "description": "CoreML Model Specification",
13
+ "name": "model.mlmodel",
14
+ "path": "com.apple.CoreML/model.mlmodel"
15
  }
16
  },
17
+ "rootModelIdentifier": "3E69D1BF-E09D-43D9-A7FE-E3B15CDDF0BD"
18
  }
scripts/convert.py CHANGED
@@ -207,7 +207,7 @@ def convert_decode(model: PlaprePico, output_dir: Path):
207
  ],
208
  outputs=[ct.TensorType(name="logits", dtype=np.float16)],
209
  states=build_kv_cache_states(),
210
- compute_precision=ct.precision.FLOAT32,
211
  minimum_deployment_target=ct.target.iOS18,
212
  )
213
 
 
207
  ],
208
  outputs=[ct.TensorType(name="logits", dtype=np.float16)],
209
  states=build_kv_cache_states(),
210
+ compute_precision=ct.precision.FLOAT16,
211
  minimum_deployment_target=ct.target.iOS18,
212
  )
213
 
swift-cli/Sources/main.swift CHANGED
@@ -378,56 +378,9 @@ func sampleFromLogitsFp16(_ ptr: UnsafeBufferPointer<Float16>, temperature: Floa
378
  return Int32(topIndices[topK - 1])
379
  }
380
 
381
- func sampleFromLogitsFp32(_ ptr: UnsafeBufferPointer<Float>, temperature: Float, topK: Int) -> Int32 {
382
- var topIndices = [Int](repeating: 0, count: topK)
383
- var topValues = [Float](repeating: -.greatestFiniteMagnitude, count: topK)
384
- var minIdx = 0
385
- for i in 0..<vocabSize {
386
- if ptr[i] > topValues[minIdx] {
387
- topValues[minIdx] = ptr[i]
388
- topIndices[minIdx] = i
389
- minIdx = 0
390
- for j in 1..<topK { if topValues[j] < topValues[minIdx] { minIdx = j } }
391
- }
392
- }
393
- if temperature <= 0 {
394
- var bestIdx = 0
395
- for j in 1..<topK { if topValues[j] > topValues[bestIdx] { bestIdx = j } }
396
- return Int32(topIndices[bestIdx])
397
- }
398
- var logits32 = [Float](repeating: 0, count: topK)
399
- for j in 0..<topK { logits32[j] = topValues[j] / temperature }
400
- let maxVal = logits32.max()!
401
- var exps = logits32.map { exp($0 - maxVal) }
402
- let sum = exps.reduce(0, +)
403
- for j in 0..<topK { exps[j] /= sum }
404
- let r = Float.random(in: 0..<1)
405
- var cumsum: Float = 0
406
- for j in 0..<topK {
407
- cumsum += exps[j]
408
- if cumsum >= r { return Int32(topIndices[j]) }
409
- }
410
- return Int32(topIndices[topK - 1])
411
- }
412
-
413
- func sampleFromLogits(_ logitsArr: MLMultiArray, temperature: Float = 0.8, topK: Int = 50) -> Int32 {
414
- // CoreML may report .float16 dataType but use float32 backing with FLOAT32 compute precision.
415
- // Try fp16 first; if values are NaN (fp32 data read as fp16), fall back to fp32.
416
- var isFp16 = true
417
- logitsArr.withUnsafeBufferPointer(ofType: Float16.self) { ptr in
418
- if ptr[0].isNaN && ptr[1].isNaN { isFp16 = false }
419
- }
420
- if isFp16 {
421
- return logitsArr.withUnsafeBufferPointer(ofType: Float16.self) { ptr -> Int32 in
422
- return sampleFromLogitsFp16(ptr, temperature: temperature, topK: topK)
423
- }
424
- } else {
425
- // fp32 backing behind fp16-declared output — use dataPointer directly
426
- let rawPtr = UnsafeBufferPointer(
427
- start: logitsArr.dataPointer.assumingMemoryBound(to: Float.self),
428
- count: vocabSize
429
- )
430
- return sampleFromLogitsFp32(rawPtr, temperature: temperature, topK: topK)
431
  }
432
  }
433
 
@@ -510,8 +463,11 @@ let decodeModel = try measure("Compile PlaprePico") { try compileModel(at: model
510
 
511
  // === Step 1: Prefill via decode model (one token at a time) ===
512
  print("\n--- Prefill (token-by-token through decode model) ---")
513
- let state = decodeModel.makeState()
514
- var lastLogitsArr: MLMultiArray!
 
 
 
515
 
516
  // Pre-allocate all input arrays ONCE
517
  let pInputIds = try! MLMultiArray(shape: [1, 1], dataType: .int32)
@@ -575,7 +531,11 @@ func runDecodeStep(token: Int32, pos: Int, isSpeaker: Bool = false) throws {
575
  }
576
 
577
  let output = try decodeModel.prediction(from: inputProvider, using: state)
578
- lastLogitsArr = output.featureValue(for: "logits")!.multiArrayValue!
 
 
 
 
579
  }
580
 
581
  // The input sequence is: [placeholder(speaker), <text>, tokens..., <audio>]
@@ -589,15 +549,33 @@ func runDecodeStep(token: Int32, pos: Int, isSpeaker: Bool = false) throws {
589
  let inputTokens: [Int32] = Array(inputSeq.prefix(inputLen))
590
  print("Processing \(inputTokens.count) input tokens...")
591
 
 
 
 
 
592
  let prefillStart = CFAbsoluteTimeGetCurrent()
593
- for (i, token) in inputTokens.enumerated() {
594
- try runDecodeStep(token: token, pos: i, isSpeaker: i == 0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
595
  }
596
  let prefillElapsed = CFAbsoluteTimeGetCurrent() - prefillStart
597
  let prefillTokPerSec = Double(inputTokens.count) / prefillElapsed
598
  print(" ⏱ Prefill: \(formatTime(prefillElapsed)) (\(inputTokens.count) tokens, \(String(format: "%.1f", prefillTokPerSec)) tok/s)")
599
 
600
- let firstToken = sampleFromLogits(lastLogitsArr, temperature: 0.8, topK: 50)
601
  print("First generated token: \(firstToken)")
602
 
603
  // === Step 2: Autoregressive decode ===
@@ -614,7 +592,7 @@ for step in 1..<maxTokens {
614
  let pos = inputLen + step - 1
615
 
616
  try runDecodeStep(token: nextToken, pos: pos)
617
- nextToken = sampleFromLogits(lastLogitsArr, temperature: 0.8, topK: 50)
618
  generatedTokens.append(nextToken)
619
 
620
  if nextToken == eosToken {
 
378
  return Int32(topIndices[topK - 1])
379
  }
380
 
381
+ func sampleFromLogits(_ logits: [Float16], temperature: Float = 0.8, topK: Int = 50) -> Int32 {
382
+ return logits.withUnsafeBufferPointer { ptr -> Int32 in
383
+ return sampleFromLogitsFp16(ptr, temperature: temperature, topK: topK)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  }
385
  }
386
 
 
463
 
464
  // === Step 1: Prefill via decode model (one token at a time) ===
465
  print("\n--- Prefill (token-by-token through decode model) ---")
466
+
467
+ // MLState may contain uninitialized memory (NaN in KV cache).
468
+ // Retry makeState + first prefill step until logits are valid.
469
+ var state = decodeModel.makeState()
470
+ var lastLogits = [Float16](repeating: 0, count: vocabSize)
471
 
472
  // Pre-allocate all input arrays ONCE
473
  let pInputIds = try! MLMultiArray(shape: [1, 1], dataType: .int32)
 
531
  }
532
 
533
  let output = try decodeModel.prediction(from: inputProvider, using: state)
534
+ let arr = output.featureValue(for: "logits")!.multiArrayValue!
535
+ // Log shape/stride info on first call to diagnose backing type
536
+ arr.withUnsafeBufferPointer(ofType: Float16.self) { ptr in
537
+ for i in 0..<vocabSize { lastLogits[i] = ptr[i] }
538
+ }
539
  }
540
 
541
  // The input sequence is: [placeholder(speaker), <text>, tokens..., <audio>]
 
549
  let inputTokens: [Int32] = Array(inputSeq.prefix(inputLen))
550
  print("Processing \(inputTokens.count) input tokens...")
551
 
552
+ // Retry prefill if state has uninitialized NaN memory.
553
+ // MLState buffers are not guaranteed zero-initialized; NaN in KV cache
554
+ // propagates through Q@K^T and poisons softmax. Typically 1-3 attempts.
555
+ var prefillAttempt = 0
556
  let prefillStart = CFAbsoluteTimeGetCurrent()
557
+ while prefillAttempt < 20 {
558
+ prefillAttempt += 1
559
+ state = decodeModel.makeState()
560
+ // Reset causal mask and update mask for fresh prefill
561
+ pCausalMask.withUnsafeMutableBufferPointer(ofType: Float16.self) { ptr, _ in
562
+ for i in 0..<maxContext { ptr[i] = Float16(-65504.0) }
563
+ }
564
+ pUpdateMask.withUnsafeMutableBufferPointer(ofType: Float16.self) { ptr, _ in
565
+ for i in 0..<maxContext { ptr[i] = Float16(0.0) }
566
+ }
567
+ try runDecodeStep(token: inputTokens[0], pos: 0, isSpeaker: true)
568
+ if !lastLogits[0].isNaN { break }
569
+ }
570
+ print(" Clean state after \(prefillAttempt) attempt(s)")
571
+ for i in 1..<inputTokens.count {
572
+ try runDecodeStep(token: inputTokens[i], pos: i, isSpeaker: false)
573
  }
574
  let prefillElapsed = CFAbsoluteTimeGetCurrent() - prefillStart
575
  let prefillTokPerSec = Double(inputTokens.count) / prefillElapsed
576
  print(" ⏱ Prefill: \(formatTime(prefillElapsed)) (\(inputTokens.count) tokens, \(String(format: "%.1f", prefillTokPerSec)) tok/s)")
577
 
578
+ let firstToken = sampleFromLogits(lastLogits, temperature: 0.8, topK: 50)
579
  print("First generated token: \(firstToken)")
580
 
581
  // === Step 2: Autoregressive decode ===
 
592
  let pos = inputLen + step - 1
593
 
594
  try runDecodeStep(token: nextToken, pos: pos)
595
+ nextToken = sampleFromLogits(lastLogits, temperature: 0.8, topK: 50)
596
  generatedTokens.append(nextToken)
597
 
598
  if nextToken == eosToken {