|
|
"""SHARP inference + optional CUDA video rendering utilities. |
|
|
|
|
|
Design goals: |
|
|
- Reuse SHARP's own predict/render pipeline (no subprocess calls). |
|
|
- Be robust on Hugging Face Spaces + ZeroGPU. |
|
|
- Cache model weights and predictor construction across requests. |
|
|
|
|
|
Public API (used by the Gradio app): |
|
|
- TrajectoryType |
|
|
- predict_and_maybe_render_gpu(...) |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import os |
|
|
import threading |
|
|
import time |
|
|
import uuid |
|
|
from contextlib import contextmanager |
|
|
from dataclasses import dataclass |
|
|
from pathlib import Path |
|
|
from typing import Final, Literal |
|
|
|
|
|
import torch |
|
|
|
|
|
try: |
|
|
import spaces |
|
|
except Exception: |
|
|
spaces = None |
|
|
|
|
|
try: |
|
|
|
|
|
from huggingface_hub import hf_hub_download, try_to_load_from_cache |
|
|
except Exception: |
|
|
hf_hub_download = None |
|
|
try_to_load_from_cache = None |
|
|
|
|
|
from sharp.cli.predict import DEFAULT_MODEL_URL, predict_image |
|
|
from sharp.cli.render import render_gaussians as sharp_render_gaussians |
|
|
from sharp.models import PredictorParams, create_predictor |
|
|
from sharp.utils import camera, io |
|
|
from sharp.utils.gaussians import Gaussians3D, SceneMetaData, save_ply |
|
|
from sharp.utils.gsplat import GSplatRenderer |
|
|
|
|
|
TrajectoryType = Literal["swipe", "shake", "rotate", "rotate_forward"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _now_ms() -> int: |
|
|
return int(time.time() * 1000) |
|
|
|
|
|
|
|
|
def _ensure_dir(path: Path) -> Path: |
|
|
path.mkdir(parents=True, exist_ok=True) |
|
|
return path |
|
|
|
|
|
|
|
|
def _make_even(x: int) -> int: |
|
|
return x if x % 2 == 0 else x + 1 |
|
|
|
|
|
|
|
|
def _select_device(preference: str = "auto") -> torch.device: |
|
|
"""Select the best available device for inference (CPU/CUDA/MPS).""" |
|
|
if preference not in {"auto", "cpu", "cuda", "mps"}: |
|
|
raise ValueError("device preference must be one of: auto|cpu|cuda|mps") |
|
|
|
|
|
if preference == "cpu": |
|
|
return torch.device("cpu") |
|
|
if preference == "cuda": |
|
|
return torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
if preference == "mps": |
|
|
return torch.device("mps" if torch.backends.mps.is_available() else "cpu") |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
return torch.device("cuda") |
|
|
if torch.backends.mps.is_available(): |
|
|
return torch.device("mps") |
|
|
return torch.device("cpu") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True, slots=True) |
|
|
class PredictionOutputs: |
|
|
"""Outputs of SHARP inference (plus derived metadata for rendering).""" |
|
|
|
|
|
ply_path: Path |
|
|
gaussians: Gaussians3D |
|
|
metadata_for_render: SceneMetaData |
|
|
input_resolution_hw: tuple[int, int] |
|
|
focal_length_px: float |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _PatchedVideoWriter(io.VideoWriter): |
|
|
"""Ensure depth writer is closed so files can be safely cleaned up.""" |
|
|
|
|
|
def __init__( |
|
|
self, output_path: Path, fps: float = 30.0, render_depth: bool = True |
|
|
) -> None: |
|
|
super().__init__(output_path, fps=fps, render_depth=render_depth) |
|
|
|
|
|
if not hasattr(self, "depth_writer"): |
|
|
self.depth_writer = None |
|
|
|
|
|
def close(self): |
|
|
super().close() |
|
|
depth_writer = getattr(self, "depth_writer", None) |
|
|
try: |
|
|
if depth_writer is not None: |
|
|
depth_writer.close() |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
@contextmanager |
|
|
def _patched_sharp_videowriter(): |
|
|
"""Temporarily patch `sharp.utils.io.VideoWriter` used by `sharp.cli.render`.""" |
|
|
original = io.VideoWriter |
|
|
io.VideoWriter = _PatchedVideoWriter |
|
|
try: |
|
|
yield |
|
|
finally: |
|
|
io.VideoWriter = original |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ModelWrapper: |
|
|
"""Cached SHARP model wrapper for Gradio/Spaces.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
*, |
|
|
outputs_dir: str | Path = "outputs", |
|
|
checkpoint_url: str = DEFAULT_MODEL_URL, |
|
|
checkpoint_path: str | Path | None = None, |
|
|
device_preference: str = "auto", |
|
|
keep_model_on_device: bool | None = None, |
|
|
hf_repo_id: str | None = None, |
|
|
hf_filename: str | None = None, |
|
|
hf_revision: str | None = None, |
|
|
) -> None: |
|
|
self.outputs_dir = _ensure_dir(Path(outputs_dir)) |
|
|
self.checkpoint_url = checkpoint_url |
|
|
|
|
|
env_ckpt = os.getenv("SHARP_CHECKPOINT_PATH") or os.getenv("SHARP_CHECKPOINT") |
|
|
if checkpoint_path: |
|
|
self.checkpoint_path = Path(checkpoint_path) |
|
|
elif env_ckpt: |
|
|
self.checkpoint_path = Path(env_ckpt) |
|
|
else: |
|
|
self.checkpoint_path = None |
|
|
|
|
|
|
|
|
self.hf_repo_id = hf_repo_id or os.getenv("SHARP_HF_REPO_ID", "apple/Sharp") |
|
|
self.hf_filename = hf_filename or os.getenv( |
|
|
"SHARP_HF_FILENAME", "sharp_2572gikvuh.pt" |
|
|
) |
|
|
self.hf_revision = hf_revision or os.getenv("SHARP_HF_REVISION") or None |
|
|
|
|
|
self.device_preference = device_preference |
|
|
|
|
|
|
|
|
if keep_model_on_device is None: |
|
|
keep_env = ( |
|
|
os.getenv("SHARP_KEEP_MODEL_ON_DEVICE") |
|
|
) |
|
|
self.keep_model_on_device = keep_env == "1" |
|
|
else: |
|
|
self.keep_model_on_device = keep_model_on_device |
|
|
|
|
|
self._lock = threading.RLock() |
|
|
self._predictor: torch.nn.Module | None = None |
|
|
self._predictor_device: torch.device | None = None |
|
|
self._state_dict: dict | None = None |
|
|
|
|
|
def has_cuda(self) -> bool: |
|
|
return torch.cuda.is_available() |
|
|
|
|
|
def _load_state_dict(self) -> dict: |
|
|
with self._lock: |
|
|
if self._state_dict is not None: |
|
|
return self._state_dict |
|
|
|
|
|
|
|
|
if self.checkpoint_path is not None: |
|
|
try: |
|
|
self._state_dict = torch.load( |
|
|
self.checkpoint_path, |
|
|
weights_only=True, |
|
|
map_location="cpu", |
|
|
) |
|
|
return self._state_dict |
|
|
except Exception as e: |
|
|
raise RuntimeError( |
|
|
"Failed to load SHARP checkpoint from local path.\n\n" |
|
|
f"Path:\n {self.checkpoint_path}\n\n" |
|
|
f"Original error:\n {type(e).__name__}: {e}" |
|
|
) from e |
|
|
|
|
|
|
|
|
hf_cache_error: Exception | None = None |
|
|
if try_to_load_from_cache is not None: |
|
|
try: |
|
|
cached = try_to_load_from_cache( |
|
|
repo_id=self.hf_repo_id, |
|
|
filename=self.hf_filename, |
|
|
revision=self.hf_revision, |
|
|
repo_type="model", |
|
|
) |
|
|
except TypeError: |
|
|
cached = try_to_load_from_cache(self.hf_repo_id, self.hf_filename) |
|
|
|
|
|
try: |
|
|
if isinstance(cached, str) and Path(cached).exists(): |
|
|
self._state_dict = torch.load( |
|
|
cached, weights_only=True, map_location="cpu" |
|
|
) |
|
|
return self._state_dict |
|
|
except Exception as e: |
|
|
hf_cache_error = e |
|
|
|
|
|
|
|
|
hf_error: Exception | None = None |
|
|
if hf_hub_download is not None: |
|
|
|
|
|
try: |
|
|
import inspect |
|
|
|
|
|
if "local_files_only" in inspect.signature(hf_hub_download).parameters: |
|
|
ckpt_path = hf_hub_download( |
|
|
repo_id=self.hf_repo_id, |
|
|
filename=self.hf_filename, |
|
|
revision=self.hf_revision, |
|
|
local_files_only=True, |
|
|
) |
|
|
if Path(ckpt_path).exists(): |
|
|
self._state_dict = torch.load( |
|
|
ckpt_path, weights_only=True, map_location="cpu" |
|
|
) |
|
|
return self._state_dict |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
try: |
|
|
ckpt_path = hf_hub_download( |
|
|
repo_id=self.hf_repo_id, |
|
|
filename=self.hf_filename, |
|
|
revision=self.hf_revision, |
|
|
) |
|
|
self._state_dict = torch.load( |
|
|
ckpt_path, |
|
|
weights_only=True, |
|
|
map_location="cpu", |
|
|
) |
|
|
return self._state_dict |
|
|
except Exception as e: |
|
|
hf_error = e |
|
|
|
|
|
|
|
|
url_error: Exception | None = None |
|
|
try: |
|
|
self._state_dict = torch.hub.load_state_dict_from_url( |
|
|
self.checkpoint_url, |
|
|
progress=True, |
|
|
map_location="cpu", |
|
|
) |
|
|
return self._state_dict |
|
|
except Exception as e: |
|
|
url_error = e |
|
|
|
|
|
|
|
|
hint_lines = [ |
|
|
"Failed to load SHARP checkpoint.", |
|
|
"", |
|
|
"Tried (in order):", |
|
|
f" 1) HF cache (preload_from_hub): repo_id={self.hf_repo_id}, filename={self.hf_filename}, revision={self.hf_revision or 'None'}", |
|
|
f" 2) HF Hub download: repo_id={self.hf_repo_id}, filename={self.hf_filename}, revision={self.hf_revision or 'None'}", |
|
|
f" 3) URL (torch hub): {self.checkpoint_url}", |
|
|
"", |
|
|
"If network access is restricted, set a local checkpoint path:", |
|
|
" - SHARP_CHECKPOINT_PATH=/path/to/sharp_2572gikvuh.pt", |
|
|
"", |
|
|
"Original errors:", |
|
|
] |
|
|
if try_to_load_from_cache is None: |
|
|
hint_lines.append(" HF cache: huggingface_hub not installed") |
|
|
elif hf_cache_error is not None: |
|
|
hint_lines.append( |
|
|
f" HF cache: {type(hf_cache_error).__name__}: {hf_cache_error}" |
|
|
) |
|
|
else: |
|
|
hint_lines.append(" HF cache: (not found in cache)") |
|
|
|
|
|
if hf_hub_download is None: |
|
|
hint_lines.append(" HF download: huggingface_hub not installed") |
|
|
else: |
|
|
hint_lines.append(f" HF download: {type(hf_error).__name__}: {hf_error}") |
|
|
|
|
|
hint_lines.append(f" URL: {type(url_error).__name__}: {url_error}") |
|
|
|
|
|
raise RuntimeError("\n".join(hint_lines)) |
|
|
|
|
|
def _get_predictor(self, device: torch.device) -> torch.nn.Module: |
|
|
with self._lock: |
|
|
if self._predictor is None: |
|
|
state_dict = self._load_state_dict() |
|
|
predictor = create_predictor(PredictorParams()) |
|
|
predictor.load_state_dict(state_dict) |
|
|
predictor.eval() |
|
|
self._predictor = predictor |
|
|
self._predictor_device = torch.device("cpu") |
|
|
|
|
|
assert self._predictor is not None |
|
|
assert self._predictor_device is not None |
|
|
|
|
|
if self._predictor_device != device: |
|
|
self._predictor.to(device) |
|
|
self._predictor_device = device |
|
|
|
|
|
return self._predictor |
|
|
|
|
|
def _maybe_move_model_back_to_cpu(self) -> None: |
|
|
if self.keep_model_on_device: |
|
|
return |
|
|
with self._lock: |
|
|
if self._predictor is not None and self._predictor_device is not None: |
|
|
if self._predictor_device.type != "cpu": |
|
|
self._predictor.to("cpu") |
|
|
self._predictor_device = torch.device("cpu") |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
def _make_output_stem(self, input_path: Path) -> str: |
|
|
return f"{input_path.stem}-{_now_ms()}-{uuid.uuid4().hex[:8]}" |
|
|
|
|
|
def predict_to_ply(self, image_path: str | Path) -> PredictionOutputs: |
|
|
"""Run SHARP inference and export a .ply file.""" |
|
|
image_path = Path(image_path) |
|
|
if not image_path.exists(): |
|
|
raise FileNotFoundError(f"Image does not exist: {image_path}") |
|
|
|
|
|
device = _select_device(self.device_preference) |
|
|
predictor = self._get_predictor(device) |
|
|
|
|
|
image_np, _, f_px = io.load_rgb(image_path) |
|
|
height, width = image_np.shape[:2] |
|
|
|
|
|
with torch.no_grad(): |
|
|
gaussians = predict_image(predictor, image_np, f_px, device) |
|
|
|
|
|
stem = self._make_output_stem(image_path) |
|
|
ply_path = self.outputs_dir / f"{stem}.ply" |
|
|
|
|
|
|
|
|
save_ply(gaussians, f_px, (height, width), ply_path) |
|
|
|
|
|
|
|
|
metadata_for_render = SceneMetaData( |
|
|
focal_length_px=float(f_px), |
|
|
resolution_px=(int(width), int(height)), |
|
|
color_space="linearRGB", |
|
|
) |
|
|
|
|
|
self._maybe_move_model_back_to_cpu() |
|
|
|
|
|
return PredictionOutputs( |
|
|
ply_path=ply_path, |
|
|
gaussians=gaussians, |
|
|
metadata_for_render=metadata_for_render, |
|
|
input_resolution_hw=(int(height), int(width)), |
|
|
focal_length_px=float(f_px), |
|
|
) |
|
|
|
|
|
def _render_video_impl( |
|
|
self, |
|
|
*, |
|
|
gaussians: Gaussians3D, |
|
|
metadata: SceneMetaData, |
|
|
output_path: Path, |
|
|
trajectory_type: TrajectoryType, |
|
|
num_frames: int, |
|
|
fps: int, |
|
|
output_long_side: int | None, |
|
|
) -> Path: |
|
|
if not torch.cuda.is_available(): |
|
|
raise RuntimeError("Rendering requires CUDA (gsplat).") |
|
|
|
|
|
if num_frames < 2: |
|
|
raise ValueError("num_frames must be >= 2") |
|
|
if fps < 1: |
|
|
raise ValueError("fps must be >= 1") |
|
|
|
|
|
|
|
|
if output_long_side is None and int(fps) == 30: |
|
|
params = camera.TrajectoryParams( |
|
|
type=trajectory_type, |
|
|
num_steps=int(num_frames), |
|
|
num_repeats=1, |
|
|
) |
|
|
with _patched_sharp_videowriter(): |
|
|
sharp_render_gaussians( |
|
|
gaussians=gaussians, |
|
|
metadata=metadata, |
|
|
params=params, |
|
|
output_path=output_path, |
|
|
) |
|
|
depth_path = output_path.with_suffix(".depth.mp4") |
|
|
try: |
|
|
if depth_path.exists(): |
|
|
depth_path.unlink() |
|
|
except Exception: |
|
|
pass |
|
|
return output_path |
|
|
|
|
|
|
|
|
src_w, src_h = metadata.resolution_px |
|
|
src_f = float(metadata.focal_length_px) |
|
|
|
|
|
if output_long_side is None: |
|
|
out_w, out_h, out_f = src_w, src_h, src_f |
|
|
else: |
|
|
long_side = max(src_w, src_h) |
|
|
scale = float(output_long_side) / float(long_side) |
|
|
out_w = _make_even(max(2, int(round(src_w * scale)))) |
|
|
out_h = _make_even(max(2, int(round(src_h * scale)))) |
|
|
out_f = src_f * scale |
|
|
|
|
|
traj_params = camera.TrajectoryParams( |
|
|
type=trajectory_type, |
|
|
num_steps=int(num_frames), |
|
|
num_repeats=1, |
|
|
) |
|
|
|
|
|
device = torch.device("cuda") |
|
|
gaussians_cuda = gaussians.to(device) |
|
|
|
|
|
intrinsics = torch.tensor( |
|
|
[ |
|
|
[out_f, 0.0, (out_w - 1) / 2.0, 0.0], |
|
|
[0.0, out_f, (out_h - 1) / 2.0, 0.0], |
|
|
[0.0, 0.0, 1.0, 0.0], |
|
|
[0.0, 0.0, 0.0, 1.0], |
|
|
], |
|
|
device=device, |
|
|
dtype=torch.float32, |
|
|
) |
|
|
|
|
|
cam_model = camera.create_camera_model( |
|
|
gaussians_cuda, |
|
|
intrinsics, |
|
|
resolution_px=(out_w, out_h), |
|
|
lookat_mode=traj_params.lookat_mode, |
|
|
) |
|
|
|
|
|
trajectory = camera.create_eye_trajectory( |
|
|
gaussians_cuda, |
|
|
traj_params, |
|
|
resolution_px=(out_w, out_h), |
|
|
f_px=out_f, |
|
|
) |
|
|
|
|
|
renderer = GSplatRenderer(color_space=metadata.color_space) |
|
|
|
|
|
|
|
|
video_writer = _PatchedVideoWriter(output_path, fps=float(fps), render_depth=True) |
|
|
|
|
|
for eye_position in trajectory: |
|
|
cam_info = cam_model.compute(eye_position) |
|
|
rendering = renderer( |
|
|
gaussians_cuda, |
|
|
extrinsics=cam_info.extrinsics[None].to(device), |
|
|
intrinsics=cam_info.intrinsics[None].to(device), |
|
|
image_width=cam_info.width, |
|
|
image_height=cam_info.height, |
|
|
) |
|
|
color = (rendering.color[0].permute(1, 2, 0) * 255.0).to(dtype=torch.uint8) |
|
|
depth = rendering.depth[0] |
|
|
video_writer.add_frame(color, depth) |
|
|
|
|
|
video_writer.close() |
|
|
|
|
|
depth_path = output_path.with_suffix(".depth.mp4") |
|
|
try: |
|
|
if depth_path.exists(): |
|
|
depth_path.unlink() |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
return output_path |
|
|
|
|
|
def render_video( |
|
|
self, |
|
|
*, |
|
|
gaussians: Gaussians3D, |
|
|
metadata: SceneMetaData, |
|
|
output_stem: str, |
|
|
trajectory_type: TrajectoryType = "rotate_forward", |
|
|
num_frames: int = 60, |
|
|
fps: int = 30, |
|
|
output_long_side: int | None = None, |
|
|
) -> Path: |
|
|
"""Render a camera trajectory as an MP4 (CUDA-only).""" |
|
|
output_path = self.outputs_dir / f"{output_stem}.mp4" |
|
|
return self._render_video_impl( |
|
|
gaussians=gaussians, |
|
|
metadata=metadata, |
|
|
output_path=output_path, |
|
|
trajectory_type=trajectory_type, |
|
|
num_frames=num_frames, |
|
|
fps=fps, |
|
|
output_long_side=output_long_side, |
|
|
) |
|
|
|
|
|
def predict_and_maybe_render( |
|
|
self, |
|
|
image_path: str | Path, |
|
|
*, |
|
|
trajectory_type: TrajectoryType, |
|
|
num_frames: int, |
|
|
fps: int, |
|
|
output_long_side: int | None, |
|
|
render_video: bool = True, |
|
|
) -> tuple[Path | None, Path]: |
|
|
"""One-shot helper for the UI: returns (video_path, ply_path).""" |
|
|
pred = self.predict_to_ply(image_path) |
|
|
|
|
|
if not render_video: |
|
|
return None, pred.ply_path |
|
|
|
|
|
if not torch.cuda.is_available(): |
|
|
return None, pred.ply_path |
|
|
|
|
|
output_stem = pred.ply_path.with_suffix("").name |
|
|
video_path = self.render_video( |
|
|
gaussians=pred.gaussians, |
|
|
metadata=pred.metadata_for_render, |
|
|
output_stem=output_stem, |
|
|
trajectory_type=trajectory_type, |
|
|
num_frames=num_frames, |
|
|
fps=fps, |
|
|
output_long_side=output_long_side, |
|
|
) |
|
|
return video_path, pred.ply_path |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_OUTPUTS_DIR: Final[Path] = _ensure_dir(Path(__file__).resolve().parent / "outputs") |
|
|
|
|
|
_GLOBAL_MODEL: ModelWrapper | None = None |
|
|
_GLOBAL_MODEL_INIT_LOCK: Final[threading.Lock] = threading.Lock() |
|
|
|
|
|
|
|
|
def get_global_model(*, outputs_dir: str | Path = DEFAULT_OUTPUTS_DIR) -> ModelWrapper: |
|
|
global _GLOBAL_MODEL |
|
|
with _GLOBAL_MODEL_INIT_LOCK: |
|
|
if _GLOBAL_MODEL is None: |
|
|
_GLOBAL_MODEL = ModelWrapper(outputs_dir=outputs_dir) |
|
|
return _GLOBAL_MODEL |
|
|
|
|
|
|
|
|
def predict_and_maybe_render( |
|
|
image_path: str | Path, |
|
|
*, |
|
|
trajectory_type: TrajectoryType, |
|
|
num_frames: int, |
|
|
fps: int, |
|
|
output_long_side: int | None, |
|
|
render_video: bool = True, |
|
|
) -> tuple[Path | None, Path]: |
|
|
model = get_global_model() |
|
|
return model.predict_and_maybe_render( |
|
|
image_path, |
|
|
trajectory_type=trajectory_type, |
|
|
num_frames=num_frames, |
|
|
fps=fps, |
|
|
output_long_side=output_long_side, |
|
|
render_video=render_video, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if spaces is not None: |
|
|
predict_and_maybe_render_gpu = spaces.GPU(duration=180)(predict_and_maybe_render) |
|
|
else: |
|
|
predict_and_maybe_render_gpu = predict_and_maybe_render |
|
|
|