import CoreML import Foundation /// Thin Core ML wrapper around a SaT token-classification model. /// /// The exported models have a **fixed** sequence length (256 by default), INT32 /// inputs (`input_ids`, `attention_mask`) and a single boundary logit per token /// (`logits`, shape `[1, seqLen, 1]`, Float16). This type loads/compiles the /// package, discovers the sequence length from the model, and runs one window. final class SaTModel { let model: MLModel let seqLen: Int private let remap: VocabRemap? init(modelURL: URL, remap: VocabRemap?, computeUnits: MLComputeUnits) throws { let compiledURL: URL if modelURL.pathExtension == "mlmodelc" { compiledURL = modelURL } else { compiledURL = try MLModel.compileModel(at: modelURL) } let config = MLModelConfiguration() config.computeUnits = computeUnits self.model = try MLModel(contentsOf: compiledURL, configuration: config) self.remap = remap guard let desc = model.modelDescription.inputDescriptionsByName["input_ids"], let constraint = desc.multiArrayConstraint, constraint.shape.count >= 1 else { throw KitError.modelIncompatible("model has no 'input_ids' multi-array input") } self.seqLen = constraint.shape[constraint.shape.count - 1].intValue guard seqLen > 2 else { throw KitError.modelIncompatible("sequence length too small: \(seqLen)") } guard model.modelDescription.outputDescriptionsByName["logits"] != nil else { throw KitError.modelIncompatible("model has no 'logits' output") } } /// Run one window. `ids` and `mask` must each have length `seqLen`. /// Returns `seqLen` boundary logits (one per position). EN+ZH id remapping, /// if configured, is applied here so callers always pass full-vocab ids. func predict(ids: [Int32], mask: [Int32]) throws -> [Float] { precondition(ids.count == seqLen && mask.count == seqLen, "window must equal seqLen") let idsArray = try MLMultiArray(shape: [1, NSNumber(value: seqLen)], dataType: .int32) let maskArray = try MLMultiArray(shape: [1, NSNumber(value: seqLen)], dataType: .int32) let idsPtr = idsArray.dataPointer.bindMemory(to: Int32.self, capacity: seqLen) let maskPtr = maskArray.dataPointer.bindMemory(to: Int32.self, capacity: seqLen) if let remap { for k in 0.. [Float] { var out = [Float](repeating: 0, count: count) let n = min(count, array.count) switch array.dataType { case .float16: let p = array.dataPointer.bindMemory(to: Float16.self, capacity: array.count) for k in 0..