Spaces:
Running
Running
| """SHARP inference utilities (PLY export only). | |
| This module intentionally does *not* implement MP4/video rendering. | |
| It provides a small, Spaces/ZeroGPU-friendly wrapper that: | |
| - Caches model weights and predictor construction across requests. | |
| - Runs SHARP inference and exports a canonical `.ply`. | |
| Public API (used by the Gradio app): | |
| - predict_to_ply_gpu(...) | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import threading | |
| import time | |
| import uuid | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Final | |
| import torch | |
| try: | |
| import spaces | |
| except Exception: # pragma: no cover | |
| spaces = None # type: ignore[assignment] | |
| try: | |
| # Prefer HF cache / Hub downloads (works with Spaces `preload_from_hub`). | |
| from huggingface_hub import hf_hub_download, try_to_load_from_cache | |
| except Exception: # pragma: no cover | |
| hf_hub_download = None # type: ignore[assignment] | |
| try_to_load_from_cache = None # type: ignore[assignment] | |
| from sharp.cli.predict import DEFAULT_MODEL_URL, predict_image | |
| from sharp.models import PredictorParams, create_predictor | |
| from sharp.utils import io | |
| from sharp.utils.gaussians import save_ply | |
| # ----------------------------------------------------------------------------- | |
| # Helpers | |
| # ----------------------------------------------------------------------------- | |
| def _now_ms() -> int: | |
| return int(time.time() * 1000) | |
| def _ensure_dir(path: Path) -> Path: | |
| path.mkdir(parents=True, exist_ok=True) | |
| return path | |
| def _make_even(x: int) -> int: | |
| return x if x % 2 == 0 else x + 1 | |
| def _select_device(preference: str = "auto") -> torch.device: | |
| """Select the best available device for inference (CPU/CUDA/MPS).""" | |
| if preference not in {"auto", "cpu", "cuda", "mps"}: | |
| raise ValueError("device preference must be one of: auto|cpu|cuda|mps") | |
| if preference == "cpu": | |
| return torch.device("cpu") | |
| if preference == "cuda": | |
| return torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| if preference == "mps": | |
| return torch.device("mps" if torch.backends.mps.is_available() else "cpu") | |
| # auto | |
| if torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| if torch.backends.mps.is_available(): | |
| return torch.device("mps") | |
| return torch.device("cpu") | |
| # ----------------------------------------------------------------------------- | |
| # Prediction outputs | |
| # ----------------------------------------------------------------------------- | |
| class PredictionOutputs: | |
| """Outputs of SHARP inference.""" | |
| ply_path: Path | |
| # ----------------------------------------------------------------------------- | |
| # Model wrapper | |
| # ----------------------------------------------------------------------------- | |
| class ModelWrapper: | |
| """Cached SHARP model wrapper for Gradio/Spaces.""" | |
| def __init__( | |
| self, | |
| *, | |
| outputs_dir: str | Path = "outputs", | |
| checkpoint_url: str = DEFAULT_MODEL_URL, | |
| checkpoint_path: str | Path | None = None, | |
| device_preference: str = "auto", | |
| keep_model_on_device: bool | None = None, | |
| hf_repo_id: str | None = None, | |
| hf_filename: str | None = None, | |
| hf_revision: str | None = None, | |
| ) -> None: | |
| self.outputs_dir = _ensure_dir(Path(outputs_dir)) | |
| self.checkpoint_url = checkpoint_url | |
| env_ckpt = os.getenv("SHARP_CHECKPOINT_PATH") or os.getenv("SHARP_CHECKPOINT") | |
| if checkpoint_path: | |
| self.checkpoint_path = Path(checkpoint_path) | |
| elif env_ckpt: | |
| self.checkpoint_path = Path(env_ckpt) | |
| else: | |
| self.checkpoint_path = None | |
| # Optional Hugging Face Hub fallback (useful when direct CDN download fails). | |
| self.hf_repo_id = hf_repo_id or os.getenv("SHARP_HF_REPO_ID", "apple/Sharp") | |
| self.hf_filename = hf_filename or os.getenv( | |
| "SHARP_HF_FILENAME", "sharp_2572gikvuh.pt" | |
| ) | |
| self.hf_revision = hf_revision or os.getenv("SHARP_HF_REVISION") or None | |
| self.device_preference = device_preference | |
| # For ZeroGPU, it's safer to not keep large tensors on CUDA across calls. | |
| if keep_model_on_device is None: | |
| keep_env = ( | |
| os.getenv("SHARP_KEEP_MODEL_ON_DEVICE") | |
| ) | |
| self.keep_model_on_device = keep_env == "1" | |
| else: | |
| self.keep_model_on_device = keep_model_on_device | |
| self._lock = threading.RLock() | |
| self._predictor: torch.nn.Module | None = None | |
| self._predictor_device: torch.device | None = None | |
| self._state_dict: dict | None = None | |
| def has_cuda(self) -> bool: | |
| return torch.cuda.is_available() | |
| def _load_state_dict(self) -> dict: | |
| with self._lock: | |
| if self._state_dict is not None: | |
| return self._state_dict | |
| # 1) Explicit local checkpoint path | |
| if self.checkpoint_path is not None: | |
| try: | |
| self._state_dict = torch.load( | |
| self.checkpoint_path, | |
| weights_only=True, | |
| map_location="cpu", | |
| ) | |
| return self._state_dict | |
| except Exception as e: | |
| raise RuntimeError( | |
| "Failed to load SHARP checkpoint from local path.\n\n" | |
| f"Path:\n {self.checkpoint_path}\n\n" | |
| f"Original error:\n {type(e).__name__}: {e}" | |
| ) from e | |
| # 2) HF cache (no-network): best match for Spaces `preload_from_hub`. | |
| hf_cache_error: Exception | None = None | |
| if try_to_load_from_cache is not None: | |
| try: | |
| cached = try_to_load_from_cache( | |
| repo_id=self.hf_repo_id, | |
| filename=self.hf_filename, | |
| revision=self.hf_revision, | |
| repo_type="model", | |
| ) | |
| except TypeError: | |
| cached = try_to_load_from_cache(self.hf_repo_id, self.hf_filename) # type: ignore[misc] | |
| try: | |
| if isinstance(cached, str) and Path(cached).exists(): | |
| self._state_dict = torch.load( | |
| cached, weights_only=True, map_location="cpu" | |
| ) | |
| return self._state_dict | |
| except Exception as e: | |
| hf_cache_error = e | |
| # 3) HF Hub download (reuse cache when available; may download otherwise). | |
| hf_error: Exception | None = None | |
| if hf_hub_download is not None: | |
| # Attempt "local only" mode if supported (avoids network). | |
| try: | |
| import inspect | |
| if "local_files_only" in inspect.signature(hf_hub_download).parameters: | |
| ckpt_path = hf_hub_download( | |
| repo_id=self.hf_repo_id, | |
| filename=self.hf_filename, | |
| revision=self.hf_revision, | |
| local_files_only=True, | |
| ) | |
| if Path(ckpt_path).exists(): | |
| self._state_dict = torch.load( | |
| ckpt_path, weights_only=True, map_location="cpu" | |
| ) | |
| return self._state_dict | |
| except Exception: | |
| pass | |
| try: | |
| ckpt_path = hf_hub_download( | |
| repo_id=self.hf_repo_id, | |
| filename=self.hf_filename, | |
| revision=self.hf_revision, | |
| ) | |
| self._state_dict = torch.load( | |
| ckpt_path, | |
| weights_only=True, | |
| map_location="cpu", | |
| ) | |
| return self._state_dict | |
| except Exception as e: | |
| hf_error = e | |
| # 4) Default upstream CDN (torch hub cache). Last resort. | |
| url_error: Exception | None = None | |
| try: | |
| self._state_dict = torch.hub.load_state_dict_from_url( | |
| self.checkpoint_url, | |
| progress=True, | |
| map_location="cpu", | |
| ) | |
| return self._state_dict | |
| except Exception as e: | |
| url_error = e | |
| # If we got here: all options failed. | |
| hint_lines = [ | |
| "Failed to load SHARP checkpoint.", | |
| "", | |
| "Tried (in order):", | |
| f" 1) HF cache (preload_from_hub): repo_id={self.hf_repo_id}, filename={self.hf_filename}, revision={self.hf_revision or 'None'}", | |
| f" 2) HF Hub download: repo_id={self.hf_repo_id}, filename={self.hf_filename}, revision={self.hf_revision or 'None'}", | |
| f" 3) URL (torch hub): {self.checkpoint_url}", | |
| "", | |
| "If network access is restricted, set a local checkpoint path:", | |
| " - SHARP_CHECKPOINT_PATH=/path/to/sharp_2572gikvuh.pt", | |
| "", | |
| "Original errors:", | |
| ] | |
| if try_to_load_from_cache is None: | |
| hint_lines.append(" HF cache: huggingface_hub not installed") | |
| elif hf_cache_error is not None: | |
| hint_lines.append( | |
| f" HF cache: {type(hf_cache_error).__name__}: {hf_cache_error}" | |
| ) | |
| else: | |
| hint_lines.append(" HF cache: (not found in cache)") | |
| if hf_hub_download is None: | |
| hint_lines.append(" HF download: huggingface_hub not installed") | |
| else: | |
| hint_lines.append(f" HF download: {type(hf_error).__name__}: {hf_error}") | |
| hint_lines.append(f" URL: {type(url_error).__name__}: {url_error}") | |
| raise RuntimeError("\n".join(hint_lines)) | |
| def _get_predictor(self, device: torch.device) -> torch.nn.Module: | |
| with self._lock: | |
| if self._predictor is None: | |
| state_dict = self._load_state_dict() | |
| predictor = create_predictor(PredictorParams()) | |
| predictor.load_state_dict(state_dict) | |
| predictor.eval() | |
| self._predictor = predictor | |
| self._predictor_device = torch.device("cpu") | |
| assert self._predictor is not None | |
| assert self._predictor_device is not None | |
| if self._predictor_device != device: | |
| self._predictor.to(device) | |
| self._predictor_device = device | |
| return self._predictor | |
| def _maybe_move_model_back_to_cpu(self) -> None: | |
| if self.keep_model_on_device: | |
| return | |
| with self._lock: | |
| if self._predictor is not None and self._predictor_device is not None: | |
| if self._predictor_device.type != "cpu": | |
| self._predictor.to("cpu") | |
| self._predictor_device = torch.device("cpu") | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def _make_output_stem(self, input_path: Path) -> str: | |
| return f"{input_path.stem}-{_now_ms()}-{uuid.uuid4().hex[:8]}" | |
| def predict_to_ply(self, image_path: str | Path) -> PredictionOutputs: | |
| """Run SHARP inference and export a .ply file.""" | |
| image_path = Path(image_path) | |
| if not image_path.exists(): | |
| raise FileNotFoundError(f"Image does not exist: {image_path}") | |
| device = _select_device(self.device_preference) | |
| predictor = self._get_predictor(device) | |
| image_np, _, f_px = io.load_rgb(image_path) | |
| height, width = image_np.shape[:2] | |
| with torch.no_grad(): | |
| gaussians = predict_image(predictor, image_np, f_px, device) | |
| stem = self._make_output_stem(image_path) | |
| ply_path = self.outputs_dir / f"{stem}.ply" | |
| # save_ply expects (height, width). | |
| save_ply(gaussians, f_px, (height, width), ply_path) | |
| self._maybe_move_model_back_to_cpu() | |
| return PredictionOutputs(ply_path=ply_path) | |
| # ----------------------------------------------------------------------------- | |
| # ZeroGPU entrypoints | |
| # ----------------------------------------------------------------------------- | |
| # | |
| # IMPORTANT: Do NOT decorate bound instance methods with `@spaces.GPU` on ZeroGPU. | |
| # The wrapper uses multiprocessing queues and pickles args/kwargs. If `self` is | |
| # included, Python will try to pickle the whole instance. ModelWrapper contains | |
| # a threading.RLock (not pickleable) and the model itself should not be pickled. | |
| # | |
| # Expose module-level functions that accept only pickleable arguments and | |
| # create/cache the ModelWrapper inside the GPU worker process. | |
| DEFAULT_OUTPUTS_DIR: Final[Path] = _ensure_dir(Path(__file__).resolve().parent / "outputs") | |
| _GLOBAL_MODEL: ModelWrapper | None = None | |
| _GLOBAL_MODEL_INIT_LOCK: Final[threading.Lock] = threading.Lock() | |
| def get_global_model(*, outputs_dir: str | Path = DEFAULT_OUTPUTS_DIR) -> ModelWrapper: | |
| global _GLOBAL_MODEL | |
| with _GLOBAL_MODEL_INIT_LOCK: | |
| if _GLOBAL_MODEL is None: | |
| _GLOBAL_MODEL = ModelWrapper(outputs_dir=outputs_dir) | |
| return _GLOBAL_MODEL | |
| def predict_to_ply( | |
| image_path: str | Path, | |
| ) -> Path: | |
| model = get_global_model() | |
| return model.predict_to_ply(image_path).ply_path | |
| # Export the GPU-wrapped callable (or a no-op wrapper locally). | |
| if spaces is not None: | |
| predict_to_ply_gpu = spaces.GPU(duration=180)(predict_to_ply) | |
| else: # pragma: no cover | |
| predict_to_ply_gpu = predict_to_ply | |