plapre-pico-coreml / swift-cli /Sources /CoreMLUtils.swift
Daniel Rothmann
Tidy up example CLI
1dfb01c
import Foundation
import CoreML
/// Compiles a CoreML model and configures it for CPU-only execution.
func compileModel(at url: URL) throws -> MLModel {
print(" Compiling \(url.lastPathComponent)...")
guard FileManager.default.fileExists(atPath: url.path) else {
throw CoreMLError.modelNotFound(url.path)
}
let compiled = try MLModel.compileModel(at: url)
let config = MLModelConfiguration()
config.computeUnits = .cpuOnly
return try MLModel(contentsOf: compiled, configuration: config)
}
/// Creates an MLMultiArray with Float16 data from Float values.
func makeFloat16Array(_ values: [Float], shape: [Int]) -> MLMultiArray {
let arr = try! MLMultiArray(shape: shape.map { NSNumber(value: $0) }, dataType: .float16)
let count = values.count
arr.withUnsafeMutableBufferPointer(ofType: Float16.self) { ptr, _ in
for i in 0..<count {
ptr[i] = Float16(values[i])
}
}
return arr
}
/// Creates an MLMultiArray with Float32 data.
func makeFloat32Array(_ values: [Float], shape: [Int]) -> MLMultiArray {
let arr = try! MLMultiArray(shape: shape.map { NSNumber(value: $0) }, dataType: .float32)
arr.withUnsafeMutableBufferPointer(ofType: Float.self) { dst, _ in
for i in 0..<values.count {
dst[i] = values[i]
}
}
return arr
}
/// Creates an MLMultiArray with Int32 data.
func makeInt32Array(_ values: [Int32], shape: [Int]) -> MLMultiArray {
let arr = try! MLMultiArray(shape: shape.map { NSNumber(value: $0) }, dataType: .int32)
arr.withUnsafeMutableBufferPointer(ofType: Int32.self) { dst, _ in
for i in 0..<values.count {
dst[i] = values[i]
}
}
return arr
}
/// Reads Float32 values from an MLMultiArray.
func readFloat32Array(_ array: MLMultiArray) -> [Float] {
let count = array.count
var result = [Float](repeating: 0, count: count)
array.withUnsafeBufferPointer(ofType: Float.self) { ptr in
for i in 0..<count {
result[i] = ptr[i]
}
}
return result
}
enum CoreMLError: LocalizedError {
case modelNotFound(String)
var errorDescription: String? {
switch self {
case .modelNotFound(let path):
return "Model not found: \(path)"
}
}
}