|
|
"""Contains `sharp predict` CLI implementation. |
|
|
|
|
|
For licensing see accompanying LICENSE file. |
|
|
Copyright (C) 2025 Apple Inc. All Rights Reserved. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import logging |
|
|
from pathlib import Path |
|
|
|
|
|
import click |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import torch.utils.data |
|
|
|
|
|
from sharp.models import ( |
|
|
PredictorParams, |
|
|
RGBGaussianPredictor, |
|
|
create_predictor, |
|
|
) |
|
|
from sharp.utils import io |
|
|
from sharp.utils import logging as logging_utils |
|
|
from sharp.utils.gaussians import ( |
|
|
Gaussians3D, |
|
|
SceneMetaData, |
|
|
save_ply, |
|
|
unproject_gaussians, |
|
|
) |
|
|
|
|
|
from .render import render_gaussians |
|
|
|
|
|
LOGGER = logging.getLogger(__name__) |
|
|
|
|
|
DEFAULT_MODEL_URL = "https://ml-site.cdn-apple.com/models/sharp/sharp_2572gikvuh.pt" |
|
|
|
|
|
|
|
|
@click.command() |
|
|
@click.option( |
|
|
"-i", |
|
|
"--input-path", |
|
|
type=click.Path(path_type=Path, exists=True), |
|
|
help="Path to an image or containing a list of images.", |
|
|
required=True, |
|
|
) |
|
|
@click.option( |
|
|
"-o", |
|
|
"--output-path", |
|
|
type=click.Path(path_type=Path, file_okay=False), |
|
|
help="Path to save the predicted Gaussians and renderings.", |
|
|
required=True, |
|
|
) |
|
|
@click.option( |
|
|
"-c", |
|
|
"--checkpoint-path", |
|
|
type=click.Path(path_type=Path, dir_okay=False), |
|
|
default=None, |
|
|
help="Path to the .pt checkpoint. If not provided, downloads the default model automatically.", |
|
|
required=False, |
|
|
) |
|
|
@click.option( |
|
|
"--render/--no-render", |
|
|
"with_rendering", |
|
|
is_flag=True, |
|
|
default=False, |
|
|
help="Whether to render trajectory for checkpoint.", |
|
|
) |
|
|
@click.option( |
|
|
"--device", |
|
|
type=str, |
|
|
default="default", |
|
|
help="Device to run on. ['cpu', 'mps', 'cuda']", |
|
|
) |
|
|
@click.option("-v", "--verbose", is_flag=True, help="Activate debug logs.") |
|
|
def predict_cli( |
|
|
input_path: Path, |
|
|
output_path: Path, |
|
|
checkpoint_path: Path, |
|
|
with_rendering: bool, |
|
|
device: str, |
|
|
verbose: bool, |
|
|
): |
|
|
"""Predict Gaussians from input images.""" |
|
|
logging_utils.configure(logging.DEBUG if verbose else logging.INFO) |
|
|
|
|
|
extensions = io.get_supported_image_extensions() |
|
|
|
|
|
image_paths = [] |
|
|
if input_path.is_file(): |
|
|
if input_path.suffix in extensions: |
|
|
image_paths = [input_path] |
|
|
else: |
|
|
for ext in extensions: |
|
|
image_paths.extend(list(input_path.glob(f"**/*{ext}"))) |
|
|
|
|
|
if len(image_paths) == 0: |
|
|
LOGGER.info("No valid images found. Input was %s.", input_path) |
|
|
return |
|
|
|
|
|
LOGGER.info("Processing %d valid image files.", len(image_paths)) |
|
|
|
|
|
if device == "default": |
|
|
if torch.cuda.is_available(): |
|
|
device = "cuda" |
|
|
elif torch.mps.is_available(): |
|
|
device = "mps" |
|
|
else: |
|
|
device = "cpu" |
|
|
LOGGER.info("Using device %s", device) |
|
|
|
|
|
if with_rendering and device != "cuda": |
|
|
LOGGER.warning("Can only run rendering with gsplat on CUDA. Rendering is disabled.") |
|
|
with_rendering = False |
|
|
|
|
|
|
|
|
if checkpoint_path is None: |
|
|
LOGGER.info("No checkpoint provided. Downloading default model from %s", DEFAULT_MODEL_URL) |
|
|
state_dict = torch.hub.load_state_dict_from_url(DEFAULT_MODEL_URL, progress=True) |
|
|
else: |
|
|
LOGGER.info("Loading checkpoint from %s", checkpoint_path) |
|
|
state_dict = torch.load(checkpoint_path, weights_only=True) |
|
|
|
|
|
gaussian_predictor = create_predictor(PredictorParams()) |
|
|
gaussian_predictor.load_state_dict(state_dict) |
|
|
gaussian_predictor.eval() |
|
|
gaussian_predictor.to(device) |
|
|
|
|
|
output_path.mkdir(exist_ok=True, parents=True) |
|
|
|
|
|
for image_path in image_paths: |
|
|
LOGGER.info("Processing %s", image_path) |
|
|
image, _, f_px = io.load_rgb(image_path) |
|
|
height, width = image.shape[:2] |
|
|
intrinsics = torch.tensor( |
|
|
[ |
|
|
[f_px, 0, (width - 1) / 2.0, 0], |
|
|
[0, f_px, (height - 1) / 2.0, 0], |
|
|
[0, 0, 1, 0], |
|
|
[0, 0, 0, 1], |
|
|
], |
|
|
device=device, |
|
|
dtype=torch.float32, |
|
|
) |
|
|
gaussians = predict_image(gaussian_predictor, image, f_px, torch.device(device)) |
|
|
|
|
|
LOGGER.info("Saving 3DGS to %s", output_path) |
|
|
save_ply(gaussians, f_px, (height, width), output_path / f"{image_path.stem}.ply") |
|
|
|
|
|
if with_rendering: |
|
|
output_video_path = (output_path / image_path.stem).with_suffix(".mp4") |
|
|
LOGGER.info("Rendering trajectory to %s", output_video_path) |
|
|
|
|
|
metadata = SceneMetaData(intrinsics[0, 0].item(), (width, height), "linearRGB") |
|
|
render_gaussians(gaussians, metadata, output_video_path) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def predict_image( |
|
|
predictor: RGBGaussianPredictor, |
|
|
image: np.ndarray, |
|
|
f_px: float, |
|
|
device: torch.device, |
|
|
) -> Gaussians3D: |
|
|
"""Predict Gaussians from an image.""" |
|
|
internal_shape = (1536, 1536) |
|
|
|
|
|
LOGGER.info("Running preprocessing.") |
|
|
image_pt = torch.from_numpy(image.copy()).float().to(device).permute(2, 0, 1) / 255.0 |
|
|
_, height, width = image_pt.shape |
|
|
disparity_factor = torch.tensor([f_px / width]).float().to(device) |
|
|
|
|
|
image_resized_pt = F.interpolate( |
|
|
image_pt[None], |
|
|
size=(internal_shape[1], internal_shape[0]), |
|
|
mode="bilinear", |
|
|
align_corners=True, |
|
|
) |
|
|
|
|
|
|
|
|
LOGGER.info("Running inference.") |
|
|
gaussians_ndc = predictor(image_resized_pt, disparity_factor) |
|
|
|
|
|
LOGGER.info("Running postprocessing.") |
|
|
intrinsics = ( |
|
|
torch.tensor( |
|
|
[ |
|
|
[f_px, 0, width / 2, 0], |
|
|
[0, f_px, height / 2, 0], |
|
|
[0, 0, 1, 0], |
|
|
[0, 0, 0, 1], |
|
|
] |
|
|
) |
|
|
.float() |
|
|
.to(device) |
|
|
) |
|
|
intrinsics_resized = intrinsics.clone() |
|
|
intrinsics_resized[0] *= internal_shape[0] / width |
|
|
intrinsics_resized[1] *= internal_shape[1] / height |
|
|
|
|
|
|
|
|
gaussians = unproject_gaussians( |
|
|
gaussians_ndc, torch.eye(4).to(device), intrinsics_resized, internal_shape |
|
|
) |
|
|
|
|
|
return gaussians |
|
|
|