ml-sharp-3d-viewer / model_utils.py
notaneimu's picture
initial release
a7ee00a
"""SHARP inference utilities (PLY export only).
This module intentionally does *not* implement MP4/video rendering.
It provides a small, Spaces/ZeroGPU-friendly wrapper that:
- Caches model weights and predictor construction across requests.
- Runs SHARP inference and exports a canonical `.ply`.
Public API (used by the Gradio app):
- predict_to_ply_gpu(...)
"""
from __future__ import annotations
import os
import threading
import time
import uuid
from dataclasses import dataclass
from pathlib import Path
from typing import Final
import torch
try:
import spaces
except Exception: # pragma: no cover
spaces = None # type: ignore[assignment]
try:
# Prefer HF cache / Hub downloads (works with Spaces `preload_from_hub`).
from huggingface_hub import hf_hub_download, try_to_load_from_cache
except Exception: # pragma: no cover
hf_hub_download = None # type: ignore[assignment]
try_to_load_from_cache = None # type: ignore[assignment]
from sharp.cli.predict import DEFAULT_MODEL_URL, predict_image
from sharp.models import PredictorParams, create_predictor
from sharp.utils import io
from sharp.utils.gaussians import save_ply
# -----------------------------------------------------------------------------
# Helpers
# -----------------------------------------------------------------------------
def _now_ms() -> int:
return int(time.time() * 1000)
def _ensure_dir(path: Path) -> Path:
path.mkdir(parents=True, exist_ok=True)
return path
def _make_even(x: int) -> int:
return x if x % 2 == 0 else x + 1
def _select_device(preference: str = "auto") -> torch.device:
"""Select the best available device for inference (CPU/CUDA/MPS)."""
if preference not in {"auto", "cpu", "cuda", "mps"}:
raise ValueError("device preference must be one of: auto|cpu|cuda|mps")
if preference == "cpu":
return torch.device("cpu")
if preference == "cuda":
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
if preference == "mps":
return torch.device("mps" if torch.backends.mps.is_available() else "cpu")
# auto
if torch.cuda.is_available():
return torch.device("cuda")
if torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
# -----------------------------------------------------------------------------
# Prediction outputs
# -----------------------------------------------------------------------------
@dataclass(frozen=True, slots=True)
class PredictionOutputs:
"""Outputs of SHARP inference."""
ply_path: Path
# -----------------------------------------------------------------------------
# Model wrapper
# -----------------------------------------------------------------------------
class ModelWrapper:
"""Cached SHARP model wrapper for Gradio/Spaces."""
def __init__(
self,
*,
outputs_dir: str | Path = "outputs",
checkpoint_url: str = DEFAULT_MODEL_URL,
checkpoint_path: str | Path | None = None,
device_preference: str = "auto",
keep_model_on_device: bool | None = None,
hf_repo_id: str | None = None,
hf_filename: str | None = None,
hf_revision: str | None = None,
) -> None:
self.outputs_dir = _ensure_dir(Path(outputs_dir))
self.checkpoint_url = checkpoint_url
env_ckpt = os.getenv("SHARP_CHECKPOINT_PATH") or os.getenv("SHARP_CHECKPOINT")
if checkpoint_path:
self.checkpoint_path = Path(checkpoint_path)
elif env_ckpt:
self.checkpoint_path = Path(env_ckpt)
else:
self.checkpoint_path = None
# Optional Hugging Face Hub fallback (useful when direct CDN download fails).
self.hf_repo_id = hf_repo_id or os.getenv("SHARP_HF_REPO_ID", "apple/Sharp")
self.hf_filename = hf_filename or os.getenv(
"SHARP_HF_FILENAME", "sharp_2572gikvuh.pt"
)
self.hf_revision = hf_revision or os.getenv("SHARP_HF_REVISION") or None
self.device_preference = device_preference
# For ZeroGPU, it's safer to not keep large tensors on CUDA across calls.
if keep_model_on_device is None:
keep_env = (
os.getenv("SHARP_KEEP_MODEL_ON_DEVICE")
)
self.keep_model_on_device = keep_env == "1"
else:
self.keep_model_on_device = keep_model_on_device
self._lock = threading.RLock()
self._predictor: torch.nn.Module | None = None
self._predictor_device: torch.device | None = None
self._state_dict: dict | None = None
def has_cuda(self) -> bool:
return torch.cuda.is_available()
def _load_state_dict(self) -> dict:
with self._lock:
if self._state_dict is not None:
return self._state_dict
# 1) Explicit local checkpoint path
if self.checkpoint_path is not None:
try:
self._state_dict = torch.load(
self.checkpoint_path,
weights_only=True,
map_location="cpu",
)
return self._state_dict
except Exception as e:
raise RuntimeError(
"Failed to load SHARP checkpoint from local path.\n\n"
f"Path:\n {self.checkpoint_path}\n\n"
f"Original error:\n {type(e).__name__}: {e}"
) from e
# 2) HF cache (no-network): best match for Spaces `preload_from_hub`.
hf_cache_error: Exception | None = None
if try_to_load_from_cache is not None:
try:
cached = try_to_load_from_cache(
repo_id=self.hf_repo_id,
filename=self.hf_filename,
revision=self.hf_revision,
repo_type="model",
)
except TypeError:
cached = try_to_load_from_cache(self.hf_repo_id, self.hf_filename) # type: ignore[misc]
try:
if isinstance(cached, str) and Path(cached).exists():
self._state_dict = torch.load(
cached, weights_only=True, map_location="cpu"
)
return self._state_dict
except Exception as e:
hf_cache_error = e
# 3) HF Hub download (reuse cache when available; may download otherwise).
hf_error: Exception | None = None
if hf_hub_download is not None:
# Attempt "local only" mode if supported (avoids network).
try:
import inspect
if "local_files_only" in inspect.signature(hf_hub_download).parameters:
ckpt_path = hf_hub_download(
repo_id=self.hf_repo_id,
filename=self.hf_filename,
revision=self.hf_revision,
local_files_only=True,
)
if Path(ckpt_path).exists():
self._state_dict = torch.load(
ckpt_path, weights_only=True, map_location="cpu"
)
return self._state_dict
except Exception:
pass
try:
ckpt_path = hf_hub_download(
repo_id=self.hf_repo_id,
filename=self.hf_filename,
revision=self.hf_revision,
)
self._state_dict = torch.load(
ckpt_path,
weights_only=True,
map_location="cpu",
)
return self._state_dict
except Exception as e:
hf_error = e
# 4) Default upstream CDN (torch hub cache). Last resort.
url_error: Exception | None = None
try:
self._state_dict = torch.hub.load_state_dict_from_url(
self.checkpoint_url,
progress=True,
map_location="cpu",
)
return self._state_dict
except Exception as e:
url_error = e
# If we got here: all options failed.
hint_lines = [
"Failed to load SHARP checkpoint.",
"",
"Tried (in order):",
f" 1) HF cache (preload_from_hub): repo_id={self.hf_repo_id}, filename={self.hf_filename}, revision={self.hf_revision or 'None'}",
f" 2) HF Hub download: repo_id={self.hf_repo_id}, filename={self.hf_filename}, revision={self.hf_revision or 'None'}",
f" 3) URL (torch hub): {self.checkpoint_url}",
"",
"If network access is restricted, set a local checkpoint path:",
" - SHARP_CHECKPOINT_PATH=/path/to/sharp_2572gikvuh.pt",
"",
"Original errors:",
]
if try_to_load_from_cache is None:
hint_lines.append(" HF cache: huggingface_hub not installed")
elif hf_cache_error is not None:
hint_lines.append(
f" HF cache: {type(hf_cache_error).__name__}: {hf_cache_error}"
)
else:
hint_lines.append(" HF cache: (not found in cache)")
if hf_hub_download is None:
hint_lines.append(" HF download: huggingface_hub not installed")
else:
hint_lines.append(f" HF download: {type(hf_error).__name__}: {hf_error}")
hint_lines.append(f" URL: {type(url_error).__name__}: {url_error}")
raise RuntimeError("\n".join(hint_lines))
def _get_predictor(self, device: torch.device) -> torch.nn.Module:
with self._lock:
if self._predictor is None:
state_dict = self._load_state_dict()
predictor = create_predictor(PredictorParams())
predictor.load_state_dict(state_dict)
predictor.eval()
self._predictor = predictor
self._predictor_device = torch.device("cpu")
assert self._predictor is not None
assert self._predictor_device is not None
if self._predictor_device != device:
self._predictor.to(device)
self._predictor_device = device
return self._predictor
def _maybe_move_model_back_to_cpu(self) -> None:
if self.keep_model_on_device:
return
with self._lock:
if self._predictor is not None and self._predictor_device is not None:
if self._predictor_device.type != "cpu":
self._predictor.to("cpu")
self._predictor_device = torch.device("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
def _make_output_stem(self, input_path: Path) -> str:
return f"{input_path.stem}-{_now_ms()}-{uuid.uuid4().hex[:8]}"
def predict_to_ply(self, image_path: str | Path) -> PredictionOutputs:
"""Run SHARP inference and export a .ply file."""
image_path = Path(image_path)
if not image_path.exists():
raise FileNotFoundError(f"Image does not exist: {image_path}")
device = _select_device(self.device_preference)
predictor = self._get_predictor(device)
image_np, _, f_px = io.load_rgb(image_path)
height, width = image_np.shape[:2]
with torch.no_grad():
gaussians = predict_image(predictor, image_np, f_px, device)
stem = self._make_output_stem(image_path)
ply_path = self.outputs_dir / f"{stem}.ply"
# save_ply expects (height, width).
save_ply(gaussians, f_px, (height, width), ply_path)
self._maybe_move_model_back_to_cpu()
return PredictionOutputs(ply_path=ply_path)
# -----------------------------------------------------------------------------
# ZeroGPU entrypoints
# -----------------------------------------------------------------------------
#
# IMPORTANT: Do NOT decorate bound instance methods with `@spaces.GPU` on ZeroGPU.
# The wrapper uses multiprocessing queues and pickles args/kwargs. If `self` is
# included, Python will try to pickle the whole instance. ModelWrapper contains
# a threading.RLock (not pickleable) and the model itself should not be pickled.
#
# Expose module-level functions that accept only pickleable arguments and
# create/cache the ModelWrapper inside the GPU worker process.
DEFAULT_OUTPUTS_DIR: Final[Path] = _ensure_dir(Path(__file__).resolve().parent / "outputs")
_GLOBAL_MODEL: ModelWrapper | None = None
_GLOBAL_MODEL_INIT_LOCK: Final[threading.Lock] = threading.Lock()
def get_global_model(*, outputs_dir: str | Path = DEFAULT_OUTPUTS_DIR) -> ModelWrapper:
global _GLOBAL_MODEL
with _GLOBAL_MODEL_INIT_LOCK:
if _GLOBAL_MODEL is None:
_GLOBAL_MODEL = ModelWrapper(outputs_dir=outputs_dir)
return _GLOBAL_MODEL
def predict_to_ply(
image_path: str | Path,
) -> Path:
model = get_global_model()
return model.predict_to_ply(image_path).ply_path
# Export the GPU-wrapped callable (or a no-op wrapper locally).
if spaces is not None:
predict_to_ply_gpu = spaces.GPU(duration=180)(predict_to_ply)
else: # pragma: no cover
predict_to_ply_gpu = predict_to_ply