Daniel Rothmann commited on
Commit ·
10bad22
1
Parent(s): 26b347d
Workaround for bad ML state on MacOS
Browse files
EXPERIMENTS.md
DELETED
|
File without changes
|
PlaprePico.mlpackage/Data/com.apple.CoreML/model.mlmodel
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 957824
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b0ea4fbe5939f8db381da0ccadf9e90b61c82f5f0eca58b46e89b3a5541a49f0
|
| 3 |
size 957824
|
PlaprePico.mlpackage/Manifest.json
CHANGED
|
@@ -1,18 +1,18 @@
|
|
| 1 |
{
|
| 2 |
"fileFormatVersion": "1.0.0",
|
| 3 |
"itemInfoEntries": {
|
| 4 |
-
"
|
| 5 |
-
"author": "com.apple.CoreML",
|
| 6 |
-
"description": "CoreML Model Specification",
|
| 7 |
-
"name": "model.mlmodel",
|
| 8 |
-
"path": "com.apple.CoreML/model.mlmodel"
|
| 9 |
-
},
|
| 10 |
-
"D85A4A62-BBF3-4C94-848E-7E37C3571EA6": {
|
| 11 |
"author": "com.apple.CoreML",
|
| 12 |
"description": "CoreML Model Weights",
|
| 13 |
"name": "weights",
|
| 14 |
"path": "com.apple.CoreML/weights"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
}
|
| 16 |
},
|
| 17 |
-
"rootModelIdentifier": "
|
| 18 |
}
|
|
|
|
| 1 |
{
|
| 2 |
"fileFormatVersion": "1.0.0",
|
| 3 |
"itemInfoEntries": {
|
| 4 |
+
"1F911078-42FE-4F91-A2D0-E5B86F87F7AD": {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
"author": "com.apple.CoreML",
|
| 6 |
"description": "CoreML Model Weights",
|
| 7 |
"name": "weights",
|
| 8 |
"path": "com.apple.CoreML/weights"
|
| 9 |
+
},
|
| 10 |
+
"3E69D1BF-E09D-43D9-A7FE-E3B15CDDF0BD": {
|
| 11 |
+
"author": "com.apple.CoreML",
|
| 12 |
+
"description": "CoreML Model Specification",
|
| 13 |
+
"name": "model.mlmodel",
|
| 14 |
+
"path": "com.apple.CoreML/model.mlmodel"
|
| 15 |
}
|
| 16 |
},
|
| 17 |
+
"rootModelIdentifier": "3E69D1BF-E09D-43D9-A7FE-E3B15CDDF0BD"
|
| 18 |
}
|
scripts/convert.py
CHANGED
|
@@ -207,7 +207,7 @@ def convert_decode(model: PlaprePico, output_dir: Path):
|
|
| 207 |
],
|
| 208 |
outputs=[ct.TensorType(name="logits", dtype=np.float16)],
|
| 209 |
states=build_kv_cache_states(),
|
| 210 |
-
compute_precision=ct.precision.
|
| 211 |
minimum_deployment_target=ct.target.iOS18,
|
| 212 |
)
|
| 213 |
|
|
|
|
| 207 |
],
|
| 208 |
outputs=[ct.TensorType(name="logits", dtype=np.float16)],
|
| 209 |
states=build_kv_cache_states(),
|
| 210 |
+
compute_precision=ct.precision.FLOAT16,
|
| 211 |
minimum_deployment_target=ct.target.iOS18,
|
| 212 |
)
|
| 213 |
|
swift-cli/Sources/main.swift
CHANGED
|
@@ -378,56 +378,9 @@ func sampleFromLogitsFp16(_ ptr: UnsafeBufferPointer<Float16>, temperature: Floa
|
|
| 378 |
return Int32(topIndices[topK - 1])
|
| 379 |
}
|
| 380 |
|
| 381 |
-
func
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
var minIdx = 0
|
| 385 |
-
for i in 0..<vocabSize {
|
| 386 |
-
if ptr[i] > topValues[minIdx] {
|
| 387 |
-
topValues[minIdx] = ptr[i]
|
| 388 |
-
topIndices[minIdx] = i
|
| 389 |
-
minIdx = 0
|
| 390 |
-
for j in 1..<topK { if topValues[j] < topValues[minIdx] { minIdx = j } }
|
| 391 |
-
}
|
| 392 |
-
}
|
| 393 |
-
if temperature <= 0 {
|
| 394 |
-
var bestIdx = 0
|
| 395 |
-
for j in 1..<topK { if topValues[j] > topValues[bestIdx] { bestIdx = j } }
|
| 396 |
-
return Int32(topIndices[bestIdx])
|
| 397 |
-
}
|
| 398 |
-
var logits32 = [Float](repeating: 0, count: topK)
|
| 399 |
-
for j in 0..<topK { logits32[j] = topValues[j] / temperature }
|
| 400 |
-
let maxVal = logits32.max()!
|
| 401 |
-
var exps = logits32.map { exp($0 - maxVal) }
|
| 402 |
-
let sum = exps.reduce(0, +)
|
| 403 |
-
for j in 0..<topK { exps[j] /= sum }
|
| 404 |
-
let r = Float.random(in: 0..<1)
|
| 405 |
-
var cumsum: Float = 0
|
| 406 |
-
for j in 0..<topK {
|
| 407 |
-
cumsum += exps[j]
|
| 408 |
-
if cumsum >= r { return Int32(topIndices[j]) }
|
| 409 |
-
}
|
| 410 |
-
return Int32(topIndices[topK - 1])
|
| 411 |
-
}
|
| 412 |
-
|
| 413 |
-
func sampleFromLogits(_ logitsArr: MLMultiArray, temperature: Float = 0.8, topK: Int = 50) -> Int32 {
|
| 414 |
-
// CoreML may report .float16 dataType but use float32 backing with FLOAT32 compute precision.
|
| 415 |
-
// Try fp16 first; if values are NaN (fp32 data read as fp16), fall back to fp32.
|
| 416 |
-
var isFp16 = true
|
| 417 |
-
logitsArr.withUnsafeBufferPointer(ofType: Float16.self) { ptr in
|
| 418 |
-
if ptr[0].isNaN && ptr[1].isNaN { isFp16 = false }
|
| 419 |
-
}
|
| 420 |
-
if isFp16 {
|
| 421 |
-
return logitsArr.withUnsafeBufferPointer(ofType: Float16.self) { ptr -> Int32 in
|
| 422 |
-
return sampleFromLogitsFp16(ptr, temperature: temperature, topK: topK)
|
| 423 |
-
}
|
| 424 |
-
} else {
|
| 425 |
-
// fp32 backing behind fp16-declared output — use dataPointer directly
|
| 426 |
-
let rawPtr = UnsafeBufferPointer(
|
| 427 |
-
start: logitsArr.dataPointer.assumingMemoryBound(to: Float.self),
|
| 428 |
-
count: vocabSize
|
| 429 |
-
)
|
| 430 |
-
return sampleFromLogitsFp32(rawPtr, temperature: temperature, topK: topK)
|
| 431 |
}
|
| 432 |
}
|
| 433 |
|
|
@@ -510,8 +463,11 @@ let decodeModel = try measure("Compile PlaprePico") { try compileModel(at: model
|
|
| 510 |
|
| 511 |
// === Step 1: Prefill via decode model (one token at a time) ===
|
| 512 |
print("\n--- Prefill (token-by-token through decode model) ---")
|
| 513 |
-
|
| 514 |
-
|
|
|
|
|
|
|
|
|
|
| 515 |
|
| 516 |
// Pre-allocate all input arrays ONCE
|
| 517 |
let pInputIds = try! MLMultiArray(shape: [1, 1], dataType: .int32)
|
|
@@ -575,7 +531,11 @@ func runDecodeStep(token: Int32, pos: Int, isSpeaker: Bool = false) throws {
|
|
| 575 |
}
|
| 576 |
|
| 577 |
let output = try decodeModel.prediction(from: inputProvider, using: state)
|
| 578 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 579 |
}
|
| 580 |
|
| 581 |
// The input sequence is: [placeholder(speaker), <text>, tokens..., <audio>]
|
|
@@ -589,15 +549,33 @@ func runDecodeStep(token: Int32, pos: Int, isSpeaker: Bool = false) throws {
|
|
| 589 |
let inputTokens: [Int32] = Array(inputSeq.prefix(inputLen))
|
| 590 |
print("Processing \(inputTokens.count) input tokens...")
|
| 591 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 592 |
let prefillStart = CFAbsoluteTimeGetCurrent()
|
| 593 |
-
|
| 594 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 595 |
}
|
| 596 |
let prefillElapsed = CFAbsoluteTimeGetCurrent() - prefillStart
|
| 597 |
let prefillTokPerSec = Double(inputTokens.count) / prefillElapsed
|
| 598 |
print(" ⏱ Prefill: \(formatTime(prefillElapsed)) (\(inputTokens.count) tokens, \(String(format: "%.1f", prefillTokPerSec)) tok/s)")
|
| 599 |
|
| 600 |
-
let firstToken = sampleFromLogits(
|
| 601 |
print("First generated token: \(firstToken)")
|
| 602 |
|
| 603 |
// === Step 2: Autoregressive decode ===
|
|
@@ -614,7 +592,7 @@ for step in 1..<maxTokens {
|
|
| 614 |
let pos = inputLen + step - 1
|
| 615 |
|
| 616 |
try runDecodeStep(token: nextToken, pos: pos)
|
| 617 |
-
nextToken = sampleFromLogits(
|
| 618 |
generatedTokens.append(nextToken)
|
| 619 |
|
| 620 |
if nextToken == eosToken {
|
|
|
|
| 378 |
return Int32(topIndices[topK - 1])
|
| 379 |
}
|
| 380 |
|
| 381 |
+
func sampleFromLogits(_ logits: [Float16], temperature: Float = 0.8, topK: Int = 50) -> Int32 {
|
| 382 |
+
return logits.withUnsafeBufferPointer { ptr -> Int32 in
|
| 383 |
+
return sampleFromLogitsFp16(ptr, temperature: temperature, topK: topK)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
}
|
| 385 |
}
|
| 386 |
|
|
|
|
| 463 |
|
| 464 |
// === Step 1: Prefill via decode model (one token at a time) ===
|
| 465 |
print("\n--- Prefill (token-by-token through decode model) ---")
|
| 466 |
+
|
| 467 |
+
// MLState may contain uninitialized memory (NaN in KV cache).
|
| 468 |
+
// Retry makeState + first prefill step until logits are valid.
|
| 469 |
+
var state = decodeModel.makeState()
|
| 470 |
+
var lastLogits = [Float16](repeating: 0, count: vocabSize)
|
| 471 |
|
| 472 |
// Pre-allocate all input arrays ONCE
|
| 473 |
let pInputIds = try! MLMultiArray(shape: [1, 1], dataType: .int32)
|
|
|
|
| 531 |
}
|
| 532 |
|
| 533 |
let output = try decodeModel.prediction(from: inputProvider, using: state)
|
| 534 |
+
let arr = output.featureValue(for: "logits")!.multiArrayValue!
|
| 535 |
+
// Log shape/stride info on first call to diagnose backing type
|
| 536 |
+
arr.withUnsafeBufferPointer(ofType: Float16.self) { ptr in
|
| 537 |
+
for i in 0..<vocabSize { lastLogits[i] = ptr[i] }
|
| 538 |
+
}
|
| 539 |
}
|
| 540 |
|
| 541 |
// The input sequence is: [placeholder(speaker), <text>, tokens..., <audio>]
|
|
|
|
| 549 |
let inputTokens: [Int32] = Array(inputSeq.prefix(inputLen))
|
| 550 |
print("Processing \(inputTokens.count) input tokens...")
|
| 551 |
|
| 552 |
+
// Retry prefill if state has uninitialized NaN memory.
|
| 553 |
+
// MLState buffers are not guaranteed zero-initialized; NaN in KV cache
|
| 554 |
+
// propagates through Q@K^T and poisons softmax. Typically 1-3 attempts.
|
| 555 |
+
var prefillAttempt = 0
|
| 556 |
let prefillStart = CFAbsoluteTimeGetCurrent()
|
| 557 |
+
while prefillAttempt < 20 {
|
| 558 |
+
prefillAttempt += 1
|
| 559 |
+
state = decodeModel.makeState()
|
| 560 |
+
// Reset causal mask and update mask for fresh prefill
|
| 561 |
+
pCausalMask.withUnsafeMutableBufferPointer(ofType: Float16.self) { ptr, _ in
|
| 562 |
+
for i in 0..<maxContext { ptr[i] = Float16(-65504.0) }
|
| 563 |
+
}
|
| 564 |
+
pUpdateMask.withUnsafeMutableBufferPointer(ofType: Float16.self) { ptr, _ in
|
| 565 |
+
for i in 0..<maxContext { ptr[i] = Float16(0.0) }
|
| 566 |
+
}
|
| 567 |
+
try runDecodeStep(token: inputTokens[0], pos: 0, isSpeaker: true)
|
| 568 |
+
if !lastLogits[0].isNaN { break }
|
| 569 |
+
}
|
| 570 |
+
print(" Clean state after \(prefillAttempt) attempt(s)")
|
| 571 |
+
for i in 1..<inputTokens.count {
|
| 572 |
+
try runDecodeStep(token: inputTokens[i], pos: i, isSpeaker: false)
|
| 573 |
}
|
| 574 |
let prefillElapsed = CFAbsoluteTimeGetCurrent() - prefillStart
|
| 575 |
let prefillTokPerSec = Double(inputTokens.count) / prefillElapsed
|
| 576 |
print(" ⏱ Prefill: \(formatTime(prefillElapsed)) (\(inputTokens.count) tokens, \(String(format: "%.1f", prefillTokPerSec)) tok/s)")
|
| 577 |
|
| 578 |
+
let firstToken = sampleFromLogits(lastLogits, temperature: 0.8, topK: 50)
|
| 579 |
print("First generated token: \(firstToken)")
|
| 580 |
|
| 581 |
// === Step 2: Autoregressive decode ===
|
|
|
|
| 592 |
let pos = inputLen + step - 1
|
| 593 |
|
| 594 |
try runDecodeStep(token: nextToken, pos: pos)
|
| 595 |
+
nextToken = sampleFromLogits(lastLogits, temperature: 0.8, topK: 50)
|
| 596 |
generatedTokens.append(nextToken)
|
| 597 |
|
| 598 |
if nextToken == eosToken {
|