|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import Foundation |
|
|
import CoreML |
|
|
import CoreImage |
|
|
import AppKit |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
struct Gaussians3D { |
|
|
let meanVectors: MLMultiArray |
|
|
let singularValues: MLMultiArray |
|
|
let quaternions: MLMultiArray |
|
|
let colors: MLMultiArray |
|
|
let opacities: MLMultiArray |
|
|
|
|
|
var count: Int { |
|
|
return meanVectors.shape[1].intValue |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func computeImportanceScores() -> [Float] { |
|
|
let n = count |
|
|
var scores = [Float](repeating: 0, count: n) |
|
|
|
|
|
let scalePtr = singularValues.dataPointer.assumingMemoryBound(to: Float.self) |
|
|
let opacityPtr = opacities.dataPointer.assumingMemoryBound(to: Float.self) |
|
|
|
|
|
for i in 0..<n { |
|
|
|
|
|
|
|
|
|
|
|
let s0 = scalePtr[i * 3 + 0] |
|
|
let s1 = scalePtr[i * 3 + 1] |
|
|
let s2 = scalePtr[i * 3 + 2] |
|
|
|
|
|
|
|
|
let scaleProduct = s0 * s1 * s2 |
|
|
|
|
|
|
|
|
let opacity = opacityPtr[i] |
|
|
|
|
|
scores[i] = scaleProduct * opacity |
|
|
} |
|
|
|
|
|
return scores |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func decimationIndices(keepRatio: Float) -> [Int] { |
|
|
let n = count |
|
|
let keepCount = max(1, Int(Float(n) * keepRatio)) |
|
|
|
|
|
|
|
|
let scores = computeImportanceScores() |
|
|
|
|
|
|
|
|
var indexedScores = scores.enumerated().map { ($0.offset, $0.element) } |
|
|
indexedScores.sort { $0.1 > $1.1 } |
|
|
|
|
|
|
|
|
var keepIndices = indexedScores.prefix(keepCount).map { $0.0 } |
|
|
|
|
|
|
|
|
keepIndices.sort() |
|
|
|
|
|
return keepIndices |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func linearRGBToSRGB(_ linear: Float) -> Float { |
|
|
if linear <= 0.0031308 { |
|
|
return linear * 12.92 |
|
|
} else { |
|
|
return 1.055 * pow(linear, 1.0 / 2.4) - 0.055 |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
func rgbToSphericalHarmonics(_ rgb: Float) -> Float { |
|
|
let coeffDegree0 = sqrt(1.0 / (4.0 * Float.pi)) |
|
|
return (rgb - 0.5) / coeffDegree0 |
|
|
} |
|
|
|
|
|
|
|
|
func inverseSigmoid(_ x: Float) -> Float { |
|
|
let clamped = min(max(x, 1e-6), 1.0 - 1e-6) |
|
|
return log(clamped / (1.0 - clamped)) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
class SHARPModelRunner { |
|
|
private let model: MLModel |
|
|
private let inputHeight: Int |
|
|
private let inputWidth: Int |
|
|
|
|
|
init(modelPath: URL, inputHeight: Int = 1536, inputWidth: Int = 1536) throws { |
|
|
let config = MLModelConfiguration() |
|
|
config.computeUnits = .all |
|
|
|
|
|
|
|
|
let compiledModelURL = try SHARPModelRunner.compileModelIfNeeded(at: modelPath) |
|
|
|
|
|
self.model = try MLModel(contentsOf: compiledModelURL, configuration: config) |
|
|
self.inputHeight = inputHeight |
|
|
self.inputWidth = inputWidth |
|
|
|
|
|
|
|
|
print("Model inputs: \(model.modelDescription.inputDescriptionsByName.keys.joined(separator: ", "))") |
|
|
print("Model outputs: \(model.modelDescription.outputDescriptionsByName.keys.joined(separator: ", "))") |
|
|
} |
|
|
|
|
|
|
|
|
private static func compileModelIfNeeded(at modelPath: URL) throws -> URL { |
|
|
let fileManager = FileManager.default |
|
|
let pathExtension = modelPath.pathExtension.lowercased() |
|
|
|
|
|
|
|
|
if pathExtension == "mlmodelc" { |
|
|
print("Model is already compiled.") |
|
|
return modelPath |
|
|
} |
|
|
|
|
|
|
|
|
guard pathExtension == "mlpackage" || pathExtension == "mlmodel" else { |
|
|
throw NSError(domain: "SHARPModelRunner", code: 10, |
|
|
userInfo: [NSLocalizedDescriptionKey: "Unsupported model format: \(pathExtension).Use .mlpackage, .mlmodel, or .mlmodelc"]) |
|
|
} |
|
|
|
|
|
|
|
|
let cacheDir = fileManager.temporaryDirectory.appendingPathComponent("SHARPModelCache") |
|
|
try? fileManager.createDirectory(at: cacheDir, withIntermediateDirectories: true) |
|
|
|
|
|
|
|
|
let modelName = modelPath.deletingPathExtension().lastPathComponent |
|
|
let compiledPath = cacheDir.appendingPathComponent("\(modelName).mlmodelc") |
|
|
|
|
|
|
|
|
if fileManager.fileExists(atPath: compiledPath.path) { |
|
|
|
|
|
let sourceAttrs = try fileManager.attributesOfItem(atPath: modelPath.path) |
|
|
let cachedAttrs = try fileManager.attributesOfItem(atPath: compiledPath.path) |
|
|
|
|
|
if let sourceDate = sourceAttrs[.modificationDate] as? Date, |
|
|
let cachedDate = cachedAttrs[.modificationDate] as? Date, |
|
|
cachedDate >= sourceDate { |
|
|
print("Using cached compiled model at \(compiledPath.path)") |
|
|
return compiledPath |
|
|
} else { |
|
|
|
|
|
try? fileManager.removeItem(at: compiledPath) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
print("Compiling model (this may take a moment)...") |
|
|
let startTime = CFAbsoluteTimeGetCurrent() |
|
|
|
|
|
let temporaryCompiledURL = try MLModel.compileModel(at: modelPath) |
|
|
|
|
|
let compileTime = CFAbsoluteTimeGetCurrent() - startTime |
|
|
print("✓ Model compiled in \(String(format: "%.1f", compileTime))s") |
|
|
|
|
|
|
|
|
try? fileManager.removeItem(at: compiledPath) |
|
|
try fileManager.moveItem(at: temporaryCompiledURL, to: compiledPath) |
|
|
|
|
|
print("Compiled model cached at \(compiledPath.path)") |
|
|
return compiledPath |
|
|
} |
|
|
|
|
|
|
|
|
func preprocessImage(at imagePath: URL) throws -> MLMultiArray { |
|
|
guard let nsImage = NSImage(contentsOf: imagePath) else { |
|
|
throw NSError(domain: "SHARPModelRunner", code: 1, |
|
|
userInfo: [NSLocalizedDescriptionKey: "Failed to load image from \(imagePath.path)"]) |
|
|
} |
|
|
|
|
|
guard let cgImage = nsImage.cgImage(forProposedRect: nil, context: nil, hints: nil) else { |
|
|
throw NSError(domain: "SHARPModelRunner", code: 2, |
|
|
userInfo: [NSLocalizedDescriptionKey: "Failed to convert to CGImage"]) |
|
|
} |
|
|
|
|
|
|
|
|
let ciImage = CIImage(cgImage: cgImage) |
|
|
let context = CIContext() |
|
|
|
|
|
|
|
|
let scaleX = CGFloat(inputWidth) / ciImage.extent.width |
|
|
let scaleY = CGFloat(inputHeight) / ciImage.extent.height |
|
|
let scaledImage = ciImage.transformed(by: CGAffineTransform(scaleX: scaleX, y: scaleY)) |
|
|
|
|
|
|
|
|
guard let resizedCGImage = context.createCGImage(scaledImage, from: CGRect(x: 0, y: 0, |
|
|
width: inputWidth, |
|
|
height: inputHeight)) else { |
|
|
throw NSError(domain: "SHARPModelRunner", code: 3, |
|
|
userInfo: [NSLocalizedDescriptionKey: "Failed to resize image"]) |
|
|
} |
|
|
|
|
|
|
|
|
let imageArray = try MLMultiArray(shape: [1, 3, NSNumber(value: inputHeight), NSNumber(value: inputWidth)], |
|
|
dataType: .float32) |
|
|
|
|
|
let width = resizedCGImage.width |
|
|
let height = resizedCGImage.height |
|
|
let bytesPerPixel = 4 |
|
|
let bytesPerRow = bytesPerPixel * width |
|
|
var pixelData = [UInt8](repeating: 0, count: height * bytesPerRow) |
|
|
|
|
|
let colorSpace = CGColorSpaceCreateDeviceRGB() |
|
|
guard let cgContext = CGContext(data: &pixelData, |
|
|
width: width, |
|
|
height: height, |
|
|
bitsPerComponent: 8, |
|
|
bytesPerRow: bytesPerRow, |
|
|
space: colorSpace, |
|
|
bitmapInfo: CGImageAlphaInfo.premultipliedLast.rawValue) else { |
|
|
throw NSError(domain: "SHARPModelRunner", code: 4, |
|
|
userInfo: [NSLocalizedDescriptionKey: "Failed to create bitmap context"]) |
|
|
} |
|
|
|
|
|
cgContext.draw(resizedCGImage, in: CGRect(x: 0, y: 0, width: width, height: height)) |
|
|
|
|
|
|
|
|
|
|
|
let ptr = imageArray.dataPointer.assumingMemoryBound(to: Float.self) |
|
|
let channelStride = inputHeight * inputWidth |
|
|
|
|
|
for y in 0..<height { |
|
|
for x in 0..<width { |
|
|
let pixelIndex = y * bytesPerRow + x * bytesPerPixel |
|
|
let r = Float(pixelData[pixelIndex]) / 255.0 |
|
|
let g = Float(pixelData[pixelIndex + 1]) / 255.0 |
|
|
let b = Float(pixelData[pixelIndex + 2]) / 255.0 |
|
|
|
|
|
let spatialIndex = y * inputWidth + x |
|
|
ptr[0 * channelStride + spatialIndex] = r |
|
|
ptr[1 * channelStride + spatialIndex] = g |
|
|
ptr[2 * channelStride + spatialIndex] = b |
|
|
} |
|
|
} |
|
|
|
|
|
return imageArray |
|
|
} |
|
|
|
|
|
|
|
|
func predict(image: MLMultiArray, focalLengthPx: Float) throws -> Gaussians3D { |
|
|
|
|
|
let disparityFactor = focalLengthPx / Float(inputWidth) |
|
|
|
|
|
|
|
|
let disparityArray = try MLMultiArray(shape: [1], dataType: .float32) |
|
|
disparityArray[0] = NSNumber(value: disparityFactor) |
|
|
|
|
|
|
|
|
let inputFeatures = try MLDictionaryFeatureProvider(dictionary: [ |
|
|
"image": MLFeatureValue(multiArray: image), |
|
|
"disparity_factor": MLFeatureValue(multiArray: disparityArray) |
|
|
]) |
|
|
|
|
|
|
|
|
let output = try model.prediction(from: inputFeatures) |
|
|
|
|
|
|
|
|
let outputNames = Array(model.modelDescription.outputDescriptionsByName.keys) |
|
|
|
|
|
|
|
|
func findOutput(containing keywords: [String]) -> MLMultiArray? { |
|
|
for name in outputNames { |
|
|
let lowercaseName = name.lowercased() |
|
|
for keyword in keywords { |
|
|
if lowercaseName.contains(keyword.lowercased()) { |
|
|
return output.featureValue(for: name)?.multiArrayValue |
|
|
} |
|
|
} |
|
|
} |
|
|
return nil |
|
|
} |
|
|
|
|
|
|
|
|
let meanVectors = output.featureValue(for: "mean_vectors_3d_positions")?.multiArrayValue |
|
|
?? findOutput(containing: ["mean", "position", "xyz"]) |
|
|
|
|
|
let singularValues = output.featureValue(for: "singular_values_scales")?.multiArrayValue |
|
|
?? findOutput(containing: ["singular", "scale"]) |
|
|
|
|
|
let quaternions = output.featureValue(for: "quaternions_rotations")?.multiArrayValue |
|
|
?? findOutput(containing: ["quaternion", "rotation", "rot"]) |
|
|
|
|
|
let colors = output.featureValue(for: "colors_rgb_linear")?.multiArrayValue |
|
|
?? findOutput(containing: ["color", "rgb"]) |
|
|
|
|
|
let opacities = output.featureValue(for: "opacities_alpha_channel")?.multiArrayValue |
|
|
?? findOutput(containing: ["opacity", "alpha"]) |
|
|
|
|
|
|
|
|
if meanVectors == nil || singularValues == nil || quaternions == nil || colors == nil || opacities == nil { |
|
|
print("Warning: Could not match all outputs by name.Available outputs: \(outputNames)") |
|
|
|
|
|
|
|
|
if outputNames.count >= 5 { |
|
|
let sortedNames = outputNames.sorted() |
|
|
guard let mv = output.featureValue(for: sortedNames[0])?.multiArrayValue, |
|
|
let sv = output.featureValue(for: sortedNames[1])?.multiArrayValue, |
|
|
let q = output.featureValue(for: sortedNames[2])?.multiArrayValue, |
|
|
let c = output.featureValue(for: sortedNames[3])?.multiArrayValue, |
|
|
let o = output.featureValue(for: sortedNames[4])?.multiArrayValue else { |
|
|
throw NSError(domain: "SHARPModelRunner", code: 5, |
|
|
userInfo: [NSLocalizedDescriptionKey: "Failed to extract model outputs. Available: \(outputNames)"]) |
|
|
} |
|
|
|
|
|
print("Using outputs by sorted order: \(sortedNames)") |
|
|
return Gaussians3D( |
|
|
meanVectors: mv, |
|
|
singularValues: sv, |
|
|
quaternions: q, |
|
|
colors: c, |
|
|
opacities: o |
|
|
) |
|
|
} |
|
|
|
|
|
throw NSError(domain: "SHARPModelRunner", code: 5, |
|
|
userInfo: [NSLocalizedDescriptionKey: "Failed to extract model outputs.Available: \(outputNames)"]) |
|
|
} |
|
|
|
|
|
return Gaussians3D( |
|
|
meanVectors: meanVectors!, |
|
|
singularValues: singularValues!, |
|
|
quaternions: quaternions!, |
|
|
colors: colors!, |
|
|
opacities: opacities! |
|
|
) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func savePLY(gaussians: Gaussians3D, |
|
|
focalLengthPx: Float, |
|
|
imageShape: (height: Int, width: Int), |
|
|
to outputPath: URL, |
|
|
decimation: Float = 1.0) throws { |
|
|
|
|
|
let imageHeight = imageShape.height |
|
|
let imageWidth = imageShape.width |
|
|
|
|
|
|
|
|
let keepIndices: [Int] |
|
|
let originalCount = gaussians.count |
|
|
|
|
|
if decimation < 1.0 { |
|
|
keepIndices = gaussians.decimationIndices(keepRatio: decimation) |
|
|
print("Decimating: keeping \(keepIndices.count) of \(originalCount) Gaussians (\(String(format: "%.1f", decimation * 100))%)") |
|
|
} else { |
|
|
keepIndices = Array(0..<originalCount) |
|
|
} |
|
|
|
|
|
let numGaussians = keepIndices.count |
|
|
|
|
|
var fileContent = Data() |
|
|
|
|
|
|
|
|
func appendString(_ str: String) { |
|
|
fileContent.append(str.data(using: .ascii)!) |
|
|
} |
|
|
|
|
|
|
|
|
func appendFloat32(_ value: Float) { |
|
|
var v = value |
|
|
fileContent.append(Data(bytes: &v, count: 4)) |
|
|
} |
|
|
|
|
|
|
|
|
func appendInt32(_ value: Int32) { |
|
|
var v = value |
|
|
fileContent.append(Data(bytes: &v, count: 4)) |
|
|
} |
|
|
|
|
|
|
|
|
func appendUInt32(_ value: UInt32) { |
|
|
var v = value |
|
|
fileContent.append(Data(bytes: &v, count: 4)) |
|
|
} |
|
|
|
|
|
|
|
|
func appendUInt8(_ value: UInt8) { |
|
|
var v = value |
|
|
fileContent.append(Data(bytes: &v, count: 1)) |
|
|
} |
|
|
|
|
|
|
|
|
appendString("ply\n") |
|
|
appendString("format binary_little_endian 1.0\n") |
|
|
|
|
|
|
|
|
appendString("element vertex \(numGaussians)\n") |
|
|
appendString("property float x\n") |
|
|
appendString("property float y\n") |
|
|
appendString("property float z\n") |
|
|
appendString("property float f_dc_0\n") |
|
|
appendString("property float f_dc_1\n") |
|
|
appendString("property float f_dc_2\n") |
|
|
appendString("property float opacity\n") |
|
|
appendString("property float scale_0\n") |
|
|
appendString("property float scale_1\n") |
|
|
appendString("property float scale_2\n") |
|
|
appendString("property float rot_0\n") |
|
|
appendString("property float rot_1\n") |
|
|
appendString("property float rot_2\n") |
|
|
appendString("property float rot_3\n") |
|
|
|
|
|
|
|
|
appendString("element extrinsic 16\n") |
|
|
appendString("property float extrinsic\n") |
|
|
|
|
|
|
|
|
appendString("element intrinsic 9\n") |
|
|
appendString("property float intrinsic\n") |
|
|
|
|
|
|
|
|
appendString("element image_size 2\n") |
|
|
appendString("property uint image_size\n") |
|
|
|
|
|
|
|
|
appendString("element frame 2\n") |
|
|
appendString("property int frame\n") |
|
|
|
|
|
|
|
|
appendString("element disparity 2\n") |
|
|
appendString("property float disparity\n") |
|
|
|
|
|
|
|
|
appendString("element color_space 1\n") |
|
|
appendString("property uchar color_space\n") |
|
|
|
|
|
|
|
|
appendString("element version 3\n") |
|
|
appendString("property uchar version\n") |
|
|
|
|
|
appendString("end_header\n") |
|
|
|
|
|
|
|
|
|
|
|
var disparities: [Float] = [] |
|
|
|
|
|
|
|
|
let meanPtr = gaussians.meanVectors.dataPointer.assumingMemoryBound(to: Float.self) |
|
|
let scalePtr = gaussians.singularValues.dataPointer.assumingMemoryBound(to: Float.self) |
|
|
let quatPtr = gaussians.quaternions.dataPointer.assumingMemoryBound(to: Float.self) |
|
|
let colorPtr = gaussians.colors.dataPointer.assumingMemoryBound(to: Float.self) |
|
|
let opacityPtr = gaussians.opacities.dataPointer.assumingMemoryBound(to: Float.self) |
|
|
|
|
|
for i in keepIndices { |
|
|
|
|
|
let x = meanPtr[i * 3 + 0] |
|
|
let y = meanPtr[i * 3 + 1] |
|
|
let z = meanPtr[i * 3 + 2] |
|
|
appendFloat32(x) |
|
|
appendFloat32(y) |
|
|
appendFloat32(z) |
|
|
|
|
|
|
|
|
if z > 1e-6 { |
|
|
disparities.append(1.0 / z) |
|
|
} |
|
|
|
|
|
|
|
|
let colorR = colorPtr[i * 3 + 0] |
|
|
let colorG = colorPtr[i * 3 + 1] |
|
|
let colorB = colorPtr[i * 3 + 2] |
|
|
|
|
|
let srgbR = linearRGBToSRGB(colorR) |
|
|
let srgbG = linearRGBToSRGB(colorG) |
|
|
let srgbB = linearRGBToSRGB(colorB) |
|
|
|
|
|
let sh0 = rgbToSphericalHarmonics(srgbR) |
|
|
let sh1 = rgbToSphericalHarmonics(srgbG) |
|
|
let sh2 = rgbToSphericalHarmonics(srgbB) |
|
|
|
|
|
appendFloat32(sh0) |
|
|
appendFloat32(sh1) |
|
|
appendFloat32(sh2) |
|
|
|
|
|
|
|
|
let opacity = opacityPtr[i] |
|
|
let opacityLogit = inverseSigmoid(opacity) |
|
|
appendFloat32(opacityLogit) |
|
|
|
|
|
|
|
|
let scale0 = scalePtr[i * 3 + 0] |
|
|
let scale1 = scalePtr[i * 3 + 1] |
|
|
let scale2 = scalePtr[i * 3 + 2] |
|
|
|
|
|
appendFloat32(log(max(scale0, 1e-10))) |
|
|
appendFloat32(log(max(scale1, 1e-10))) |
|
|
appendFloat32(log(max(scale2, 1e-10))) |
|
|
|
|
|
|
|
|
let q0 = quatPtr[i * 4 + 0] |
|
|
let q1 = quatPtr[i * 4 + 1] |
|
|
let q2 = quatPtr[i * 4 + 2] |
|
|
let q3 = quatPtr[i * 4 + 3] |
|
|
|
|
|
appendFloat32(q0) |
|
|
appendFloat32(q1) |
|
|
appendFloat32(q2) |
|
|
appendFloat32(q3) |
|
|
} |
|
|
|
|
|
|
|
|
let identity: [Float] = [ |
|
|
1, 0, 0, 0, |
|
|
0, 1, 0, 0, |
|
|
0, 0, 1, 0, |
|
|
0, 0, 0, 1 |
|
|
] |
|
|
for val in identity { |
|
|
appendFloat32(val) |
|
|
} |
|
|
|
|
|
|
|
|
let intrinsic: [Float] = [ |
|
|
focalLengthPx, 0, Float(imageWidth) * 0.5, |
|
|
0, focalLengthPx, Float(imageHeight) * 0.5, |
|
|
0, 0, 1 |
|
|
] |
|
|
for val in intrinsic { |
|
|
appendFloat32(val) |
|
|
} |
|
|
|
|
|
|
|
|
appendUInt32(UInt32(imageWidth)) |
|
|
appendUInt32(UInt32(imageHeight)) |
|
|
|
|
|
|
|
|
appendInt32(1) |
|
|
appendInt32(Int32(numGaussians)) |
|
|
|
|
|
|
|
|
disparities.sort() |
|
|
let q10Index = Int(Float(disparities.count) * 0.1) |
|
|
let q90Index = Int(Float(disparities.count) * 0.9) |
|
|
let disparity10 = disparities.isEmpty ? 0.0 : disparities[min(q10Index, disparities.count - 1)] |
|
|
let disparity90 = disparities.isEmpty ? 1.0 : disparities[min(q90Index, disparities.count - 1)] |
|
|
appendFloat32(disparity10) |
|
|
appendFloat32(disparity90) |
|
|
|
|
|
|
|
|
appendUInt8(1) |
|
|
|
|
|
|
|
|
appendUInt8(1) |
|
|
appendUInt8(5) |
|
|
appendUInt8(0) |
|
|
|
|
|
|
|
|
try fileContent.write(to: outputPath) |
|
|
|
|
|
print("✓ Saved PLY with \(numGaussians) Gaussians to \(outputPath.path)") |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
struct CommandLineArgs { |
|
|
let modelPath: URL |
|
|
let imagePath: URL |
|
|
let outputPath: URL |
|
|
let focalLength: Float |
|
|
let decimation: Float |
|
|
|
|
|
static func parse() -> CommandLineArgs? { |
|
|
let args = CommandLine.arguments |
|
|
|
|
|
var modelPath: URL? |
|
|
var imagePath: URL? |
|
|
var outputPath: URL? |
|
|
var focalLength: Float = 1536.0 |
|
|
var decimation: Float = 1.0 |
|
|
|
|
|
var i = 1 |
|
|
while i < args.count { |
|
|
let arg = args[i] |
|
|
|
|
|
switch arg { |
|
|
case "-m", "--model": |
|
|
i += 1 |
|
|
if i < args.count { |
|
|
modelPath = URL(fileURLWithPath: args[i]) |
|
|
} |
|
|
|
|
|
case "-i", "--input": |
|
|
i += 1 |
|
|
if i < args.count { |
|
|
imagePath = URL(fileURLWithPath: args[i]) |
|
|
} |
|
|
|
|
|
case "-o", "--output": |
|
|
i += 1 |
|
|
if i < args.count { |
|
|
outputPath = URL(fileURLWithPath: args[i]) |
|
|
} |
|
|
|
|
|
case "-f", "--focal-length": |
|
|
i += 1 |
|
|
if i < args.count { |
|
|
focalLength = Float(args[i]) ?? 1536.0 |
|
|
} |
|
|
|
|
|
case "-d", "--decimation": |
|
|
i += 1 |
|
|
if i < args.count { |
|
|
if let value = Float(args[i]) { |
|
|
|
|
|
if value > 1.0 { |
|
|
decimation = value / 100.0 |
|
|
} else { |
|
|
decimation = value |
|
|
} |
|
|
decimation = max(0.01, min(1.0, decimation)) |
|
|
} |
|
|
} |
|
|
|
|
|
case "-h", "--help": |
|
|
printUsage() |
|
|
return nil |
|
|
|
|
|
default: |
|
|
|
|
|
if modelPath == nil { |
|
|
modelPath = URL(fileURLWithPath: arg) |
|
|
} else if imagePath == nil { |
|
|
imagePath = URL(fileURLWithPath: arg) |
|
|
} else if outputPath == nil { |
|
|
outputPath = URL(fileURLWithPath: arg) |
|
|
} else if focalLength == 1536.0 { |
|
|
focalLength = Float(arg) ?? 1536.0 |
|
|
} |
|
|
} |
|
|
|
|
|
i += 1 |
|
|
} |
|
|
|
|
|
guard let model = modelPath, let image = imagePath, let output = outputPath else { |
|
|
printUsage() |
|
|
return nil |
|
|
} |
|
|
|
|
|
return CommandLineArgs( |
|
|
modelPath: model, |
|
|
imagePath: image, |
|
|
outputPath: output, |
|
|
focalLength: focalLength, |
|
|
decimation: decimation |
|
|
) |
|
|
} |
|
|
|
|
|
static func printUsage() { |
|
|
let execName = CommandLine.arguments[0].components(separatedBy: "/").last ?? "sharp_runner" |
|
|
print(""" |
|
|
Usage: \(execName) [OPTIONS] <model> <input_image> <output.ply> |
|
|
|
|
|
SHARP Model Inference - Generate 3D Gaussian Splats from a single image |
|
|
|
|
|
Arguments: |
|
|
model Path to the SHARP Core ML model (.mlpackage, .mlmodel, or .mlmodelc) |
|
|
input_image Path to input image (PNG, JPEG, etc.) |
|
|
output.ply Path for output PLY file |
|
|
|
|
|
Options: |
|
|
-m, --model PATH Path to Core ML model |
|
|
-i, --input PATH Path to input image |
|
|
-o, --output PATH Path for output PLY file |
|
|
-f, --focal-length FLOAT Focal length in pixels (default: 1536) |
|
|
-d, --decimation FLOAT Decimation ratio 0.0-1.0 or percentage 1-100 (default: 1.0 = keep all) |
|
|
Example: 0.5 or 50 keeps 50% of Gaussians |
|
|
-h, --help Show this help message |
|
|
|
|
|
Examples: |
|
|
# Basic usage |
|
|
\(execName) sharp.mlpackage photo.jpg output.ply |
|
|
|
|
|
# With focal length |
|
|
\(execName) sharp.mlpackage photo.jpg output.ply 768 |
|
|
|
|
|
# With decimation (keep 50% of points) |
|
|
\(execName) -m sharp.mlpackage -i photo.jpg -o output.ply -d 0.5 |
|
|
|
|
|
# With decimation as percentage |
|
|
\(execName) -m sharp.mlpackage -i photo.jpg -o output.ply -d 25 |
|
|
|
|
|
The model will be automatically compiled on first use and cached for subsequent runs. |
|
|
Decimation keeps the most important Gaussians based on scale and opacity. |
|
|
""") |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func main() { |
|
|
guard let args = CommandLineArgs.parse() else { |
|
|
exit(1) |
|
|
} |
|
|
|
|
|
do { |
|
|
print("Loading SHARP model from \(args.modelPath.path)...") |
|
|
let runner = try SHARPModelRunner(modelPath: args.modelPath) |
|
|
|
|
|
print("Preprocessing image \(args.imagePath.path)...") |
|
|
let imageArray = try runner.preprocessImage(at: args.imagePath) |
|
|
|
|
|
print("Running inference...") |
|
|
let startTime = CFAbsoluteTimeGetCurrent() |
|
|
let gaussians = try runner.predict(image: imageArray, focalLengthPx: args.focalLength) |
|
|
let inferenceTime = CFAbsoluteTimeGetCurrent() - startTime |
|
|
|
|
|
print("✓ Generated \(gaussians.count) Gaussians in \(String(format: "%.2f", inferenceTime))s") |
|
|
|
|
|
print("Saving PLY file...") |
|
|
try runner.savePLY( |
|
|
gaussians: gaussians, |
|
|
focalLengthPx: args.focalLength, |
|
|
imageShape: (height: 1536, width: 1536), |
|
|
to: args.outputPath, |
|
|
decimation: args.decimation |
|
|
) |
|
|
|
|
|
print("✓ Complete!") |
|
|
|
|
|
} catch { |
|
|
print("Error: \(error.localizedDescription)") |
|
|
if let nsError = error as NSError? { |
|
|
print("Domain: \(nsError.domain), Code: \(nsError.code)") |
|
|
if let underlyingError = nsError.userInfo[NSUnderlyingErrorKey] as? Error { |
|
|
print("Underlying error: \(underlyingError)") |
|
|
} |
|
|
} |
|
|
exit(1) |
|
|
} |
|
|
} |
|
|
|
|
|
main() |