"""SHARP inference utilities (PLY export only). This module intentionally does *not* implement MP4/video rendering. It provides a small, Spaces/ZeroGPU-friendly wrapper that: - Caches model weights and predictor construction across requests. - Runs SHARP inference and exports a canonical `.ply`. Public API (used by the Gradio app): - predict_to_ply_gpu(...) """ from __future__ import annotations import os import threading import time import uuid from dataclasses import dataclass from pathlib import Path from typing import Final import torch try: import spaces except Exception: # pragma: no cover spaces = None # type: ignore[assignment] try: # Prefer HF cache / Hub downloads (works with Spaces `preload_from_hub`). from huggingface_hub import hf_hub_download, try_to_load_from_cache except Exception: # pragma: no cover hf_hub_download = None # type: ignore[assignment] try_to_load_from_cache = None # type: ignore[assignment] from sharp.cli.predict import DEFAULT_MODEL_URL, predict_image from sharp.models import PredictorParams, create_predictor from sharp.utils import io from sharp.utils.gaussians import save_ply # ----------------------------------------------------------------------------- # Helpers # ----------------------------------------------------------------------------- def _now_ms() -> int: return int(time.time() * 1000) def _ensure_dir(path: Path) -> Path: path.mkdir(parents=True, exist_ok=True) return path def _make_even(x: int) -> int: return x if x % 2 == 0 else x + 1 def _select_device(preference: str = "auto") -> torch.device: """Select the best available device for inference (CPU/CUDA/MPS).""" if preference not in {"auto", "cpu", "cuda", "mps"}: raise ValueError("device preference must be one of: auto|cpu|cuda|mps") if preference == "cpu": return torch.device("cpu") if preference == "cuda": return torch.device("cuda" if torch.cuda.is_available() else "cpu") if preference == "mps": return torch.device("mps" if torch.backends.mps.is_available() else "cpu") # auto if torch.cuda.is_available(): return torch.device("cuda") if torch.backends.mps.is_available(): return torch.device("mps") return torch.device("cpu") # ----------------------------------------------------------------------------- # Prediction outputs # ----------------------------------------------------------------------------- @dataclass(frozen=True, slots=True) class PredictionOutputs: """Outputs of SHARP inference.""" ply_path: Path # ----------------------------------------------------------------------------- # Model wrapper # ----------------------------------------------------------------------------- class ModelWrapper: """Cached SHARP model wrapper for Gradio/Spaces.""" def __init__( self, *, outputs_dir: str | Path = "outputs", checkpoint_url: str = DEFAULT_MODEL_URL, checkpoint_path: str | Path | None = None, device_preference: str = "auto", keep_model_on_device: bool | None = None, hf_repo_id: str | None = None, hf_filename: str | None = None, hf_revision: str | None = None, ) -> None: self.outputs_dir = _ensure_dir(Path(outputs_dir)) self.checkpoint_url = checkpoint_url env_ckpt = os.getenv("SHARP_CHECKPOINT_PATH") or os.getenv("SHARP_CHECKPOINT") if checkpoint_path: self.checkpoint_path = Path(checkpoint_path) elif env_ckpt: self.checkpoint_path = Path(env_ckpt) else: self.checkpoint_path = None # Optional Hugging Face Hub fallback (useful when direct CDN download fails). self.hf_repo_id = hf_repo_id or os.getenv("SHARP_HF_REPO_ID", "apple/Sharp") self.hf_filename = hf_filename or os.getenv( "SHARP_HF_FILENAME", "sharp_2572gikvuh.pt" ) self.hf_revision = hf_revision or os.getenv("SHARP_HF_REVISION") or None self.device_preference = device_preference # For ZeroGPU, it's safer to not keep large tensors on CUDA across calls. if keep_model_on_device is None: keep_env = ( os.getenv("SHARP_KEEP_MODEL_ON_DEVICE") ) self.keep_model_on_device = keep_env == "1" else: self.keep_model_on_device = keep_model_on_device self._lock = threading.RLock() self._predictor: torch.nn.Module | None = None self._predictor_device: torch.device | None = None self._state_dict: dict | None = None def has_cuda(self) -> bool: return torch.cuda.is_available() def _load_state_dict(self) -> dict: with self._lock: if self._state_dict is not None: return self._state_dict # 1) Explicit local checkpoint path if self.checkpoint_path is not None: try: self._state_dict = torch.load( self.checkpoint_path, weights_only=True, map_location="cpu", ) return self._state_dict except Exception as e: raise RuntimeError( "Failed to load SHARP checkpoint from local path.\n\n" f"Path:\n {self.checkpoint_path}\n\n" f"Original error:\n {type(e).__name__}: {e}" ) from e # 2) HF cache (no-network): best match for Spaces `preload_from_hub`. hf_cache_error: Exception | None = None if try_to_load_from_cache is not None: try: cached = try_to_load_from_cache( repo_id=self.hf_repo_id, filename=self.hf_filename, revision=self.hf_revision, repo_type="model", ) except TypeError: cached = try_to_load_from_cache(self.hf_repo_id, self.hf_filename) # type: ignore[misc] try: if isinstance(cached, str) and Path(cached).exists(): self._state_dict = torch.load( cached, weights_only=True, map_location="cpu" ) return self._state_dict except Exception as e: hf_cache_error = e # 3) HF Hub download (reuse cache when available; may download otherwise). hf_error: Exception | None = None if hf_hub_download is not None: # Attempt "local only" mode if supported (avoids network). try: import inspect if "local_files_only" in inspect.signature(hf_hub_download).parameters: ckpt_path = hf_hub_download( repo_id=self.hf_repo_id, filename=self.hf_filename, revision=self.hf_revision, local_files_only=True, ) if Path(ckpt_path).exists(): self._state_dict = torch.load( ckpt_path, weights_only=True, map_location="cpu" ) return self._state_dict except Exception: pass try: ckpt_path = hf_hub_download( repo_id=self.hf_repo_id, filename=self.hf_filename, revision=self.hf_revision, ) self._state_dict = torch.load( ckpt_path, weights_only=True, map_location="cpu", ) return self._state_dict except Exception as e: hf_error = e # 4) Default upstream CDN (torch hub cache). Last resort. url_error: Exception | None = None try: self._state_dict = torch.hub.load_state_dict_from_url( self.checkpoint_url, progress=True, map_location="cpu", ) return self._state_dict except Exception as e: url_error = e # If we got here: all options failed. hint_lines = [ "Failed to load SHARP checkpoint.", "", "Tried (in order):", f" 1) HF cache (preload_from_hub): repo_id={self.hf_repo_id}, filename={self.hf_filename}, revision={self.hf_revision or 'None'}", f" 2) HF Hub download: repo_id={self.hf_repo_id}, filename={self.hf_filename}, revision={self.hf_revision or 'None'}", f" 3) URL (torch hub): {self.checkpoint_url}", "", "If network access is restricted, set a local checkpoint path:", " - SHARP_CHECKPOINT_PATH=/path/to/sharp_2572gikvuh.pt", "", "Original errors:", ] if try_to_load_from_cache is None: hint_lines.append(" HF cache: huggingface_hub not installed") elif hf_cache_error is not None: hint_lines.append( f" HF cache: {type(hf_cache_error).__name__}: {hf_cache_error}" ) else: hint_lines.append(" HF cache: (not found in cache)") if hf_hub_download is None: hint_lines.append(" HF download: huggingface_hub not installed") else: hint_lines.append(f" HF download: {type(hf_error).__name__}: {hf_error}") hint_lines.append(f" URL: {type(url_error).__name__}: {url_error}") raise RuntimeError("\n".join(hint_lines)) def _get_predictor(self, device: torch.device) -> torch.nn.Module: with self._lock: if self._predictor is None: state_dict = self._load_state_dict() predictor = create_predictor(PredictorParams()) predictor.load_state_dict(state_dict) predictor.eval() self._predictor = predictor self._predictor_device = torch.device("cpu") assert self._predictor is not None assert self._predictor_device is not None if self._predictor_device != device: self._predictor.to(device) self._predictor_device = device return self._predictor def _maybe_move_model_back_to_cpu(self) -> None: if self.keep_model_on_device: return with self._lock: if self._predictor is not None and self._predictor_device is not None: if self._predictor_device.type != "cpu": self._predictor.to("cpu") self._predictor_device = torch.device("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() def _make_output_stem(self, input_path: Path) -> str: return f"{input_path.stem}-{_now_ms()}-{uuid.uuid4().hex[:8]}" def predict_to_ply(self, image_path: str | Path) -> PredictionOutputs: """Run SHARP inference and export a .ply file.""" image_path = Path(image_path) if not image_path.exists(): raise FileNotFoundError(f"Image does not exist: {image_path}") device = _select_device(self.device_preference) predictor = self._get_predictor(device) image_np, _, f_px = io.load_rgb(image_path) height, width = image_np.shape[:2] with torch.no_grad(): gaussians = predict_image(predictor, image_np, f_px, device) stem = self._make_output_stem(image_path) ply_path = self.outputs_dir / f"{stem}.ply" # save_ply expects (height, width). save_ply(gaussians, f_px, (height, width), ply_path) self._maybe_move_model_back_to_cpu() return PredictionOutputs(ply_path=ply_path) # ----------------------------------------------------------------------------- # ZeroGPU entrypoints # ----------------------------------------------------------------------------- # # IMPORTANT: Do NOT decorate bound instance methods with `@spaces.GPU` on ZeroGPU. # The wrapper uses multiprocessing queues and pickles args/kwargs. If `self` is # included, Python will try to pickle the whole instance. ModelWrapper contains # a threading.RLock (not pickleable) and the model itself should not be pickled. # # Expose module-level functions that accept only pickleable arguments and # create/cache the ModelWrapper inside the GPU worker process. DEFAULT_OUTPUTS_DIR: Final[Path] = _ensure_dir(Path(__file__).resolve().parent / "outputs") _GLOBAL_MODEL: ModelWrapper | None = None _GLOBAL_MODEL_INIT_LOCK: Final[threading.Lock] = threading.Lock() def get_global_model(*, outputs_dir: str | Path = DEFAULT_OUTPUTS_DIR) -> ModelWrapper: global _GLOBAL_MODEL with _GLOBAL_MODEL_INIT_LOCK: if _GLOBAL_MODEL is None: _GLOBAL_MODEL = ModelWrapper(outputs_dir=outputs_dir) return _GLOBAL_MODEL def predict_to_ply( image_path: str | Path, ) -> Path: model = get_global_model() return model.predict_to_ply(image_path).ply_path # Export the GPU-wrapped callable (or a no-op wrapper locally). if spaces is not None: predict_to_ply_gpu = spaces.GPU(duration=180)(predict_to_ply) else: # pragma: no cover predict_to_ply_gpu = predict_to_ply