import Foundation import CoreML // MARK: - Decode Input Provider /// Custom MLFeatureProvider that returns fresh MLFeatureValue on each access. /// Critical for performance: prevents CoreML from caching stale array data. class DecodeInputProvider: MLFeatureProvider { let featureNames: Set = [ "input_ids", "causal_mask", "cos", "sin", "update_mask", "speaker_embedding", "is_speaker_step" ] let arrays: [String: MLMultiArray] init(_ arrays: [String: MLMultiArray]) { self.arrays = arrays } func featureValue(for name: String) -> MLFeatureValue? { guard let arr = arrays[name] else { return nil } return MLFeatureValue(multiArray: arr) } } // MARK: - Sampling /// Samples a token from logits using top-K filtering and temperature scaling. /// Performance-critical: operates directly on Float16 buffer to avoid NSNumber overhead. func sampleFromLogits(_ logits: [Float16], temperature: Float = PlapreConfig.defaultTemperature, topK: Int = PlapreConfig.defaultTopK) -> Int32 { return logits.withUnsafeBufferPointer { ptr -> Int32 in var topIndices = [Int](repeating: 0, count: topK) var topValues = [Float16](repeating: Float16(-65504.0), count: topK) var minIdx = 0 for i in 0.. topValues[minIdx] { topValues[minIdx] = ptr[i] topIndices[minIdx] = i minIdx = 0 for j in 1.. topValues[bestIdx] { bestIdx = j } } return Int32(topIndices[bestIdx]) } var logits32 = [Float](repeating: 0, count: topK) for j in 0..= r { return Int32(topIndices[j]) } } return Int32(topIndices[topK - 1]) } } // MARK: - Timing func formatTime(_ seconds: Double) -> String { if seconds < 0.001 { return String(format: "%.2fµs", seconds * 1_000_000) } if seconds < 1.0 { return String(format: "%.1fms", seconds * 1000) } return String(format: "%.2fs", seconds) } func measure(_ label: String, _ block: () throws -> T) rethrows -> T { let start = CFAbsoluteTimeGetCurrent() let result = try block() let elapsed = CFAbsoluteTimeGetCurrent() - start print(" ⏱ \(label): \(formatTime(elapsed))") return result } func measureAsync(_ label: String, _ block: () async throws -> T) async rethrows -> T { let start = CFAbsoluteTimeGetCurrent() let result = try await block() let elapsed = CFAbsoluteTimeGetCurrent() - start print(" ⏱ \(label): \(formatTime(elapsed))") return result }