File size: 3,234 Bytes
1dfb01c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 | import Foundation
import CoreML
// MARK: - Decode Input Provider
/// Custom MLFeatureProvider that returns fresh MLFeatureValue on each access.
/// Critical for performance: prevents CoreML from caching stale array data.
class DecodeInputProvider: MLFeatureProvider {
let featureNames: Set<String> = [
"input_ids", "causal_mask", "cos", "sin",
"update_mask", "speaker_embedding", "is_speaker_step"
]
let arrays: [String: MLMultiArray]
init(_ arrays: [String: MLMultiArray]) { self.arrays = arrays }
func featureValue(for name: String) -> MLFeatureValue? {
guard let arr = arrays[name] else { return nil }
return MLFeatureValue(multiArray: arr)
}
}
// MARK: - Sampling
/// Samples a token from logits using top-K filtering and temperature scaling.
/// Performance-critical: operates directly on Float16 buffer to avoid NSNumber overhead.
func sampleFromLogits(_ logits: [Float16], temperature: Float = PlapreConfig.defaultTemperature, topK: Int = PlapreConfig.defaultTopK) -> Int32 {
return logits.withUnsafeBufferPointer { ptr -> Int32 in
var topIndices = [Int](repeating: 0, count: topK)
var topValues = [Float16](repeating: Float16(-65504.0), count: topK)
var minIdx = 0
for i in 0..<PlapreConfig.vocabSize {
if ptr[i] > topValues[minIdx] {
topValues[minIdx] = ptr[i]
topIndices[minIdx] = i
minIdx = 0
for j in 1..<topK {
if topValues[j] < topValues[minIdx] { minIdx = j }
}
}
}
if temperature <= 0 {
var bestIdx = 0
for j in 1..<topK {
if topValues[j] > topValues[bestIdx] { bestIdx = j }
}
return Int32(topIndices[bestIdx])
}
var logits32 = [Float](repeating: 0, count: topK)
for j in 0..<topK { logits32[j] = Float(topValues[j]) / temperature }
let maxVal = logits32.max()!
var exps = logits32.map { exp($0 - maxVal) }
let sum = exps.reduce(0, +)
for j in 0..<topK { exps[j] /= sum }
let r = Float.random(in: 0..<1)
var cumsum: Float = 0
for j in 0..<topK {
cumsum += exps[j]
if cumsum >= r { return Int32(topIndices[j]) }
}
return Int32(topIndices[topK - 1])
}
}
// MARK: - Timing
func formatTime(_ seconds: Double) -> String {
if seconds < 0.001 { return String(format: "%.2fµs", seconds * 1_000_000) }
if seconds < 1.0 { return String(format: "%.1fms", seconds * 1000) }
return String(format: "%.2fs", seconds)
}
func measure<T>(_ label: String, _ block: () throws -> T) rethrows -> T {
let start = CFAbsoluteTimeGetCurrent()
let result = try block()
let elapsed = CFAbsoluteTimeGetCurrent() - start
print(" ⏱ \(label): \(formatTime(elapsed))")
return result
}
func measureAsync<T>(_ label: String, _ block: () async throws -> T) async rethrows -> T {
let start = CFAbsoluteTimeGetCurrent()
let result = try await block()
let elapsed = CFAbsoluteTimeGetCurrent() - start
print(" ⏱ \(label): \(formatTime(elapsed))")
return result
}
|