| import Foundation |
| import CoreML |
|
|
| |
|
|
| |
| |
| class DecodeInputProvider: MLFeatureProvider { |
| let featureNames: Set<String> = [ |
| "input_ids", "causal_mask", "cos", "sin", |
| "update_mask", "speaker_embedding", "is_speaker_step" |
| ] |
| let arrays: [String: MLMultiArray] |
|
|
| init(_ arrays: [String: MLMultiArray]) { self.arrays = arrays } |
|
|
| func featureValue(for name: String) -> MLFeatureValue? { |
| guard let arr = arrays[name] else { return nil } |
| return MLFeatureValue(multiArray: arr) |
| } |
| } |
|
|
| |
|
|
| |
| |
| func sampleFromLogits(_ logits: [Float16], temperature: Float = PlapreConfig.defaultTemperature, topK: Int = PlapreConfig.defaultTopK) -> Int32 { |
| return logits.withUnsafeBufferPointer { ptr -> Int32 in |
| var topIndices = [Int](repeating: 0, count: topK) |
| var topValues = [Float16](repeating: Float16(-65504.0), count: topK) |
| var minIdx = 0 |
|
|
| for i in 0..<PlapreConfig.vocabSize { |
| if ptr[i] > topValues[minIdx] { |
| topValues[minIdx] = ptr[i] |
| topIndices[minIdx] = i |
| minIdx = 0 |
| for j in 1..<topK { |
| if topValues[j] < topValues[minIdx] { minIdx = j } |
| } |
| } |
| } |
|
|
| if temperature <= 0 { |
| var bestIdx = 0 |
| for j in 1..<topK { |
| if topValues[j] > topValues[bestIdx] { bestIdx = j } |
| } |
| return Int32(topIndices[bestIdx]) |
| } |
|
|
| var logits32 = [Float](repeating: 0, count: topK) |
| for j in 0..<topK { logits32[j] = Float(topValues[j]) / temperature } |
|
|
| let maxVal = logits32.max()! |
| var exps = logits32.map { exp($0 - maxVal) } |
| let sum = exps.reduce(0, +) |
| for j in 0..<topK { exps[j] /= sum } |
|
|
| let r = Float.random(in: 0..<1) |
| var cumsum: Float = 0 |
| for j in 0..<topK { |
| cumsum += exps[j] |
| if cumsum >= r { return Int32(topIndices[j]) } |
| } |
| return Int32(topIndices[topK - 1]) |
| } |
| } |
|
|
| |
|
|
| func formatTime(_ seconds: Double) -> String { |
| if seconds < 0.001 { return String(format: "%.2fµs", seconds * 1_000_000) } |
| if seconds < 1.0 { return String(format: "%.1fms", seconds * 1000) } |
| return String(format: "%.2fs", seconds) |
| } |
|
|
| func measure<T>(_ label: String, _ block: () throws -> T) rethrows -> T { |
| let start = CFAbsoluteTimeGetCurrent() |
| let result = try block() |
| let elapsed = CFAbsoluteTimeGetCurrent() - start |
| print(" ⏱ \(label): \(formatTime(elapsed))") |
| return result |
| } |
|
|
| func measureAsync<T>(_ label: String, _ block: () async throws -> T) async rethrows -> T { |
| let start = CFAbsoluteTimeGetCurrent() |
| let result = try await block() |
| let elapsed = CFAbsoluteTimeGetCurrent() - start |
| print(" ⏱ \(label): \(formatTime(elapsed))") |
| return result |
| } |
|
|