File size: 4,211 Bytes
357ae2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import CoreML
import Foundation

/// Thin Core ML wrapper around a SaT token-classification model.
///
/// The exported models have a **fixed** sequence length (256 by default), INT32
/// inputs (`input_ids`, `attention_mask`) and a single boundary logit per token
/// (`logits`, shape `[1, seqLen, 1]`, Float16). This type loads/compiles the
/// package, discovers the sequence length from the model, and runs one window.
final class SaTModel {
    let model: MLModel
    let seqLen: Int
    private let remap: VocabRemap?

    init(modelURL: URL, remap: VocabRemap?, computeUnits: MLComputeUnits) throws {
        let compiledURL: URL
        if modelURL.pathExtension == "mlmodelc" {
            compiledURL = modelURL
        } else {
            compiledURL = try MLModel.compileModel(at: modelURL)
        }
        let config = MLModelConfiguration()
        config.computeUnits = computeUnits
        self.model = try MLModel(contentsOf: compiledURL, configuration: config)
        self.remap = remap

        guard let desc = model.modelDescription.inputDescriptionsByName["input_ids"],
              let constraint = desc.multiArrayConstraint,
              constraint.shape.count >= 1 else {
            throw KitError.modelIncompatible("model has no 'input_ids' multi-array input")
        }
        self.seqLen = constraint.shape[constraint.shape.count - 1].intValue
        guard seqLen > 2 else { throw KitError.modelIncompatible("sequence length too small: \(seqLen)") }
        guard model.modelDescription.outputDescriptionsByName["logits"] != nil else {
            throw KitError.modelIncompatible("model has no 'logits' output")
        }
    }

    /// Run one window. `ids` and `mask` must each have length `seqLen`.
    /// Returns `seqLen` boundary logits (one per position). EN+ZH id remapping,
    /// if configured, is applied here so callers always pass full-vocab ids.
    func predict(ids: [Int32], mask: [Int32]) throws -> [Float] {
        precondition(ids.count == seqLen && mask.count == seqLen, "window must equal seqLen")

        let idsArray = try MLMultiArray(shape: [1, NSNumber(value: seqLen)], dataType: .int32)
        let maskArray = try MLMultiArray(shape: [1, NSNumber(value: seqLen)], dataType: .int32)
        let idsPtr = idsArray.dataPointer.bindMemory(to: Int32.self, capacity: seqLen)
        let maskPtr = maskArray.dataPointer.bindMemory(to: Int32.self, capacity: seqLen)
        if let remap {
            for k in 0..<seqLen { idsPtr[k] = remap.map(ids[k]) }
        } else {
            for k in 0..<seqLen { idsPtr[k] = ids[k] }
        }
        for k in 0..<seqLen { maskPtr[k] = mask[k] }

        let input = try MLDictionaryFeatureProvider(dictionary: [
            "input_ids": idsArray,
            "attention_mask": maskArray,
        ])
        let output = try model.prediction(from: input)
        guard let logits = output.featureValue(for: "logits")?.multiArrayValue else {
            throw KitError.modelIncompatible("prediction returned no 'logits' array")
        }
        return Self.readLogits(logits, count: seqLen)
    }

    /// Read the leading `count` boundary logits as `Float`, regardless of the
    /// backing dtype (Float16 / Float32 / Double). The output is `[1, seqLen, 1]`
    /// and contiguous, so logit for token `k` is at flat offset `k`.
    private static func readLogits(_ array: MLMultiArray, count: Int) -> [Float] {
        var out = [Float](repeating: 0, count: count)
        let n = min(count, array.count)
        switch array.dataType {
        case .float16:
            let p = array.dataPointer.bindMemory(to: Float16.self, capacity: array.count)
            for k in 0..<n { out[k] = Float(p[k]) }
        case .float32:
            let p = array.dataPointer.bindMemory(to: Float.self, capacity: array.count)
            for k in 0..<n { out[k] = p[k] }
        case .double:
            let p = array.dataPointer.bindMemory(to: Double.self, capacity: array.count)
            for k in 0..<n { out[k] = Float(p[k]) }
        default:
            // Integer or future dtypes: read via NSNumber bridging.
            for k in 0..<n { out[k] = array[k].floatValue }
        }
        return out
    }
}