File size: 3,234 Bytes
1dfb01c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import Foundation
import CoreML

// MARK: - Decode Input Provider

/// Custom MLFeatureProvider that returns fresh MLFeatureValue on each access.
/// Critical for performance: prevents CoreML from caching stale array data.
class DecodeInputProvider: MLFeatureProvider {
    let featureNames: Set<String> = [
        "input_ids", "causal_mask", "cos", "sin",
        "update_mask", "speaker_embedding", "is_speaker_step"
    ]
    let arrays: [String: MLMultiArray]

    init(_ arrays: [String: MLMultiArray]) { self.arrays = arrays }

    func featureValue(for name: String) -> MLFeatureValue? {
        guard let arr = arrays[name] else { return nil }
        return MLFeatureValue(multiArray: arr)
    }
}

// MARK: - Sampling

/// Samples a token from logits using top-K filtering and temperature scaling.
/// Performance-critical: operates directly on Float16 buffer to avoid NSNumber overhead.
func sampleFromLogits(_ logits: [Float16], temperature: Float = PlapreConfig.defaultTemperature, topK: Int = PlapreConfig.defaultTopK) -> Int32 {
    return logits.withUnsafeBufferPointer { ptr -> Int32 in
        var topIndices = [Int](repeating: 0, count: topK)
        var topValues = [Float16](repeating: Float16(-65504.0), count: topK)
        var minIdx = 0

        for i in 0..<PlapreConfig.vocabSize {
            if ptr[i] > topValues[minIdx] {
                topValues[minIdx] = ptr[i]
                topIndices[minIdx] = i
                minIdx = 0
                for j in 1..<topK {
                    if topValues[j] < topValues[minIdx] { minIdx = j }
                }
            }
        }

        if temperature <= 0 {
            var bestIdx = 0
            for j in 1..<topK {
                if topValues[j] > topValues[bestIdx] { bestIdx = j }
            }
            return Int32(topIndices[bestIdx])
        }

        var logits32 = [Float](repeating: 0, count: topK)
        for j in 0..<topK { logits32[j] = Float(topValues[j]) / temperature }

        let maxVal = logits32.max()!
        var exps = logits32.map { exp($0 - maxVal) }
        let sum = exps.reduce(0, +)
        for j in 0..<topK { exps[j] /= sum }

        let r = Float.random(in: 0..<1)
        var cumsum: Float = 0
        for j in 0..<topK {
            cumsum += exps[j]
            if cumsum >= r { return Int32(topIndices[j]) }
        }
        return Int32(topIndices[topK - 1])
    }
}

// MARK: - Timing

func formatTime(_ seconds: Double) -> String {
    if seconds < 0.001 { return String(format: "%.2fµs", seconds * 1_000_000) }
    if seconds < 1.0 { return String(format: "%.1fms", seconds * 1000) }
    return String(format: "%.2fs", seconds)
}

func measure<T>(_ label: String, _ block: () throws -> T) rethrows -> T {
    let start = CFAbsoluteTimeGetCurrent()
    let result = try block()
    let elapsed = CFAbsoluteTimeGetCurrent() - start
    print("  ⏱ \(label): \(formatTime(elapsed))")
    return result
}

func measureAsync<T>(_ label: String, _ block: () async throws -> T) async rethrows -> T {
    let start = CFAbsoluteTimeGetCurrent()
    let result = try await block()
    let elapsed = CFAbsoluteTimeGetCurrent() - start
    print("  ⏱ \(label): \(formatTime(elapsed))")
    return result
}