Kyle Pearson
Pre-resize images to exact model dimensions, implement feathered blending to eliminate seam artifacts, cache model constraints to unify coordinate space
2806f00 | // | |
| // DepthPredictor.swift | |
| // Equirectangular Depth Map Inference via DAP CoreML Model | |
| // | |
| // Loads a DAP CoreML model, runs depth inference on an equirectangular | |
| // panorama image, and saves the depth map as a PNG file. | |
| // | |
| // Usage: | |
| // swiftc -O -o depth_predictor DepthPredictor.swift \ | |
| // -framework CoreML -framework Vision -framework CoreImage \ | |
| // -framework CoreGraphics -framework AppKit | |
| // ./depth_predictor -m DAPModel.mlpackage -i panorama.jpg -o depth.png -c jet | |
| import Foundation | |
| import CoreML | |
| import Vision | |
| import CoreImage | |
| import CoreGraphics | |
| import AppKit | |
| // MARK: - Colormap LUTs (computed once, cached) | |
| /// Packed RGB colormap entry — stored contiguously for cache-friendly LUT access. | |
| struct RGB { | |
| let r: UInt8 | |
| let g: UInt8 | |
| let b: UInt8 | |
| } | |
| /// Precomputed jet colormap lookup table (256 entries, built once). | |
| let jetLUT: [RGB] = { | |
| (0...255).map { i in | |
| let t = Float(i) / 255.0 | |
| let r, g, b: Float | |
| if t < 1.0 / 3.0 { | |
| r = 0; g = 0 | |
| b = 0.5 + 0.5 * (t * 3.0) | |
| } else if t < 2.0 / 3.0 { | |
| let u = (t - 1.0 / 3.0) * 3.0 | |
| r = 0 | |
| g = 0.5 + 0.5 * u | |
| b = 1.0 - u * 0.5 | |
| } else { | |
| let u = (t - 2.0 / 3.0) * 3.0 | |
| r = 0.5 + 0.5 * u | |
| g = 1.0 - u * 0.5 | |
| b = 0 | |
| } | |
| return RGB( | |
| r: UInt8(round(max(0, min(1, r)) * 255)), | |
| g: UInt8(round(max(0, min(1, g)) * 255)), | |
| b: UInt8(round(max(0, min(1, b)) * 255)) | |
| ) | |
| } | |
| }() | |
| /// Turbo colormap (Google's perceptually-uniform alternative to jet, built once). | |
| let turboLUT: [RGB] = { | |
| func channel(_ t: Float, _ c: (Float, Float, Float, Float, Float, Float)) -> Float { | |
| let t2 = t * t, t3 = t2 * t, t4 = t3 * t, t5 = t4 * t | |
| return max(0, min(1, c.0 * t5 + c.1 * t4 + c.2 * t3 + c.3 * t2 + c.4 * t + c.5)) | |
| } | |
| let rC = (-6.3733615 as Float, 15.04266179 as Float, -13.85162213 as Float, | |
| 5.08578778 as Float, -0.83861766 as Float, 0.16457028 as Float) | |
| let gC = ( 2.25531523 as Float, -11.37426878 as Float, 21.82122831 as Float, | |
| -18.71443039 as Float, 6.26060447 as Float, -0.68049933 as Float) | |
| let bC = (-4.13513668 as Float, 6.56872416 as Float, 4.79961124 as Float, | |
| -4.01387798 as Float, 1.33503302 as Float, 0.0088154 as Float) | |
| return (0...255).map { i in | |
| let t = Float(i) / 255.0 | |
| return RGB( | |
| r: UInt8(round(channel(t, rC) * 255)), | |
| g: UInt8(round(channel(t, gC) * 255)), | |
| b: UInt8(round(channel(t, bC) * 255)) | |
| ) | |
| } | |
| }() | |
| // MARK: - MLMultiArray Helpers | |
| /// Provides direct, strided read access to an MLMultiArray's Float32 data | |
| /// without copying. The caller must keep the source MLMultiArray alive for | |
| /// the lifetime of this wrapper. | |
| struct DepthArrayView { | |
| let ptr: UnsafeMutablePointer<Float32> | |
| let width: Int | |
| let height: Int | |
| let rowStride: Int // stride between rows in Float32 units | |
| init(_ multiArray: MLMultiArray) { | |
| width = multiArray.shape[3].intValue | |
| height = multiArray.shape[2].intValue | |
| rowStride = multiArray.strides[2].intValue | |
| ptr = multiArray.dataPointer.bindMemory(to: Float32.self, capacity: height * rowStride) | |
| } | |
| /// Read a single value at (row, col). | |
| (__always) | |
| func value(row: Int, col: Int) -> Float32 { | |
| ptr[row * rowStride + col] | |
| } | |
| /// Compute min/max across all values (skipping non-positive). | |
| func minMax() -> (min: Float32, max: Float32) { | |
| var lo: Float32 = .greatestFiniteMagnitude | |
| var hi: Float32 = -.greatestFiniteMagnitude | |
| for row in 0..<height { | |
| let base = row * rowStride | |
| for col in 0..<width { | |
| let v = ptr[base + col] | |
| if v > 0 { | |
| if v < lo { lo = v } | |
| if v > hi { hi = v } | |
| } | |
| } | |
| } | |
| return (lo, hi) | |
| } | |
| } | |
| // MARK: - Depth Result | |
| /// Holds the raw depth multi-array alongside the CIImage for rendering. | |
| struct DepthResult { | |
| let ciImage: CIImage | |
| let multiArray: MLMultiArray // [1, 1, H, W] Float32 | |
| var width: Int { multiArray.shape[3].intValue } | |
| var height: Int { multiArray.shape[2].intValue } | |
| /// Zero-copy view into the underlying depth data. | |
| var view: DepthArrayView { DepthArrayView(multiArray) } | |
| } | |
| // MARK: - Depth Predictor | |
| final class DepthPredictor { | |
| private var visionModel: VNCoreMLModel? | |
| private var _outputHeight = 512 | |
| private var _outputWidth = 1024 | |
| private var _modelInputWidth: Int = 0 | |
| private var _modelInputHeight: Int = 0 | |
| var outputHeight: Int { _outputHeight } | |
| var outputWidth: Int { _outputWidth } | |
| /// Model's expected input dimensions, read from the CoreML model's image | |
| /// constraints at load time. Used to manually resize source images so that | |
| /// Vision's `.scaleFit` becomes a no-op (no letterboxing, no implicit | |
| /// bilinear downscale). Zero if the model isn't loaded. | |
| var modelInputWidth: Int { _modelInputWidth } | |
| var modelInputHeight: Int { _modelInputHeight } | |
| var isLoaded: Bool { visionModel != nil } | |
| /// Load model dynamically from a .mlpackage or .mlmodelc URL. | |
| init(modelURL: URL, computeUnits: MLComputeUnits = .all) { | |
| setupModel(modelURL: modelURL, computeUnits: computeUnits) | |
| } | |
| // MARK: Inference | |
| /// Predict depth from a CGImage. Completion receives a ``DepthResult`` with | |
| /// both a renderable CIImage and the raw Float32 depth multi-array. | |
| /// | |
| /// - Parameter fixSeam: When true, runs dual-inference seam fix: infers depth | |
| /// on both the original and a half-shifted copy, then patches the seam region | |
| /// from the shifted result into the original to eliminate edge artifacts. | |
| /// - Parameter debugDir: When provided, intermediate depth maps are saved here | |
| /// for debugging (depth_original.png, depth_shifted.png, depth_stitched.png). | |
| func predictDepth( | |
| from cgImage: CGImage, | |
| fixSeam: Bool = true, | |
| debugDir: URL? = nil, | |
| completion: @escaping (DepthResult?) -> Void | |
| ) { | |
| if fixSeam { | |
| fixSeamWithDualInference(on: cgImage, debugDir: debugDir, completion: completion) | |
| } else { | |
| runSingleInference(on: cgImage) { result in | |
| completion(result) | |
| } | |
| } | |
| } | |
| /// Run a single pass of depth inference on a CGImage. | |
| /// | |
| /// The image is resized to the model's expected input dimensions using | |
| /// high-quality interpolation *before* being handed to Vision. This makes | |
| /// `imageCropAndScaleOption = .scaleFit` effectively a no-op and avoids | |
| /// two failure modes of letting Vision do the resize: | |
| /// 1. Letterboxing on inputs whose aspect ratio doesn't exactly match | |
| /// the model (Vision pads with black, polluting depth predictions). | |
| /// 2. Implicit bilinear downscale, which loses high-frequency detail | |
| /// compared to PIL's Lanczos resize used in the Python export script. | |
| private func runSingleInference(on cgImage: CGImage, completion: @escaping (DepthResult?) -> Void) { | |
| guard let visionModel else { | |
| print("[DepthPredictor] Model not loaded") | |
| completion(nil) | |
| return | |
| } | |
| // Pre-resize to exact model input dims (matches Python's PIL resize). | |
| let prepared: CGImage | |
| if _modelInputWidth > 0 && _modelInputHeight > 0, | |
| let resized = DepthPredictor.resizeImage(cgImage, | |
| toWidth: _modelInputWidth, | |
| height: _modelInputHeight) { | |
| prepared = resized | |
| } else { | |
| // Fallback: model dims unknown — let Vision handle scaling. | |
| prepared = cgImage | |
| } | |
| let request = VNCoreMLRequest(model: visionModel) { [weak self] request, error in | |
| if let error { | |
| print("[DepthPredictor] Inference error: \(error)") | |
| completion(nil) | |
| return | |
| } | |
| guard let observations = request.results as? [VNCoreMLFeatureValueObservation], | |
| let observation = observations.first, | |
| let multiArray = observation.featureValue.multiArrayValue | |
| else { | |
| print("[DepthPredictor] No depth output in results") | |
| completion(nil) | |
| return | |
| } | |
| guard let ciImage = self?.multiArrayToCIImage(multiArray) else { | |
| completion(nil) | |
| return | |
| } | |
| completion(DepthResult(ciImage: ciImage, multiArray: multiArray)) | |
| } | |
| request.imageCropAndScaleOption = .scaleFit | |
| let handler = VNImageRequestHandler(cgImage: prepared, options: [:]) | |
| do { | |
| try handler.perform([request]) | |
| } catch { | |
| print("[DepthPredictor] Vision request failed: \(error)") | |
| completion(nil) | |
| } | |
| } | |
| /// Fix the left/right seam by running depth inference on both the original | |
| /// and a half-shifted copy, then stitching the shifted seam region into the | |
| /// original depth map. | |
| /// | |
| /// Strategy (mirrors the Python approach): | |
| /// 1. Run depth inference on the original equirectangular image. | |
| /// 2. Roll the image left by half its width so the seam moves to the center. | |
| /// 3. Run depth inference on the shifted image — the center of this result | |
| /// covers what was the original seam, artifact-free. | |
| /// 4. Roll the original depth left by half (matching the shifted coordinate | |
| /// space), paste a strip from the shifted depth over the center, then | |
| /// roll the result back to the original orientation. | |
| /// | |
| /// - Parameter patchHalfWidth: Half-width of the strip (in depth-map pixels) | |
| /// to paste from the shifted depth. The total patch width is 2× this value. | |
| /// Defaults to 25 px, which works well for 1024-wide depth outputs. Scale | |
| /// proportionally for other resolutions. | |
| private func fixSeamWithDualInference( | |
| on cgImage: CGImage, | |
| debugDir: URL?, | |
| patchHalfWidth: Int = 25, | |
| completion: @escaping (DepthResult?) -> Void | |
| ) { | |
| // Resize source to model input dims *once*, so both inference passes | |
| // and the horizontal shift all happen in the same coordinate space. | |
| // This avoids resampling twice and keeps the shift offset exact in | |
| // the same pixel grid as the depth output. | |
| let prepared: CGImage | |
| if _modelInputWidth > 0 && _modelInputHeight > 0, | |
| let resized = DepthPredictor.resizeImage(cgImage, | |
| toWidth: _modelInputWidth, | |
| height: _modelInputHeight) { | |
| prepared = resized | |
| } else { | |
| prepared = cgImage | |
| } | |
| let imageWidth = prepared.width | |
| let half = imageWidth / 2 | |
| // Shift the source image left by half — the seam moves to the center | |
| guard let shiftedImage = DepthPredictor.shiftImageHorizontally(prepared, by: half) else { | |
| print("[DepthPredictor] Failed to shift image for seam fix") | |
| completion(nil) | |
| return | |
| } | |
| // Debug: save shifted input | |
| if let debugDir { | |
| try? DepthPredictor.saveImage( | |
| CIImage(cgImage: shiftedImage), | |
| to: debugDir.appendingPathComponent("input_shifted.png") | |
| ) | |
| } | |
| // 1. Infer depth on the (resized) original image | |
| runSingleInference(on: prepared) { [weak self] originalDepth in | |
| guard let self, let originalDepth else { | |
| completion(nil) | |
| return | |
| } | |
| if let debugDir { | |
| try? DepthPredictor.saveDepthAsGrayscale( | |
| originalDepth, | |
| to: debugDir.appendingPathComponent("depth_original.png") | |
| ) | |
| } | |
| // 2. Infer depth on the shifted image | |
| self.runSingleInference(on: shiftedImage) { shiftedDepth in | |
| guard let shiftedDepth else { | |
| completion(nil) | |
| return | |
| } | |
| let w = originalDepth.width | |
| let h = originalDepth.height | |
| if let debugDir { | |
| try? DepthPredictor.saveDepthAsGrayscale( | |
| shiftedDepth, | |
| to: debugDir.appendingPathComponent("depth_shifted.png") | |
| ) | |
| } | |
| // 3. Stitch: roll original depth, patch center, roll back | |
| guard let stitched = self.stitchSeamFromShiftedDepth( | |
| original: originalDepth.multiArray, | |
| shifted: shiftedDepth.multiArray, | |
| width: w, | |
| height: h, | |
| depthHalf: w / 2, | |
| patchHalfWidth: patchHalfWidth | |
| ) else { | |
| completion(nil) | |
| return | |
| } | |
| let ciImage = self.multiArrayToCIImage(stitched) ?? originalDepth.ciImage | |
| if let debugDir { | |
| let stitchedResult = DepthResult(ciImage: ciImage, multiArray: stitched) | |
| try? DepthPredictor.saveDepthAsGrayscale( | |
| stitchedResult, | |
| to: debugDir.appendingPathComponent("depth_stitched.png") | |
| ) | |
| } | |
| completion(DepthResult(ciImage: ciImage, multiArray: stitched)) | |
| } | |
| } | |
| } | |
| /// Stitch the seam region using a single output buffer with **feathered** | |
| /// blending at the patch boundaries — no intermediate copies. | |
| /// | |
| /// The two inference passes (original and half-shifted) produce slightly | |
| /// different absolute depth values even where they agree on geometry, | |
| /// because they're independent forward passes through a non-linear model. | |
| /// A hard cutover at the patch boundary therefore leaves a visible step. | |
| /// To avoid this, we linearly blend from original→shifted as the column | |
| /// enters the patch zone and from shifted→original as it leaves, using a | |
| /// transition band of `featherWidth` pixels on each side. | |
| /// | |
| /// Layout in *shifted* coordinate space (centered at width/2): | |
| /// | |
| /// [ original ][ feather ][ shifted ][ feather ][ original ] | |
| /// ^ ^ ^ ^ | |
| /// patchLeft coreLeft coreRight patchRight | |
| /// | |
| /// - Outside `[patchLeft, patchRight)`: pure original. | |
| /// - Inside `[coreLeft, coreRight)`: pure shifted. | |
| /// - In the two feather bands: linear blend, weight 0→1 across the band. | |
| /// | |
| /// `featherWidth` is clamped so the feather bands never overlap the core. | |
| private func stitchSeamFromShiftedDepth( | |
| original: MLMultiArray, | |
| shifted: MLMultiArray, | |
| width: Int, | |
| height: Int, | |
| depthHalf: Int, | |
| patchHalfWidth: Int, | |
| featherWidth: Int = 12 | |
| ) -> MLMultiArray? { | |
| let origView = DepthArrayView(original) | |
| let shiftView = DepthArrayView(shifted) | |
| // Patch zone in the *shifted* coordinate space is centered at width/2 | |
| let centerX = width / 2 | |
| let dx = min(patchHalfWidth, centerX) | |
| let patchLeft = centerX - dx | |
| let patchRight = centerX + dx // exclusive | |
| // Clamp feather so the two bands don't overlap (each band must fit | |
| // within half the patch width, leaving at least one pure-shifted col). | |
| let maxFeather = max(0, dx - 1) | |
| let feather = min(max(0, featherWidth), maxFeather) | |
| let coreLeft = patchLeft + feather | |
| let coreRight = patchRight - feather // exclusive | |
| // Create output MLMultiArray | |
| let output: MLMultiArray | |
| do { | |
| output = try MLMultiArray(shape: original.shape.map { $0 }, dataType: original.dataType) | |
| } catch { | |
| print("[DepthPredictor] Failed to create MLMultiArray for stitch: \(error)") | |
| return nil | |
| } | |
| let outStride = output.strides[2].intValue | |
| let outPtr = output.dataPointer.bindMemory(to: Float32.self, capacity: width * height) | |
| // Precompute reciprocal once (avoid div-by-zero when feather == 0). | |
| let invFeather: Float32 = feather > 0 ? 1.0 / Float32(feather) : 0.0 | |
| for row in 0..<height { | |
| let outBase = row * outStride | |
| for col in 0..<width { | |
| // Map this output col into the shifted coordinate space: | |
| // shifting left by depthHalf means shiftedCol = (col + depthHalf) % width | |
| let shiftedCol = (col + depthHalf) % width | |
| if shiftedCol < patchLeft || shiftedCol >= patchRight { | |
| // Outside patch zone — pure original (identity mapping). | |
| outPtr[outBase + col] = origView.value(row: row, col: col) | |
| } else if shiftedCol >= coreLeft && shiftedCol < coreRight { | |
| // Core patch zone — pure shifted. | |
| outPtr[outBase + col] = shiftView.value(row: row, col: shiftedCol) | |
| } else { | |
| // Feather band — linear blend. | |
| // Weight w: 0 at the outer patch edge, 1 at the core edge. | |
| let w: Float32 | |
| if shiftedCol < coreLeft { | |
| // Left feather: ramp up as we move right toward coreLeft. | |
| w = Float32(shiftedCol - patchLeft) * invFeather | |
| } else { | |
| // Right feather: ramp down as we move right toward patchRight. | |
| w = Float32(patchRight - 1 - shiftedCol) * invFeather | |
| } | |
| let wClamped = max(0.0 as Float32, min(1.0 as Float32, w)) | |
| let origVal = origView.value(row: row, col: col) | |
| let shiftVal = shiftView.value(row: row, col: shiftedCol) | |
| outPtr[outBase + col] = origVal + (shiftVal - origVal) * wClamped | |
| } | |
| } | |
| } | |
| return output | |
| } | |
| // MARK: Colormap | |
| /// Apply a jet colormap to depth values -> 8-bit RGB CIImage. | |
| func applyJetColormap(to depth: DepthResult) -> CIImage? { | |
| applyColormap(to: depth, lut: jetLUT) | |
| } | |
| /// Apply a turbo colormap to depth values -> 8-bit RGB CIImage. | |
| func applyTurboColormap(to depth: DepthResult) -> CIImage? { | |
| applyColormap(to: depth, lut: turboLUT) | |
| } | |
| /// Apply a grayscale visualization with optional contrast. | |
| func applyGrayscale(to ciImage: CIImage, contrast: CGFloat = 1.0) -> CIImage { | |
| guard let filter = CIFilter(name: "CIColorControls") else { return ciImage } | |
| filter.setDefaults() | |
| filter.setValue(ciImage, forKey: kCIInputImageKey) | |
| filter.setValue(contrast, forKey: kCIInputContrastKey) | |
| filter.setValue(0.0, forKey: kCIInputBrightnessKey) | |
| filter.setValue(1.0, forKey: kCIInputSaturationKey) | |
| return filter.outputImage ?? ciImage | |
| } | |
| /// Apply a colormap LUT to depth values, reading directly from the | |
| /// MLMultiArray without copying into an intermediate Swift array. | |
| private func applyColormap(to depth: DepthResult, lut: [RGB]) -> CIImage? { | |
| let dv = depth.view | |
| let (minDepth, maxDepth) = dv.minMax() | |
| let range = maxDepth - minDepth | |
| let invRange: Float32 = range > 0 ? 1.0 / range : 1.0 | |
| let outputBufferSize = dv.width * dv.height * 4 | |
| guard let outputBuffer = malloc(outputBufferSize) else { return nil } | |
| defer { free(outputBuffer) } | |
| let outPtr = outputBuffer.bindMemory(to: UInt8.self, capacity: outputBufferSize) | |
| for row in 0..<dv.height { | |
| let rowBase = row * dv.rowStride | |
| let outRowBase = row * dv.width * 4 | |
| for col in 0..<dv.width { | |
| let normalized = max(0, min(1, (dv.ptr[rowBase + col] - minDepth) * invRange)) | |
| let index = min(Int(normalized * 255), 255) | |
| let color = lut[index] | |
| let px = outRowBase + col * 4 | |
| outPtr[px] = color.r | |
| outPtr[px + 1] = color.g | |
| outPtr[px + 2] = color.b | |
| outPtr[px + 3] = 255 | |
| } | |
| } | |
| let colorSpace = CGColorSpaceCreateDeviceRGB() | |
| guard let bitmapContext = CGContext( | |
| data: outPtr, | |
| width: dv.width, | |
| height: dv.height, | |
| bitsPerComponent: 8, | |
| bytesPerRow: dv.width * 4, | |
| space: colorSpace, | |
| bitmapInfo: CGImageAlphaInfo.noneSkipLast.rawValue | |
| ) else { return nil } | |
| guard let cgImage = bitmapContext.makeImage() else { return nil } | |
| return CIImage(cgImage: cgImage) | |
| } | |
| // MARK: Save | |
| /// Save depth values as a 16-bit grayscale PNG (normalized to [0, 65535]). | |
| /// Reads directly from the MLMultiArray — no intermediate Float32 copy. | |
| static func saveDepthAsGrayscale(_ depth: DepthResult, to path: URL) throws { | |
| let dv = depth.view | |
| let (minDepth, maxDepth) = dv.minMax() | |
| let range = maxDepth - minDepth | |
| let invRange: Float32 = range > 0 ? 1.0 / range : 1.0 | |
| // Create 16-bit grayscale buffer (big-endian) | |
| let bufferSize = dv.width * dv.height * 2 | |
| guard let buffer = malloc(bufferSize) else { | |
| throw NSError(domain: "DepthPredictor", code: 7, | |
| userInfo: [NSLocalizedDescriptionKey: "Failed to allocate buffer"]) | |
| } | |
| defer { free(buffer) } | |
| let outPtr = buffer.bindMemory(to: UInt8.self, capacity: bufferSize) | |
| for row in 0..<dv.height { | |
| let rowBase = row * dv.rowStride | |
| let outRowBase = row * dv.width * 2 | |
| for col in 0..<dv.width { | |
| let normalized = (dv.ptr[rowBase + col] - minDepth) * invRange | |
| let value = UInt16(max(0, min(65535, normalized * 65535))) | |
| let px = outRowBase + col * 2 | |
| outPtr[px] = UInt8(value >> 8) | |
| outPtr[px + 1] = UInt8(value & 0xFF) | |
| } | |
| } | |
| let colorSpace = CGColorSpaceCreateDeviceGray() | |
| guard let bitmapContext = CGContext( | |
| data: outPtr, | |
| width: dv.width, | |
| height: dv.height, | |
| bitsPerComponent: 16, | |
| bytesPerRow: dv.width * 2, | |
| space: colorSpace, | |
| bitmapInfo: CGImageAlphaInfo.none.rawValue | CGBitmapInfo.byteOrder16Big.rawValue | |
| ) else { | |
| throw NSError(domain: "DepthPredictor", code: 8, | |
| userInfo: [NSLocalizedDescriptionKey: "Failed to create 16-bit grayscale context"]) | |
| } | |
| guard let cgImage = bitmapContext.makeImage() else { | |
| throw NSError(domain: "DepthPredictor", code: 9, | |
| userInfo: [NSLocalizedDescriptionKey: "Failed to create CGImage"]) | |
| } | |
| try writePNG(cgImage, to: path) | |
| } | |
| /// Save any CGImage as a PNG file. | |
| static func writePNG(_ cgImage: CGImage, to path: URL) throws { | |
| let bitmapRep = NSBitmapImageRep(cgImage: cgImage) | |
| guard let pngData = bitmapRep.representation( | |
| using: .png, | |
| properties: [NSBitmapImageRep.PropertyKey.compressionFactor: 1.0] | |
| ) else { | |
| throw NSError(domain: "DepthPredictor", code: 5, | |
| userInfo: [NSLocalizedDescriptionKey: "Failed to encode PNG"]) | |
| } | |
| try pngData.write(to: path) | |
| } | |
| // MARK: Private | |
| private func setupModel(modelURL: URL, computeUnits: MLComputeUnits) { | |
| do { | |
| let config = MLModelConfiguration() | |
| config.computeUnits = computeUnits | |
| let compiledURL = try compileModelIfNeeded(at: modelURL) | |
| let model = try MLModel(contentsOf: compiledURL, configuration: config) | |
| // Capture the model's expected input dimensions so we can resize | |
| // source images ourselves (avoiding Vision's letterboxing + implicit | |
| // bilinear downscale). DAP exports use a single ImageType input. | |
| if let imageInput = model.modelDescription.inputDescriptionsByName.values | |
| .first(where: { $0.imageConstraint != nil }), | |
| let constraint = imageInput.imageConstraint { | |
| _modelInputWidth = constraint.pixelsWide | |
| _modelInputHeight = constraint.pixelsHigh | |
| print("[DepthPredictor] Model input: \(_modelInputWidth)x\(_modelInputHeight)") | |
| } else { | |
| print("[DepthPredictor] Warning: could not read model input image constraint; manual resize disabled") | |
| } | |
| visionModel = try VNCoreMLModel(for: model) | |
| print("[DepthPredictor] Model loaded from \(modelURL.path)") | |
| } catch { | |
| print("[DepthPredictor] Failed to load model: \(error)") | |
| visionModel = nil | |
| } | |
| } | |
| private func compileModelIfNeeded(at url: URL) throws -> URL { | |
| let ext = url.pathExtension.lowercased() | |
| if ext == "mlmodelc" { return url } | |
| guard ext == "mlpackage" || ext == "mlmodel" else { | |
| throw NSError(domain: "DepthPredictor", code: 1, | |
| userInfo: [NSLocalizedDescriptionKey: "Unsupported model format: \(ext)"]) | |
| } | |
| let cacheDir = FileManager.default.temporaryDirectory | |
| .appendingPathComponent("DepthPredictorCache") | |
| try? FileManager.default.createDirectory(at: cacheDir, withIntermediateDirectories: true) | |
| let modelName = url.deletingPathExtension().lastPathComponent | |
| let compiledPath = cacheDir.appendingPathComponent("\(modelName).mlmodelc") | |
| if FileManager.default.fileExists(atPath: compiledPath.path) { | |
| if let sourceDate = try? FileManager.default.attributesOfItem(atPath: url.path)[.modificationDate] as? Date, | |
| let cachedDate = try? FileManager.default.attributesOfItem(atPath: compiledPath.path)[.modificationDate] as? Date, | |
| cachedDate >= sourceDate { | |
| return compiledPath | |
| } | |
| try? FileManager.default.removeItem(at: compiledPath) | |
| } | |
| print("[DepthPredictor] Compiling model (this may take a moment)...") | |
| let startTime = CFAbsoluteTimeGetCurrent() | |
| let tempURL = try MLModel.compileModel(at: url) | |
| let elapsed = CFAbsoluteTimeGetCurrent() - startTime | |
| try? FileManager.default.removeItem(at: compiledPath) | |
| try FileManager.default.moveItem(at: tempURL, to: compiledPath) | |
| print("[DepthPredictor] Model compiled in \(String(format: "%.1f", elapsed))s") | |
| return compiledPath | |
| } | |
| private func multiArrayToCIImage(_ multiArray: MLMultiArray) -> CIImage? { | |
| let height = multiArray.shape[2].intValue | |
| let width = multiArray.shape[3].intValue | |
| _outputHeight = height | |
| _outputWidth = width | |
| var pixelBuffer: CVPixelBuffer? | |
| let status = CVPixelBufferCreate( | |
| kCFAllocatorDefault, | |
| width, | |
| height, | |
| kCVPixelFormatType_OneComponent32Float, | |
| nil, | |
| &pixelBuffer | |
| ) | |
| guard status == kCVReturnSuccess, let buffer = pixelBuffer else { | |
| print("[DepthPredictor] Failed to create CVPixelBuffer") | |
| return nil | |
| } | |
| CVPixelBufferLockBaseAddress(buffer, []) | |
| defer { CVPixelBufferUnlockBaseAddress(buffer, []) } | |
| guard let destination = CVPixelBufferGetBaseAddress(buffer) else { return nil } | |
| let planeStride = multiArray.strides[2].intValue | |
| let srcBase = multiArray.dataPointer.bindMemory(to: Float32.self, capacity: height * planeStride) | |
| let rowBytes = width * MemoryLayout<Float32>.stride | |
| for h in 0..<height { | |
| let srcRow = srcBase.advanced(by: h * planeStride) | |
| let dstRow = destination.advanced(by: h * rowBytes) | |
| memcpy(dstRow, srcRow, rowBytes) | |
| } | |
| return CIImage(cvPixelBuffer: buffer) | |
| } | |
| } | |
| // MARK: - Image Shifting | |
| extension DepthPredictor { | |
| /// Horizontally roll a CGImage by `offset` pixels (positive = shift left, wrapping around). | |
| /// | |
| /// Draws the source image twice into a CGContext with horizontal translations | |
| /// so the pixels wrap around correctly. | |
| static func shiftImageHorizontally(_ cgImage: CGImage, by offset: Int) -> CGImage? { | |
| let w = cgImage.width | |
| let h = cgImage.height | |
| let actualOffset = offset % w | |
| guard actualOffset > 0 else { return cgImage } | |
| let colorSpace = cgImage.colorSpace ?? CGColorSpaceCreateDeviceRGB() | |
| // Try with the source bitmapInfo first, fall back to explicit RGBA | |
| var bitmapInfoRaw: UInt32 = cgImage.bitmapInfo.rawValue | |
| var ctx: CGContext? | |
| ctx = CGContext(data: nil, width: w, height: h, bitsPerComponent: 8, | |
| bytesPerRow: 0, space: colorSpace, bitmapInfo: bitmapInfoRaw) | |
| if ctx == nil { | |
| bitmapInfoRaw = CGBitmapInfo.byteOrder32Little.rawValue | CGImageAlphaInfo.noneSkipLast.rawValue | |
| ctx = CGContext(data: nil, width: w, height: h, bitsPerComponent: 8, | |
| bytesPerRow: 0, space: colorSpace, bitmapInfo: bitmapInfoRaw) | |
| } | |
| guard let context = ctx else { | |
| print("[DepthPredictor] shiftImageHorizontally: CGContext creation failed (source bitmapInfo=0x\(String(cgImage.bitmapInfo.rawValue, radix: 16)))") | |
| return nil | |
| } | |
| // Draw source shifted left by actualOffset (wraps: right portion appears on left) | |
| context.translateBy(x: -CGFloat(actualOffset), y: 0) | |
| context.draw(cgImage, in: CGRect(x: 0, y: 0, width: w, height: h)) | |
| // Draw again at +w to fill the wrap-around on the right | |
| context.translateBy(x: CGFloat(w), y: 0) | |
| context.draw(cgImage, in: CGRect(x: 0, y: 0, width: w, height: h)) | |
| guard let result = context.makeImage() else { | |
| print("[DepthPredictor] shiftImageHorizontally: makeImage() returned nil") | |
| return nil | |
| } | |
| return result | |
| } | |
| } | |
| // MARK: - Image Loading | |
| extension DepthPredictor { | |
| /// Load an image from a file path and return a CGImage. | |
| static func loadImage(at path: URL) throws -> CGImage { | |
| guard let nsImage = NSImage(contentsOf: path) else { | |
| throw NSError(domain: "DepthPredictor", code: 2, | |
| userInfo: [NSLocalizedDescriptionKey: "Failed to load image from \(path.path)"]) | |
| } | |
| guard let cgImage = nsImage.cgImage(forProposedRect: nil, context: nil, hints: nil) else { | |
| throw NSError(domain: "DepthPredictor", code: 3, | |
| userInfo: [NSLocalizedDescriptionKey: "Failed to convert image to CGImage"]) | |
| } | |
| return cgImage | |
| } | |
| /// Resize a CGImage to exact `(width, height)` using high-quality | |
| /// interpolation (Lanczos-equivalent on macOS). Returns nil if context | |
| /// creation fails. | |
| /// | |
| /// This is used to pre-resize the source image to the model's expected | |
| /// input dimensions *before* handing off to Vision. Doing so makes | |
| /// `imageCropAndScaleOption = .scaleFit` a no-op — no letterboxing on | |
| /// non-matching aspect ratios, and no implicit bilinear downscale. | |
| static func resizeImage(_ cgImage: CGImage, toWidth width: Int, height: Int) -> CGImage? { | |
| guard width > 0, height > 0 else { return nil } | |
| if cgImage.width == width && cgImage.height == height { | |
| return cgImage | |
| } | |
| let colorSpace = CGColorSpaceCreateDeviceRGB() | |
| let bitmapInfo = CGBitmapInfo.byteOrder32Little.rawValue | |
| | CGImageAlphaInfo.noneSkipLast.rawValue | |
| guard let ctx = CGContext( | |
| data: nil, | |
| width: width, | |
| height: height, | |
| bitsPerComponent: 8, | |
| bytesPerRow: 0, | |
| space: colorSpace, | |
| bitmapInfo: bitmapInfo | |
| ) else { | |
| print("[DepthPredictor] resizeImage: CGContext creation failed") | |
| return nil | |
| } | |
| ctx.interpolationQuality = .high | |
| ctx.draw(cgImage, in: CGRect(x: 0, y: 0, width: width, height: height)) | |
| return ctx.makeImage() | |
| } | |
| /// Save a CIImage as a PNG file (renders via CIContext first). | |
| static func saveImage(_ ciImage: CIImage, to path: URL) throws { | |
| let context = CIContext() | |
| let extent = ciImage.extent | |
| guard let cgImage = context.createCGImage(ciImage, from: extent) else { | |
| throw NSError(domain: "DepthPredictor", code: 4, | |
| userInfo: [NSLocalizedDescriptionKey: "Failed to create CGImage from CIImage"]) | |
| } | |
| try writePNG(cgImage, to: path) | |
| } | |
| } | |
| // MARK: - Command Line Arguments | |
| struct CommandLineArgs { | |
| let modelPath: URL | |
| let imagePath: URL | |
| let outputPath: URL | |
| let colormap: String // "grayscale", "jet", "turbo" | |
| let fixSeam: Bool | |
| let debugSeamDir: URL? // directory for intermediate seam-fix outputs | |
| static func parse() -> CommandLineArgs? { | |
| let args = CommandLine.arguments | |
| var modelPath: URL? | |
| var imagePath: URL? | |
| var outputPath: URL? | |
| var colormap = "grayscale" | |
| var fixSeam = true | |
| var debugSeamDir: URL? | |
| var i = 1 | |
| while i < args.count { | |
| let arg = args[i] | |
| switch arg { | |
| case "-m", "--model": | |
| i += 1 | |
| if i < args.count { modelPath = URL(fileURLWithPath: args[i]) } | |
| case "-i", "--input": | |
| i += 1 | |
| if i < args.count { imagePath = URL(fileURLWithPath: args[i]) } | |
| case "-o", "--output": | |
| i += 1 | |
| if i < args.count { outputPath = URL(fileURLWithPath: args[i]) } | |
| case "-c", "--colormap": | |
| i += 1 | |
| if i < args.count { colormap = args[i].lowercased() } | |
| case "-f", "--fix-seam": | |
| fixSeam = true | |
| case "--no-fix-seam": | |
| fixSeam = false | |
| case "--debug-seam": | |
| i += 1 | |
| if i < args.count { debugSeamDir = URL(fileURLWithPath: args[i]) } | |
| case "-h", "--help": | |
| printUsage() | |
| return nil | |
| default: | |
| // Positional fallback | |
| if modelPath == nil { | |
| modelPath = URL(fileURLWithPath: arg) | |
| } else if imagePath == nil { | |
| imagePath = URL(fileURLWithPath: arg) | |
| } else if outputPath == nil { | |
| outputPath = URL(fileURLWithPath: arg) | |
| } | |
| } | |
| i += 1 | |
| } | |
| guard let m = modelPath, let image = imagePath, let output = outputPath else { | |
| printUsage() | |
| return nil | |
| } | |
| guard ["grayscale", "jet", "turbo"].contains(colormap) else { | |
| print("Error: Unknown colormap '\(colormap)'. Use: grayscale, jet, turbo") | |
| return nil | |
| } | |
| return CommandLineArgs( | |
| modelPath: m, | |
| imagePath: image, | |
| outputPath: output, | |
| colormap: colormap, | |
| fixSeam: fixSeam, | |
| debugSeamDir: debugSeamDir | |
| ) | |
| } | |
| static func printUsage() { | |
| let execName = CommandLine.arguments[0].components(separatedBy: "/").last ?? "depth_predictor" | |
| print(""" | |
| Usage: \(execName) [OPTIONS] <model> <input_image> <output.png> | |
| Depth Map Predictor - Generate depth maps from equirectangular panoramas | |
| Arguments: | |
| model Path to DAP CoreML model (.mlpackage or .mlmodelc) | |
| input_image Path to input equirectangular panorama (2:1 aspect ratio) | |
| output.png Path for output depth map PNG | |
| Options: | |
| -m, --model PATH Path to CoreML model | |
| -i, --input PATH Path to input image | |
| -o, --output PATH Path for output PNG | |
| -c, --colormap STYLE Colormap: grayscale (default), jet, turbo | |
| grayscale = 16-bit depth values | |
| jet/turbo = 8-bit colorized visualization | |
| -f, --fix-seam Fix left/right seam artifact via dual-inference stitch (default: on) | |
| --no-fix-seam Disable seam fixing | |
| --debug-seam DIR Save intermediate seam-fix outputs to DIR/ | |
| (depth_original.png, depth_shifted.png, depth_stitched.png) | |
| -h, --help Show this help message | |
| Examples: | |
| # Grayscale depth map (16-bit) | |
| \(execName) DAPModel.mlpackage panorama.jpg depth.png | |
| # Colorized with jet colormap | |
| \(execName) -m DAPModel.mlpackage -i panorama.jpg -o depth.png -c jet | |
| # Debug seam fix intermediates | |
| \(execName) -m DAPModel.mlpackage -i panorama.jpg -o depth.png --debug-seam /tmp/seam_debug | |
| The model is automatically compiled on first use and cached for subsequent runs. | |
| """) | |
| } | |
| } | |
| // MARK: - Main | |
| func main() { | |
| guard let args = CommandLineArgs.parse() else { | |
| exit(1) | |
| } | |
| do { | |
| // Load model | |
| print("Loading model from \(args.modelPath.path)...") | |
| let predictor = DepthPredictor(modelURL: args.modelPath) | |
| guard predictor.isLoaded else { | |
| print("Error: Model failed to load") | |
| exit(1) | |
| } | |
| // Load image | |
| print("Loading image from \(args.imagePath.path)...") | |
| let cgImage = try DepthPredictor.loadImage(at: args.imagePath) | |
| print(" Image size: \(cgImage.width)x\(cgImage.height)") | |
| // Run inference (async -> sync via semaphore) | |
| let seamDebugDir: URL? = args.debugSeamDir | |
| if let debugDir = seamDebugDir { | |
| try FileManager.default.createDirectory(at: debugDir, withIntermediateDirectories: true) | |
| print("Seam debug outputs will be saved to \(debugDir.path)") | |
| } | |
| print("Running inference...") | |
| let startTime = CFAbsoluteTimeGetCurrent() | |
| var depthResult: DepthResult? | |
| let semaphore = DispatchSemaphore(value: 0) | |
| predictor.predictDepth(from: cgImage, fixSeam: args.fixSeam, debugDir: seamDebugDir) { result in | |
| depthResult = result | |
| semaphore.signal() | |
| } | |
| semaphore.wait() | |
| let inferenceTime = CFAbsoluteTimeGetCurrent() - startTime | |
| guard let depth = depthResult else { | |
| print("Error: Inference returned nil depth map") | |
| exit(1) | |
| } | |
| print("Depth map: \(depth.width)x\(depth.height) in \(String(format: "%.2f", inferenceTime))s") | |
| // Process & save | |
| print("Saving output...") | |
| switch args.colormap { | |
| case "grayscale": | |
| try DepthPredictor.saveDepthAsGrayscale(depth, to: args.outputPath) | |
| print("Saved 16-bit grayscale depth map to \(args.outputPath.path)") | |
| case "jet": | |
| guard let colorized = predictor.applyJetColormap(to: depth) else { | |
| print("Error: Jet colormap failed") | |
| exit(1) | |
| } | |
| try DepthPredictor.saveImage(colorized, to: args.outputPath) | |
| print("Saved jet colormap depth map to \(args.outputPath.path)") | |
| case "turbo": | |
| guard let colorized = predictor.applyTurboColormap(to: depth) else { | |
| print("Error: Turbo colormap failed") | |
| exit(1) | |
| } | |
| try DepthPredictor.saveImage(colorized, to: args.outputPath) | |
| print("Saved turbo colormap depth map to \(args.outputPath.path)") | |
| default: | |
| break | |
| } | |
| print("Complete!") | |
| } catch { | |
| print("Error: \(error.localizedDescription)") | |
| if let nsError = error as NSError? { | |
| print("Domain: \(nsError.domain), Code: \(nsError.code)") | |
| } | |
| exit(1) | |
| } | |
| } | |
| main() |