krmanik's picture
Upload folder using huggingface_hub
357ae2c verified
import CoreML
import Foundation
/// Thin Core ML wrapper around a SaT token-classification model.
///
/// The exported models have a **fixed** sequence length (256 by default), INT32
/// inputs (`input_ids`, `attention_mask`) and a single boundary logit per token
/// (`logits`, shape `[1, seqLen, 1]`, Float16). This type loads/compiles the
/// package, discovers the sequence length from the model, and runs one window.
final class SaTModel {
let model: MLModel
let seqLen: Int
private let remap: VocabRemap?
init(modelURL: URL, remap: VocabRemap?, computeUnits: MLComputeUnits) throws {
let compiledURL: URL
if modelURL.pathExtension == "mlmodelc" {
compiledURL = modelURL
} else {
compiledURL = try MLModel.compileModel(at: modelURL)
}
let config = MLModelConfiguration()
config.computeUnits = computeUnits
self.model = try MLModel(contentsOf: compiledURL, configuration: config)
self.remap = remap
guard let desc = model.modelDescription.inputDescriptionsByName["input_ids"],
let constraint = desc.multiArrayConstraint,
constraint.shape.count >= 1 else {
throw KitError.modelIncompatible("model has no 'input_ids' multi-array input")
}
self.seqLen = constraint.shape[constraint.shape.count - 1].intValue
guard seqLen > 2 else { throw KitError.modelIncompatible("sequence length too small: \(seqLen)") }
guard model.modelDescription.outputDescriptionsByName["logits"] != nil else {
throw KitError.modelIncompatible("model has no 'logits' output")
}
}
/// Run one window. `ids` and `mask` must each have length `seqLen`.
/// Returns `seqLen` boundary logits (one per position). EN+ZH id remapping,
/// if configured, is applied here so callers always pass full-vocab ids.
func predict(ids: [Int32], mask: [Int32]) throws -> [Float] {
precondition(ids.count == seqLen && mask.count == seqLen, "window must equal seqLen")
let idsArray = try MLMultiArray(shape: [1, NSNumber(value: seqLen)], dataType: .int32)
let maskArray = try MLMultiArray(shape: [1, NSNumber(value: seqLen)], dataType: .int32)
let idsPtr = idsArray.dataPointer.bindMemory(to: Int32.self, capacity: seqLen)
let maskPtr = maskArray.dataPointer.bindMemory(to: Int32.self, capacity: seqLen)
if let remap {
for k in 0..<seqLen { idsPtr[k] = remap.map(ids[k]) }
} else {
for k in 0..<seqLen { idsPtr[k] = ids[k] }
}
for k in 0..<seqLen { maskPtr[k] = mask[k] }
let input = try MLDictionaryFeatureProvider(dictionary: [
"input_ids": idsArray,
"attention_mask": maskArray,
])
let output = try model.prediction(from: input)
guard let logits = output.featureValue(for: "logits")?.multiArrayValue else {
throw KitError.modelIncompatible("prediction returned no 'logits' array")
}
return Self.readLogits(logits, count: seqLen)
}
/// Read the leading `count` boundary logits as `Float`, regardless of the
/// backing dtype (Float16 / Float32 / Double). The output is `[1, seqLen, 1]`
/// and contiguous, so logit for token `k` is at flat offset `k`.
private static func readLogits(_ array: MLMultiArray, count: Int) -> [Float] {
var out = [Float](repeating: 0, count: count)
let n = min(count, array.count)
switch array.dataType {
case .float16:
let p = array.dataPointer.bindMemory(to: Float16.self, capacity: array.count)
for k in 0..<n { out[k] = Float(p[k]) }
case .float32:
let p = array.dataPointer.bindMemory(to: Float.self, capacity: array.count)
for k in 0..<n { out[k] = p[k] }
case .double:
let p = array.dataPointer.bindMemory(to: Double.self, capacity: array.count)
for k in 0..<n { out[k] = Float(p[k]) }
default:
// Integer or future dtypes: read via NSNumber bridging.
for k in 0..<n { out[k] = array[k].floatValue }
}
return out
}
}