ml-sharp / src /sharp /cli /predict.py
amael-apple's picture
Initial commit
c20d7cc
"""Contains `sharp predict` CLI implementation.
For licensing see accompanying LICENSE file.
Copyright (C) 2025 Apple Inc. All Rights Reserved.
"""
from __future__ import annotations
import logging
from pathlib import Path
import click
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.data
from sharp.models import (
PredictorParams,
RGBGaussianPredictor,
create_predictor,
)
from sharp.utils import io
from sharp.utils import logging as logging_utils
from sharp.utils.gaussians import (
Gaussians3D,
SceneMetaData,
save_ply,
unproject_gaussians,
)
from .render import render_gaussians
LOGGER = logging.getLogger(__name__)
DEFAULT_MODEL_URL = "https://ml-site.cdn-apple.com/models/sharp/sharp_2572gikvuh.pt"
@click.command()
@click.option(
"-i",
"--input-path",
type=click.Path(path_type=Path, exists=True),
help="Path to an image or containing a list of images.",
required=True,
)
@click.option(
"-o",
"--output-path",
type=click.Path(path_type=Path, file_okay=False),
help="Path to save the predicted Gaussians and renderings.",
required=True,
)
@click.option(
"-c",
"--checkpoint-path",
type=click.Path(path_type=Path, dir_okay=False),
default=None,
help="Path to the .pt checkpoint. If not provided, downloads the default model automatically.",
required=False,
)
@click.option(
"--render/--no-render",
"with_rendering",
is_flag=True,
default=False,
help="Whether to render trajectory for checkpoint.",
)
@click.option(
"--device",
type=str,
default="default",
help="Device to run on. ['cpu', 'mps', 'cuda']",
)
@click.option("-v", "--verbose", is_flag=True, help="Activate debug logs.")
def predict_cli(
input_path: Path,
output_path: Path,
checkpoint_path: Path,
with_rendering: bool,
device: str,
verbose: bool,
):
"""Predict Gaussians from input images."""
logging_utils.configure(logging.DEBUG if verbose else logging.INFO)
extensions = io.get_supported_image_extensions()
image_paths = []
if input_path.is_file():
if input_path.suffix in extensions:
image_paths = [input_path]
else:
for ext in extensions:
image_paths.extend(list(input_path.glob(f"**/*{ext}")))
if len(image_paths) == 0:
LOGGER.info("No valid images found. Input was %s.", input_path)
return
LOGGER.info("Processing %d valid image files.", len(image_paths))
if device == "default":
if torch.cuda.is_available():
device = "cuda"
elif torch.mps.is_available():
device = "mps"
else:
device = "cpu"
LOGGER.info("Using device %s", device)
if with_rendering and device != "cuda":
LOGGER.warning("Can only run rendering with gsplat on CUDA. Rendering is disabled.")
with_rendering = False
# Load or download checkpoint
if checkpoint_path is None:
LOGGER.info("No checkpoint provided. Downloading default model from %s", DEFAULT_MODEL_URL)
state_dict = torch.hub.load_state_dict_from_url(DEFAULT_MODEL_URL, progress=True)
else:
LOGGER.info("Loading checkpoint from %s", checkpoint_path)
state_dict = torch.load(checkpoint_path, weights_only=True)
gaussian_predictor = create_predictor(PredictorParams())
gaussian_predictor.load_state_dict(state_dict)
gaussian_predictor.eval()
gaussian_predictor.to(device)
output_path.mkdir(exist_ok=True, parents=True)
for image_path in image_paths:
LOGGER.info("Processing %s", image_path)
image, _, f_px = io.load_rgb(image_path)
height, width = image.shape[:2]
intrinsics = torch.tensor(
[
[f_px, 0, (width - 1) / 2.0, 0],
[0, f_px, (height - 1) / 2.0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
],
device=device,
dtype=torch.float32,
)
gaussians = predict_image(gaussian_predictor, image, f_px, torch.device(device))
LOGGER.info("Saving 3DGS to %s", output_path)
save_ply(gaussians, f_px, (height, width), output_path / f"{image_path.stem}.ply")
if with_rendering:
output_video_path = (output_path / image_path.stem).with_suffix(".mp4")
LOGGER.info("Rendering trajectory to %s", output_video_path)
metadata = SceneMetaData(intrinsics[0, 0].item(), (width, height), "linearRGB")
render_gaussians(gaussians, metadata, output_video_path)
@torch.no_grad()
def predict_image(
predictor: RGBGaussianPredictor,
image: np.ndarray,
f_px: float,
device: torch.device,
) -> Gaussians3D:
"""Predict Gaussians from an image."""
internal_shape = (1536, 1536)
LOGGER.info("Running preprocessing.")
image_pt = torch.from_numpy(image.copy()).float().to(device).permute(2, 0, 1) / 255.0
_, height, width = image_pt.shape
disparity_factor = torch.tensor([f_px / width]).float().to(device)
image_resized_pt = F.interpolate(
image_pt[None],
size=(internal_shape[1], internal_shape[0]),
mode="bilinear",
align_corners=True,
)
# Predict Gaussians in the NDC space.
LOGGER.info("Running inference.")
gaussians_ndc = predictor(image_resized_pt, disparity_factor)
LOGGER.info("Running postprocessing.")
intrinsics = (
torch.tensor(
[
[f_px, 0, width / 2, 0],
[0, f_px, height / 2, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
]
)
.float()
.to(device)
)
intrinsics_resized = intrinsics.clone()
intrinsics_resized[0] *= internal_shape[0] / width
intrinsics_resized[1] *= internal_shape[1] / height
# Convert Gaussians to metrics space.
gaussians = unproject_gaussians(
gaussians_ndc, torch.eye(4).to(device), intrinsics_resized, internal_shape
)
return gaussians