| import CoreML |
| import Foundation |
|
|
| |
| |
| |
| |
| |
| |
| final class SaTModel { |
| let model: MLModel |
| let seqLen: Int |
| private let remap: VocabRemap? |
|
|
| init(modelURL: URL, remap: VocabRemap?, computeUnits: MLComputeUnits) throws { |
| let compiledURL: URL |
| if modelURL.pathExtension == "mlmodelc" { |
| compiledURL = modelURL |
| } else { |
| compiledURL = try MLModel.compileModel(at: modelURL) |
| } |
| let config = MLModelConfiguration() |
| config.computeUnits = computeUnits |
| self.model = try MLModel(contentsOf: compiledURL, configuration: config) |
| self.remap = remap |
|
|
| guard let desc = model.modelDescription.inputDescriptionsByName["input_ids"], |
| let constraint = desc.multiArrayConstraint, |
| constraint.shape.count >= 1 else { |
| throw KitError.modelIncompatible("model has no 'input_ids' multi-array input") |
| } |
| self.seqLen = constraint.shape[constraint.shape.count - 1].intValue |
| guard seqLen > 2 else { throw KitError.modelIncompatible("sequence length too small: \(seqLen)") } |
| guard model.modelDescription.outputDescriptionsByName["logits"] != nil else { |
| throw KitError.modelIncompatible("model has no 'logits' output") |
| } |
| } |
|
|
| |
| |
| |
| func predict(ids: [Int32], mask: [Int32]) throws -> [Float] { |
| precondition(ids.count == seqLen && mask.count == seqLen, "window must equal seqLen") |
|
|
| let idsArray = try MLMultiArray(shape: [1, NSNumber(value: seqLen)], dataType: .int32) |
| let maskArray = try MLMultiArray(shape: [1, NSNumber(value: seqLen)], dataType: .int32) |
| let idsPtr = idsArray.dataPointer.bindMemory(to: Int32.self, capacity: seqLen) |
| let maskPtr = maskArray.dataPointer.bindMemory(to: Int32.self, capacity: seqLen) |
| if let remap { |
| for k in 0..<seqLen { idsPtr[k] = remap.map(ids[k]) } |
| } else { |
| for k in 0..<seqLen { idsPtr[k] = ids[k] } |
| } |
| for k in 0..<seqLen { maskPtr[k] = mask[k] } |
|
|
| let input = try MLDictionaryFeatureProvider(dictionary: [ |
| "input_ids": idsArray, |
| "attention_mask": maskArray, |
| ]) |
| let output = try model.prediction(from: input) |
| guard let logits = output.featureValue(for: "logits")?.multiArrayValue else { |
| throw KitError.modelIncompatible("prediction returned no 'logits' array") |
| } |
| return Self.readLogits(logits, count: seqLen) |
| } |
|
|
| |
| |
| |
| private static func readLogits(_ array: MLMultiArray, count: Int) -> [Float] { |
| var out = [Float](repeating: 0, count: count) |
| let n = min(count, array.count) |
| switch array.dataType { |
| case .float16: |
| let p = array.dataPointer.bindMemory(to: Float16.self, capacity: array.count) |
| for k in 0..<n { out[k] = Float(p[k]) } |
| case .float32: |
| let p = array.dataPointer.bindMemory(to: Float.self, capacity: array.count) |
| for k in 0..<n { out[k] = p[k] } |
| case .double: |
| let p = array.dataPointer.bindMemory(to: Double.self, capacity: array.count) |
| for k in 0..<n { out[k] = Float(p[k]) } |
| default: |
| |
| for k in 0..<n { out[k] = array[k].floatValue } |
| } |
| return out |
| } |
| } |
|
|