| import Foundation |
| import CoreML |
|
|
| |
| func compileModel(at url: URL) throws -> MLModel { |
| print(" Compiling \(url.lastPathComponent)...") |
|
|
| guard FileManager.default.fileExists(atPath: url.path) else { |
| throw CoreMLError.modelNotFound(url.path) |
| } |
|
|
| let compiled = try MLModel.compileModel(at: url) |
| let config = MLModelConfiguration() |
| config.computeUnits = .cpuOnly |
| return try MLModel(contentsOf: compiled, configuration: config) |
| } |
|
|
| |
| func makeFloat16Array(_ values: [Float], shape: [Int]) -> MLMultiArray { |
| let arr = try! MLMultiArray(shape: shape.map { NSNumber(value: $0) }, dataType: .float16) |
| let count = values.count |
| arr.withUnsafeMutableBufferPointer(ofType: Float16.self) { ptr, _ in |
| for i in 0..<count { |
| ptr[i] = Float16(values[i]) |
| } |
| } |
| return arr |
| } |
|
|
| |
| func makeFloat32Array(_ values: [Float], shape: [Int]) -> MLMultiArray { |
| let arr = try! MLMultiArray(shape: shape.map { NSNumber(value: $0) }, dataType: .float32) |
| arr.withUnsafeMutableBufferPointer(ofType: Float.self) { dst, _ in |
| for i in 0..<values.count { |
| dst[i] = values[i] |
| } |
| } |
| return arr |
| } |
|
|
| |
| func makeInt32Array(_ values: [Int32], shape: [Int]) -> MLMultiArray { |
| let arr = try! MLMultiArray(shape: shape.map { NSNumber(value: $0) }, dataType: .int32) |
| arr.withUnsafeMutableBufferPointer(ofType: Int32.self) { dst, _ in |
| for i in 0..<values.count { |
| dst[i] = values[i] |
| } |
| } |
| return arr |
| } |
|
|
| |
| func readFloat32Array(_ array: MLMultiArray) -> [Float] { |
| let count = array.count |
| var result = [Float](repeating: 0, count: count) |
| array.withUnsafeBufferPointer(ofType: Float.self) { ptr in |
| for i in 0..<count { |
| result[i] = ptr[i] |
| } |
| } |
| return result |
| } |
|
|
| enum CoreMLError: LocalizedError { |
| case modelNotFound(String) |
|
|
| var errorDescription: String? { |
| switch self { |
| case .modelNotFound(let path): |
| return "Model not found: \(path)" |
| } |
| } |
| } |
|
|