DepthAnyPanorama-coreml / PanoramaSplat.swift
Kyle Pearson
- Before: log(pixelFootprint * depth) — depth is normalized (~0–1), so the log compression kept far splats too small
2dd97e7
//
// PanoramaSplat.swift
// Convert equirectangular 360° panoramas to 3D Gaussian splat PLY files
//
// Uses a DAP CoreML depth model to estimate per-pixel depth from an
// equirectangular panorama, then projects each pixel onto a sphere
// to produce one Gaussian per pixel.
//
// Usage:
// swiftc -O -o panorama_splat PanoramaSplat.swift \
// -framework CoreML -framework Vision -framework CoreImage \
// -framework CoreGraphics -framework AppKit
// ./panorama_splat -m DAPModel.mlpackage -i panorama.jpg -o scene.ply -r 5.0
import Foundation
import CoreML
import Vision
import CoreImage
import CoreGraphics
import AppKit
// - Command Line Arguments
struct CLIArgs {
let modelPath: URL
let imagePath: URL
let outputPath: URL
let radius: Float
static func parse() -> CLIArgs? {
var modelPath: URL?
var imagePath: URL?
var outputPath: URL?
var radius: Float = 5.0
var i = 1
while i < CommandLine.arguments.count {
let arg = CommandLine.arguments[i]
switch arg {
case "-m", "--model":
i += 1; guard i < CommandLine.arguments.count else { return nil }
modelPath = URL(fileURLWithPath: CommandLine.arguments[i])
case "-i", "--input":
i += 1; guard i < CommandLine.arguments.count else { return nil }
imagePath = URL(fileURLWithPath: CommandLine.arguments[i])
case "-o", "--output":
i += 1; guard i < CommandLine.arguments.count else { return nil }
outputPath = URL(fileURLWithPath: CommandLine.arguments[i])
case "-r", "--radius":
i += 1; guard i < CommandLine.arguments.count else { return nil }
radius = Float(CommandLine.arguments[i]) ?? 5.0
case "-h", "--help":
printUsage(); return nil
default: break
}
i += 1
}
guard let m = modelPath, let img = imagePath, let out = outputPath else {
printUsage(); return nil
}
return CLIArgs(modelPath: m, imagePath: img, outputPath: out, radius: radius)
}
static func printUsage() {
let name = CommandLine.arguments[0].components(separatedBy: "/").last ?? "panorama_splat"
print("""
Usage: \(name) -m <model> -i <image> -o <output.ply> [-r radius]
Convert equirectangular panoramas to 3D Gaussian splat PLY files.
Options:
-m, --model PATH Path to DAP CoreML model (.mlpackage)
-i, --input PATH Path to equirectangular panorama (2:1 ratio)
-o, --output PATH Output PLY file path
-r, --radius FLOAT Sphere radius in world units (default: 5.0)
-h, --help Show this help
""")
}
}
// - CoreML Depth Inference
func compileModelIfNeeded(at url: URL) throws -> URL {
let ext = url.pathExtension.lowercased()
guard ext == "mlpackage" || ext == "mlmodel" || ext == "mlmodelc" else {
fatalError("Unsupported model format: \(ext)")
}
guard ext != "mlmodelc" else { return url }
let cacheDir = FileManager.default.temporaryDirectory
.appendingPathComponent("PanoramaSplatCache")
try FileManager.default.createDirectory(at: cacheDir, withIntermediateDirectories: true)
let compiled = cacheDir.appendingPathComponent("\(url.deletingPathExtension().lastPathComponent).mlmodelc")
if FileManager.default.fileExists(atPath: compiled.path) {
if let src = try? FileManager.default.attributesOfItem(atPath: url.path)[.modificationDate] as? Date,
let cch = try? FileManager.default.attributesOfItem(atPath: compiled.path)[.modificationDate] as? Date,
cch >= src {
return compiled
}
try? FileManager.default.removeItem(at: compiled)
}
print(" Compiling CoreML model ...")
let t = CFAbsoluteTimeGetCurrent()
let tmp = try MLModel.compileModel(at: url)
try? FileManager.default.removeItem(at: compiled)
try FileManager.default.moveItem(at: tmp, to: compiled)
print(" Compiled in \(String(format: "%.1fs", CFAbsoluteTimeGetCurrent() - t))")
return compiled
}
func runDepthInference(modelURL: URL, image: CGImage) throws -> (depths: [Float32], width: Int, height: Int) {
let compiled = try compileModelIfNeeded(at: modelURL)
let config = MLModelConfiguration()
config.computeUnits = .all
let model = try MLModel(contentsOf: compiled, configuration: config)
let vnModel = try VNCoreMLModel(for: model)
let request = VNCoreMLRequest(model: vnModel) { _, error in
if let error { fatalError("Inference error: \(error)") }
}
request.imageCropAndScaleOption = .scaleFit
let handler = VNImageRequestHandler(cgImage: image, options: [:])
try handler.perform([request])
guard let observations = request.results as? [VNCoreMLFeatureValueObservation],
let ma = observations.first?.featureValue.multiArrayValue else {
fatalError("No depth output from model")
}
let h = ma.shape[2].intValue
let w = ma.shape[3].intValue
let planeStride = ma.strides[2].intValue
let ptr = ma.dataPointer.bindMemory(to: Float32.self, capacity: h * w)
var depths = [Float32](repeating: 0, count: h * w)
for row in 0..<h {
let src = ptr.advanced(by: row * planeStride)
let dst = depths.withUnsafeMutableBufferPointer { $0.baseAddress!.advanced(by: row * w) }
memcpy(dst, src, w * MemoryLayout<Float32>.stride)
}
return (depths, w, h)
}
// Image Pixel Loading
/// Load image as RGBA pixels resized to target dimensions.
func loadImagePixels(_ image: CGImage, targetW: Int, targetH: Int) -> [UInt8] {
let ci = CIImage(cgImage: image)
let ctx = CIContext()
let scaled = ci.transformed(by: CGAffineTransform(scaleX: CGFloat(targetW) / ci.extent.width,
y: CGFloat(targetH) / ci.extent.height))
guard let resized = ctx.createCGImage(scaled, from: CGRect(x: 0, y: 0, width: targetW, height: targetH)) else {
fatalError("Failed to resize image to \(targetW)x\(targetH)")
}
let bpp = 4
let bpr = bpp * targetW
var pixels = [UInt8](repeating: 0, count: targetH * bpr)
let cs = CGColorSpaceCreateDeviceRGB()
guard let gctx = CGContext(data: &pixels, width: targetW, height: targetH,
bitsPerComponent: 8, bytesPerRow: bpr, space: cs,
bitmapInfo: CGImageAlphaInfo.premultipliedLast.rawValue) else {
fatalError("Failed to create bitmap context")
}
gctx.draw(resized, in: CGRect(x: 0, y: 0, width: targetW, height: targetH))
return pixels
}
// Equirectangular to 3D Projection
func equiToSphereDirection(u: Float, v: Float, width: Int, height: Int) -> (x: Float, y: Float, z: Float) {
let lon = (u / Float(width) - 0.5) * 2.0 * Float.pi
let lat = (0.5 - v / Float(height)) * Float.pi
let cosLat = cos(lat)
return (cosLat * cos(lon), sin(lat), cosLat * sin(lon))
}
// - PLY Export (binary_little_endian, matches Sharp format)
func writePLY(gaussians: [(x: Float, y: Float, z: Float,
f0: Float, f1: Float, f2: Float,
opacity: Float,
s0: Float, s1: Float, s2: Float,
q0: Float, q1: Float, q2: Float, q3: Float)],
focalLength: Float, imageW: Int, imageH: Int,
to url: URL) throws {
var data = Data()
func a(_ str: String) {
data.append(str.data(using: .ascii)!)
}
func f(_ v: Float) {
var vv = v; data.append(Data(bytes: &vv, count: 4))
}
func i32(_ v: Int32) {
var vv = v; data.append(Data(bytes: &vv, count: 4))
}
func u32(_ v: UInt32) {
var vv = v; data.append(Data(bytes: &vv, count: 4))
}
func u8(_ v: UInt8) {
var vv = v; data.append(Data(bytes: &vv, count: 1))
}
let n = gaussians.count
// --- Header ---
a("ply\n")
a("format binary_little_endian 1.0\n")
a("element vertex \(n)\n")
a("property float x\nproperty float y\nproperty float z\n")
a("property float f_dc_0\nproperty float f_dc_1\nproperty float f_dc_2\n")
a("property float opacity\n")
a("property float scale_0\nproperty float scale_1\nproperty float scale_2\n")
a("property float rot_0\nproperty float rot_1\nproperty float rot_2\nproperty float rot_3\n")
a("element extrinsic 16\nproperty float extrinsic\n")
a("element intrinsic 9\nproperty float intrinsic\n")
a("element image_size 2\nproperty uint image_size\n")
a("element frame 2\nproperty int frame\n")
a("element disparity 2\nproperty float disparity\n")
a("element color_space 1\nproperty uchar color_space\n")
a("element version 3\nproperty uchar version\n")
a("end_header\n")
// --- Vertex data ---
var disparities: [Float] = []
for g in gaussians {
f(g.x); f(g.y); f(g.z)
f(g.f0); f(g.f1); f(g.f2)
f(g.opacity)
f(g.s0); f(g.s1); f(g.s2)
f(g.q0); f(g.q1); f(g.q2); f(g.q3)
if g.z > 1e-6 { disparities.append(1.0 / g.z) }
}
// --- Extrinsic (identity 4x4) ---
let id: [Float] = [1,0,0,0, 0,1,0,0, 0,0,1,0, 0,0,0,1]
for v in id { f(v) }
// --- Intrinsic (3x3) ---
f(focalLength); f(0); f(Float(imageW) * 0.5)
f(0); f(focalLength); f(Float(imageH) * 0.5)
f(0); f(0); f(1)
// --- Image size ---
u32(UInt32(imageW)); u32(UInt32(imageH))
// --- Frame ---
i32(1); i32(Int32(n))
// --- Disparity quantiles ---
disparities.sort()
let d10 = disparities.isEmpty ? 0.0 : disparities[min(Int(Float(disparities.count) * 0.1), disparities.count - 1)]
let d90 = disparities.isEmpty ? 1.0 : disparities[min(Int(Float(disparities.count) * 0.9), disparities.count - 1)]
f(d10); f(d90)
// --- Color space (sRGB = 1) ---
u8(1)
// --- Version ---
u8(1); u8(5); u8(0)
try data.write(to: url, options: .atomic)
}
// - Image Shifting
/// Horizontally roll a CGImage by `offset` pixels (positive = shift left, wrapping around).
func shiftImageHorizontally(_ cgImage: CGImage, by offset: Int) -> CGImage? {
let w = cgImage.width
let h = cgImage.height
let actualOffset = offset % w
guard actualOffset > 0 else { return cgImage }
let colorSpace = cgImage.colorSpace ?? CGColorSpaceCreateDeviceRGB()
var bitmapInfoRaw: UInt32 = cgImage.bitmapInfo.rawValue
var ctx: CGContext?
ctx = CGContext(data: nil, width: w, height: h, bitsPerComponent: 8,
bytesPerRow: 0, space: colorSpace, bitmapInfo: bitmapInfoRaw)
if ctx == nil {
bitmapInfoRaw = CGBitmapInfo.byteOrder32Little.rawValue | CGImageAlphaInfo.noneSkipLast.rawValue
ctx = CGContext(data: nil, width: w, height: h, bitsPerComponent: 8,
bytesPerRow: 0, space: colorSpace, bitmapInfo: bitmapInfoRaw)
}
guard let context = ctx else { return nil }
context.translateBy(x: -CGFloat(actualOffset), y: 0)
context.draw(cgImage, in: CGRect(x: 0, y: 0, width: w, height: h))
context.translateBy(x: CGFloat(w), y: 0)
context.draw(cgImage, in: CGRect(x: 0, y: 0, width: w, height: h))
return context.makeImage()
}
// - Depth Map Seam Fix (dual-inference approach)
/// Fix the left/right seam by running depth inference on both the original and a
/// half-shifted copy, then stitching the shifted seam region into the original
/// depth map with feathered blending at the patch boundaries.
///
/// The two inference passes produce slightly different absolute depth values even
/// where they agree on geometry, because they're independent forward passes through
/// a non-linear model. A hard cutover at the patch boundary leaves a visible step,
/// so we linearly blend from original→shifted as we enter the patch zone and back.
///
/// Layout in *shifted* coordinate space (centered at width/2):
///
/// [ original ][ feather ][ shifted ][ feather ][ original ]
/// ^ ^ ^ ^
/// patchLeft coreLeft coreRight patchRight
///
/// - `patchHalfWidth`: half-width of the strip to paste from the shifted depth
/// - `featherWidth`: width of the linear blend band on each side of the core patch
func stitchSeamFromShiftedDepth(
original: [Float32],
shifted: [Float32],
width: Int,
height: Int,
depthHalf: Int,
patchHalfWidth: Int = 25,
featherWidth: Int = 12
) -> [Float32] {
let centerX = width / 2
let dx = min(patchHalfWidth, centerX)
let patchLeft = centerX - dx
let patchRight = centerX + dx
let maxFeather = max(0, dx - 1)
let feather = min(max(0, featherWidth), maxFeather)
let coreLeft = patchLeft + feather
let coreRight = patchRight - feather
let invFeather: Float = feather > 0 ? 1.0 / Float(feather) : 0.0
var result = [Float32](repeating: 0, count: width * height)
for row in 0..<height {
for col in 0..<width {
let shiftedCol = (col + depthHalf) % width
if shiftedCol < patchLeft || shiftedCol >= patchRight {
// Outside patch zone — pure original.
result[row * width + col] = original[row * width + col]
} else if shiftedCol >= coreLeft && shiftedCol < coreRight {
// Core patch zone — pure shifted.
result[row * width + col] = shifted[row * width + shiftedCol]
} else {
// Feather band — linear blend.
let w: Float
if shiftedCol < coreLeft {
w = Float(shiftedCol - patchLeft) * invFeather
} else {
w = Float(patchRight - 1 - shiftedCol) * invFeather
}
let wClamped = max(0.0, min(1.0, w))
let origVal = original[row * width + col]
let shiftVal = shifted[row * width + shiftedCol]
result[row * width + col] = origVal + (shiftVal - origVal) * wClamped
}
}
}
return result
}
/// Run dual-inference seam fix: infer depth on both the original and a
/// half-shifted copy of the image, then stitch the seam region.
func fixSeamWithDualInference(
modelURL: URL,
image: CGImage,
patchHalfWidth: Int = 25
) throws -> (depths: [Float32], width: Int, height: Int) {
let imageWidth = image.width
let half = imageWidth / 2
// Shift the source image left by half — the seam moves to the center
guard let shiftedImage = shiftImageHorizontally(image, by: half) else {
fatalError("Failed to shift image for seam fix")
}
// 1. Infer depth on the original image
let (origDepths, w, h) = try runDepthInference(modelURL: modelURL, image: image)
// 2. Infer depth on the shifted image
let (shiftDepths, _, _) = try runDepthInference(modelURL: modelURL, image: shiftedImage)
// 3. Stitch: patch center seam from shifted depth into original
let stitched = stitchSeamFromShiftedDepth(
original: origDepths,
shifted: shiftDepths,
width: w,
height: h,
depthHalf: w / 2,
patchHalfWidth: patchHalfWidth
)
return (stitched, w, h)
}
// - Main Pipeline
func main() {
guard let args = CLIArgs.parse() else { exit(1) }
print("Loading image ...")
guard let nsImg = NSImage(contentsOf: args.imagePath) else {
fatalError("Cannot load image: \(args.imagePath.path)")
}
guard let cgImg = nsImg.cgImage(forProposedRect: nil, context: nil, hints: nil) else {
fatalError("Cannot convert image to CGImage")
}
print(" Image: \(cgImg.width)x\(cgImg.height)")
print("Running depth inference (with dual-inference seam fix) ...")
let t0 = CFAbsoluteTimeGetCurrent()
let (depths, dW, dH) = try! fixSeamWithDualInference(modelURL: args.modelPath, image: cgImg)
let dt = CFAbsoluteTimeGetCurrent() - t0
print(" Depth: \(dW)x\(dH) in \(String(format: "%.2fs", dt))")
print("Loading image pixels ...")
let pixels = loadImagePixels(cgImg, targetW: dW, targetH: dH)
let radius = args.radius
let coeffSH0 = sqrt(1.0 / (4.0 * Float.pi))
// Base angular footprint of one pixel (used as scale factor per-splat)
let pixelFootprint = radius * Float.pi / Float(max(dW, dH))
let uniformOpacity = Float(log(0.85 / (1.0 - 0.85))) // logit(0.85) ≈ 1.96
print("Generating \(dW * dH) Gaussians ...")
var gaussians: [(x: Float, y: Float, z: Float,
f0: Float, f1: Float, f2: Float,
opacity: Float,
s0: Float, s1: Float, s2: Float,
q0: Float, q1: Float, q2: Float, q3: Float)] = []
gaussians.reserveCapacity(dW * dH)
for v in 0..<dH {
for u in 0..<dW {
let idx = v * dW + u
let depth = depths[idx]
// Skip zero-depth pixels (invalid / background)
guard depth > 0.01 else { continue }
var dir = equiToSphereDirection(u: Float(u), v: Float(v), width: dW, height: dH)
// Flip 180° (panorama was upside down — invert Y axis)
dir.y = -dir.y
let r = depth * radius
let px = dir.x * r
let py = dir.y * r
let pz = dir.z * r
// Scale proportional to world distance — far splats grow linearly to avoid holes
let linearScale = pixelFootprint * (r / radius) * 1.5
let splatScale = Float(log(linearScale))
// Color from image pixel (RGBA)
let pidx = idx * 4
let rr = Float(pixels[pidx]) / 255.0
let gg = Float(pixels[pidx + 1]) / 255.0
let bb = Float(pixels[pidx + 2]) / 255.0
// RGB -> SH0
let f0 = (rr - 0.5) / coeffSH0
let f1 = (gg - 0.5) / coeffSH0
let f2 = (bb - 0.5) / coeffSH0
gaussians.append((
x: px, y: py, z: pz,
f0: f0, f1: f1, f2: f2,
opacity: uniformOpacity,
s0: splatScale, s1: splatScale, s2: splatScale,
q0: 1.0, q1: 0.0, q2: 0.0, q3: 0.0
))
}
}
print(" Valid Gaussians: \(gaussians.count) (filtered \(dW * dH - gaussians.count) zero-depth pixels)")
print("Saving PLY ...")
let focal = Float(dW) // panoramic focal ≈ image width
try! writePLY(gaussians: gaussians, focalLength: focal, imageW: dW, imageH: dH, to: args.outputPath)
let size = (try? FileManager.default.attributesOfItem(atPath: args.outputPath.path)[.size] as? UInt)?.description ?? "?"
print(" Saved \(args.outputPath.path) (\(size) bytes)")
print("Done!")
}
main()