Spaces:
Runtime error
Runtime error
| // Copyright 2019 The TensorFlow Authors. All Rights Reserved. | |
| // | |
| // Licensed under the Apache License, Version 2.0 (the "License"); | |
| // you may not use this file except in compliance with the License. | |
| // You may obtain a copy of the License at | |
| // | |
| // http://www.apache.org/licenses/LICENSE-2.0 | |
| // | |
| // Unless required by applicable law or agreed to in writing, software | |
| // distributed under the License is distributed on an "AS IS" BASIS, | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| // See the License for the specific language governing permissions and | |
| // limitations under the License. | |
| import Accelerate | |
| import CoreImage | |
| import Foundation | |
| import TensorFlowLite | |
| import UIKit | |
| /// This class handles all data preprocessing and makes calls to run inference on a given frame | |
| /// by invoking the `Interpreter`. It then formats the inferences obtained. | |
| class ModelDataHandler { | |
| // MARK: - Private Properties | |
| /// TensorFlow Lite `Interpreter` object for performing inference on a given model. | |
| private var interpreter: Interpreter | |
| /// TensorFlow lite `Tensor` of model input and output. | |
| private var inputTensor: Tensor | |
| //private var heatsTensor: Tensor | |
| //private var offsetsTensor: Tensor | |
| private var outputTensor: Tensor | |
| // MARK: - Initialization | |
| /// A failable initializer for `ModelDataHandler`. A new instance is created if the model is | |
| /// successfully loaded from the app's main bundle. Default `threadCount` is 2. | |
| init( | |
| threadCount: Int = Constants.defaultThreadCount, | |
| delegate: Delegates = Constants.defaultDelegate | |
| ) throws { | |
| // Construct the path to the model file. | |
| guard | |
| let modelPath = Bundle.main.path( | |
| forResource: Model.file.name, | |
| ofType: Model.file.extension | |
| ) | |
| else { | |
| fatalError("Failed to load the model file with name: \(Model.file.name).") | |
| } | |
| // Specify the options for the `Interpreter`. | |
| var options = Interpreter.Options() | |
| options.threadCount = threadCount | |
| // Specify the delegates for the `Interpreter`. | |
| var delegates: [Delegate]? | |
| switch delegate { | |
| case .Metal: | |
| delegates = [MetalDelegate()] | |
| case .CoreML: | |
| if let coreMLDelegate = CoreMLDelegate() { | |
| delegates = [coreMLDelegate] | |
| } else { | |
| delegates = nil | |
| } | |
| default: | |
| delegates = nil | |
| } | |
| // Create the `Interpreter`. | |
| interpreter = try Interpreter(modelPath: modelPath, options: options, delegates: delegates) | |
| // Initialize input and output `Tensor`s. | |
| // Allocate memory for the model's input `Tensor`s. | |
| try interpreter.allocateTensors() | |
| // Get allocated input and output `Tensor`s. | |
| inputTensor = try interpreter.input(at: 0) | |
| outputTensor = try interpreter.output(at: 0) | |
| //heatsTensor = try interpreter.output(at: 0) | |
| //offsetsTensor = try interpreter.output(at: 1) | |
| /* | |
| // Check if input and output `Tensor`s are in the expected formats. | |
| guard (inputTensor.dataType == .uInt8) == Model.isQuantized else { | |
| fatalError("Unexpected Model: quantization is \(!Model.isQuantized)") | |
| } | |
| guard inputTensor.shape.dimensions[0] == Model.input.batchSize, | |
| inputTensor.shape.dimensions[1] == Model.input.height, | |
| inputTensor.shape.dimensions[2] == Model.input.width, | |
| inputTensor.shape.dimensions[3] == Model.input.channelSize | |
| else { | |
| fatalError("Unexpected Model: input shape") | |
| } | |
| guard heatsTensor.shape.dimensions[0] == Model.output.batchSize, | |
| heatsTensor.shape.dimensions[1] == Model.output.height, | |
| heatsTensor.shape.dimensions[2] == Model.output.width, | |
| heatsTensor.shape.dimensions[3] == Model.output.keypointSize | |
| else { | |
| fatalError("Unexpected Model: heat tensor") | |
| } | |
| guard offsetsTensor.shape.dimensions[0] == Model.output.batchSize, | |
| offsetsTensor.shape.dimensions[1] == Model.output.height, | |
| offsetsTensor.shape.dimensions[2] == Model.output.width, | |
| offsetsTensor.shape.dimensions[3] == Model.output.offsetSize | |
| else { | |
| fatalError("Unexpected Model: offset tensor") | |
| } | |
| */ | |
| } | |
| /// Runs Midas model with given image with given source area to destination area. | |
| /// | |
| /// - Parameters: | |
| /// - on: Input image to run the model. | |
| /// - from: Range of input image to run the model. | |
| /// - to: Size of view to render the result. | |
| /// - Returns: Result of the inference and the times consumed in every steps. | |
| func runMidas(on pixelbuffer: CVPixelBuffer, from source: CGRect, to dest: CGSize) | |
| //-> (Result, Times)? | |
| //-> (FlatArray<Float32>, Times)? | |
| -> ([Float], Int, Int, Times)? | |
| { | |
| // Start times of each process. | |
| let preprocessingStartTime: Date | |
| let inferenceStartTime: Date | |
| let postprocessingStartTime: Date | |
| // Processing times in miliseconds. | |
| let preprocessingTime: TimeInterval | |
| let inferenceTime: TimeInterval | |
| let postprocessingTime: TimeInterval | |
| preprocessingStartTime = Date() | |
| guard let data = preprocess(of: pixelbuffer, from: source) else { | |
| os_log("Preprocessing failed", type: .error) | |
| return nil | |
| } | |
| preprocessingTime = Date().timeIntervalSince(preprocessingStartTime) * 1000 | |
| inferenceStartTime = Date() | |
| inference(from: data) | |
| inferenceTime = Date().timeIntervalSince(inferenceStartTime) * 1000 | |
| postprocessingStartTime = Date() | |
| //guard let result = postprocess(to: dest) else { | |
| // os_log("Postprocessing failed", type: .error) | |
| // return nil | |
| //} | |
| postprocessingTime = Date().timeIntervalSince(postprocessingStartTime) * 1000 | |
| let results: [Float] | |
| switch outputTensor.dataType { | |
| case .uInt8: | |
| guard let quantization = outputTensor.quantizationParameters else { | |
| print("No results returned because the quantization values for the output tensor are nil.") | |
| return nil | |
| } | |
| let quantizedResults = [UInt8](outputTensor.data) | |
| results = quantizedResults.map { | |
| quantization.scale * Float(Int($0) - quantization.zeroPoint) | |
| } | |
| case .float32: | |
| results = [Float32](unsafeData: outputTensor.data) ?? [] | |
| default: | |
| print("Output tensor data type \(outputTensor.dataType) is unsupported for this example app.") | |
| return nil | |
| } | |
| let times = Times( | |
| preprocessing: preprocessingTime, | |
| inference: inferenceTime, | |
| postprocessing: postprocessingTime) | |
| return (results, Model.input.width, Model.input.height, times) | |
| } | |
| // MARK: - Private functions to run model | |
| /// Preprocesses given rectangle image to be `Data` of disired size by croping and resizing it. | |
| /// | |
| /// - Parameters: | |
| /// - of: Input image to crop and resize. | |
| /// - from: Target area to be cropped and resized. | |
| /// - Returns: The cropped and resized image. `nil` if it can not be processed. | |
| private func preprocess(of pixelBuffer: CVPixelBuffer, from targetSquare: CGRect) -> Data? { | |
| let sourcePixelFormat = CVPixelBufferGetPixelFormatType(pixelBuffer) | |
| assert(sourcePixelFormat == kCVPixelFormatType_32BGRA) | |
| // Resize `targetSquare` of input image to `modelSize`. | |
| let modelSize = CGSize(width: Model.input.width, height: Model.input.height) | |
| guard let thumbnail = pixelBuffer.resize(from: targetSquare, to: modelSize) | |
| else { | |
| return nil | |
| } | |
| // Remove the alpha component from the image buffer to get the initialized `Data`. | |
| let byteCount = | |
| Model.input.batchSize | |
| * Model.input.height * Model.input.width | |
| * Model.input.channelSize | |
| guard | |
| let inputData = thumbnail.rgbData( | |
| isModelQuantized: Model.isQuantized | |
| ) | |
| else { | |
| os_log("Failed to convert the image buffer to RGB data.", type: .error) | |
| return nil | |
| } | |
| return inputData | |
| } | |
| /* | |
| /// Postprocesses output `Tensor`s to `Result` with size of view to render the result. | |
| /// | |
| /// - Parameters: | |
| /// - to: Size of view to be displaied. | |
| /// - Returns: Postprocessed `Result`. `nil` if it can not be processed. | |
| private func postprocess(to viewSize: CGSize) -> Result? { | |
| // MARK: Formats output tensors | |
| // Convert `Tensor` to `FlatArray`. As Midas is not quantized, convert them to Float type | |
| // `FlatArray`. | |
| let heats = FlatArray<Float32>(tensor: heatsTensor) | |
| let offsets = FlatArray<Float32>(tensor: offsetsTensor) | |
| // MARK: Find position of each key point | |
| // Finds the (row, col) locations of where the keypoints are most likely to be. The highest | |
| // `heats[0, row, col, keypoint]` value, the more likely `keypoint` being located in (`row`, | |
| // `col`). | |
| let keypointPositions = (0..<Model.output.keypointSize).map { keypoint -> (Int, Int) in | |
| var maxValue = heats[0, 0, 0, keypoint] | |
| var maxRow = 0 | |
| var maxCol = 0 | |
| for row in 0..<Model.output.height { | |
| for col in 0..<Model.output.width { | |
| if heats[0, row, col, keypoint] > maxValue { | |
| maxValue = heats[0, row, col, keypoint] | |
| maxRow = row | |
| maxCol = col | |
| } | |
| } | |
| } | |
| return (maxRow, maxCol) | |
| } | |
| // MARK: Calculates total confidence score | |
| // Calculates total confidence score of each key position. | |
| let totalScoreSum = keypointPositions.enumerated().reduce(0.0) { accumulator, elem -> Float32 in | |
| accumulator + sigmoid(heats[0, elem.element.0, elem.element.1, elem.offset]) | |
| } | |
| let totalScore = totalScoreSum / Float32(Model.output.keypointSize) | |
| // MARK: Calculate key point position on model input | |
| // Calculates `KeyPoint` coordination model input image with `offsets` adjustment. | |
| let coords = keypointPositions.enumerated().map { index, elem -> (y: Float32, x: Float32) in | |
| let (y, x) = elem | |
| let yCoord = | |
| Float32(y) / Float32(Model.output.height - 1) * Float32(Model.input.height) | |
| + offsets[0, y, x, index] | |
| let xCoord = | |
| Float32(x) / Float32(Model.output.width - 1) * Float32(Model.input.width) | |
| + offsets[0, y, x, index + Model.output.keypointSize] | |
| return (y: yCoord, x: xCoord) | |
| } | |
| // MARK: Transform key point position and make lines | |
| // Make `Result` from `keypointPosition'. Each point is adjusted to `ViewSize` to be drawn. | |
| var result = Result(dots: [], lines: [], score: totalScore) | |
| var bodyPartToDotMap = [BodyPart: CGPoint]() | |
| for (index, part) in BodyPart.allCases.enumerated() { | |
| let position = CGPoint( | |
| x: CGFloat(coords[index].x) * viewSize.width / CGFloat(Model.input.width), | |
| y: CGFloat(coords[index].y) * viewSize.height / CGFloat(Model.input.height) | |
| ) | |
| bodyPartToDotMap[part] = position | |
| result.dots.append(position) | |
| } | |
| do { | |
| try result.lines = BodyPart.lines.map { map throws -> Line in | |
| guard let from = bodyPartToDotMap[map.from] else { | |
| throw PostprocessError.missingBodyPart(of: map.from) | |
| } | |
| guard let to = bodyPartToDotMap[map.to] else { | |
| throw PostprocessError.missingBodyPart(of: map.to) | |
| } | |
| return Line(from: from, to: to) | |
| } | |
| } catch PostprocessError.missingBodyPart(let missingPart) { | |
| os_log("Postprocessing error: %s is missing.", type: .error, missingPart.rawValue) | |
| return nil | |
| } catch { | |
| os_log("Postprocessing error: %s", type: .error, error.localizedDescription) | |
| return nil | |
| } | |
| return result | |
| } | |
| */ | |
| /// Run inference with given `Data` | |
| /// | |
| /// Parameter `from`: `Data` of input image to run model. | |
| private func inference(from data: Data) { | |
| // Copy the initialized `Data` to the input `Tensor`. | |
| do { | |
| try interpreter.copy(data, toInputAt: 0) | |
| // Run inference by invoking the `Interpreter`. | |
| try interpreter.invoke() | |
| // Get the output `Tensor` to process the inference results. | |
| outputTensor = try interpreter.output(at: 0) | |
| //heatsTensor = try interpreter.output(at: 0) | |
| //offsetsTensor = try interpreter.output(at: 1) | |
| } catch let error { | |
| os_log( | |
| "Failed to invoke the interpreter with error: %s", type: .error, | |
| error.localizedDescription) | |
| return | |
| } | |
| } | |
| /// Returns value within [0,1]. | |
| private func sigmoid(_ x: Float32) -> Float32 { | |
| return (1.0 / (1.0 + exp(-x))) | |
| } | |
| } | |
| // MARK: - Data types for inference result | |
| struct KeyPoint { | |
| var bodyPart: BodyPart = BodyPart.NOSE | |
| var position: CGPoint = CGPoint() | |
| var score: Float = 0.0 | |
| } | |
| struct Line { | |
| let from: CGPoint | |
| let to: CGPoint | |
| } | |
| struct Times { | |
| var preprocessing: Double | |
| var inference: Double | |
| var postprocessing: Double | |
| } | |
| struct Result { | |
| var dots: [CGPoint] | |
| var lines: [Line] | |
| var score: Float | |
| } | |
| enum BodyPart: String, CaseIterable { | |
| case NOSE = "nose" | |
| case LEFT_EYE = "left eye" | |
| case RIGHT_EYE = "right eye" | |
| case LEFT_EAR = "left ear" | |
| case RIGHT_EAR = "right ear" | |
| case LEFT_SHOULDER = "left shoulder" | |
| case RIGHT_SHOULDER = "right shoulder" | |
| case LEFT_ELBOW = "left elbow" | |
| case RIGHT_ELBOW = "right elbow" | |
| case LEFT_WRIST = "left wrist" | |
| case RIGHT_WRIST = "right wrist" | |
| case LEFT_HIP = "left hip" | |
| case RIGHT_HIP = "right hip" | |
| case LEFT_KNEE = "left knee" | |
| case RIGHT_KNEE = "right knee" | |
| case LEFT_ANKLE = "left ankle" | |
| case RIGHT_ANKLE = "right ankle" | |
| /// List of lines connecting each part. | |
| static let lines = [ | |
| (from: BodyPart.LEFT_WRIST, to: BodyPart.LEFT_ELBOW), | |
| (from: BodyPart.LEFT_ELBOW, to: BodyPart.LEFT_SHOULDER), | |
| (from: BodyPart.LEFT_SHOULDER, to: BodyPart.RIGHT_SHOULDER), | |
| (from: BodyPart.RIGHT_SHOULDER, to: BodyPart.RIGHT_ELBOW), | |
| (from: BodyPart.RIGHT_ELBOW, to: BodyPart.RIGHT_WRIST), | |
| (from: BodyPart.LEFT_SHOULDER, to: BodyPart.LEFT_HIP), | |
| (from: BodyPart.LEFT_HIP, to: BodyPart.RIGHT_HIP), | |
| (from: BodyPart.RIGHT_HIP, to: BodyPart.RIGHT_SHOULDER), | |
| (from: BodyPart.LEFT_HIP, to: BodyPart.LEFT_KNEE), | |
| (from: BodyPart.LEFT_KNEE, to: BodyPart.LEFT_ANKLE), | |
| (from: BodyPart.RIGHT_HIP, to: BodyPart.RIGHT_KNEE), | |
| (from: BodyPart.RIGHT_KNEE, to: BodyPart.RIGHT_ANKLE), | |
| ] | |
| } | |
| // MARK: - Delegates Enum | |
| enum Delegates: Int, CaseIterable { | |
| case CPU | |
| case Metal | |
| case CoreML | |
| var description: String { | |
| switch self { | |
| case .CPU: | |
| return "CPU" | |
| case .Metal: | |
| return "GPU" | |
| case .CoreML: | |
| return "NPU" | |
| } | |
| } | |
| } | |
| // MARK: - Custom Errors | |
| enum PostprocessError: Error { | |
| case missingBodyPart(of: BodyPart) | |
| } | |
| // MARK: - Information about the model file. | |
| typealias FileInfo = (name: String, extension: String) | |
| enum Model { | |
| static let file: FileInfo = ( | |
| name: "model_opt", extension: "tflite" | |
| ) | |
| static let input = (batchSize: 1, height: 256, width: 256, channelSize: 3) | |
| static let output = (batchSize: 1, height: 256, width: 256, channelSize: 1) | |
| static let isQuantized = false | |
| } | |
| extension Array { | |
| /// Creates a new array from the bytes of the given unsafe data. | |
| /// | |
| /// - Warning: The array's `Element` type must be trivial in that it can be copied bit for bit | |
| /// with no indirection or reference-counting operations; otherwise, copying the raw bytes in | |
| /// the `unsafeData`'s buffer to a new array returns an unsafe copy. | |
| /// - Note: Returns `nil` if `unsafeData.count` is not a multiple of | |
| /// `MemoryLayout<Element>.stride`. | |
| /// - Parameter unsafeData: The data containing the bytes to turn into an array. | |
| init?(unsafeData: Data) { | |
| guard unsafeData.count % MemoryLayout<Element>.stride == 0 else { return nil } | |
| #if swift(>=5.0) | |
| self = unsafeData.withUnsafeBytes { .init($0.bindMemory(to: Element.self)) } | |
| #else | |
| self = unsafeData.withUnsafeBytes { | |
| .init(UnsafeBufferPointer<Element>( | |
| start: $0, | |
| count: unsafeData.count / MemoryLayout<Element>.stride | |
| )) | |
| } | |
| #endif // swift(>=5.0) | |
| } | |
| } | |