Spaces:
Running on Zero
Running on Zero
| from __future__ import annotations | |
| import os | |
| import sys | |
| from pathlib import Path | |
| def _prepend_cuda_library_path() -> None: | |
| """Expose PyTorch/CUDA shared libraries before gsplat loads CUDA extensions.""" | |
| search_roots = [Path(entry) for entry in sys.path if entry] | |
| rel_candidates = ( | |
| "torch/lib", | |
| "nvidia/cuda_runtime/lib", | |
| "nvidia/cudnn/lib", | |
| "nvidia/cublas/lib", | |
| ) | |
| prepend: list[str] = [] | |
| seen: set[str] = set() | |
| for root in search_roots: | |
| for rel in rel_candidates: | |
| lib_dir = (root / rel).resolve() | |
| if not lib_dir.is_dir() or not any(lib_dir.glob("lib*.so*")): | |
| continue | |
| key = str(lib_dir) | |
| if key in seen: | |
| continue | |
| seen.add(key) | |
| prepend.append(key) | |
| if not prepend: | |
| return | |
| current = os.environ.get("LD_LIBRARY_PATH", "") | |
| os.environ["LD_LIBRARY_PATH"] = os.pathsep.join(prepend + ([current] if current else [])) | |
| _prepend_cuda_library_path() | |
| import argparse | |
| import shutil | |
| import traceback | |
| import uuid | |
| from typing import Any, Callable | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from PIL import Image, ImageOps | |
| ROOT = Path(__file__).resolve().parent | |
| sys.path.insert(0, str(ROOT)) | |
| sys.path.insert(0, str(ROOT / "UniK3D")) | |
| os.environ.setdefault("SHARP_LOAD_UNIK3D_PRETRAINED", "0") | |
| os.environ.setdefault("TORCH_HOME", str(ROOT / "checkpoints" / "torchhub")) | |
| os.environ.setdefault("HF_HOME", str(ROOT / "checkpoints" / "huggingface")) | |
| os.environ.setdefault("HF_HUB_CACHE", str(ROOT / "checkpoints" / "huggingface")) | |
| os.environ.setdefault("HUGGINGFACE_HUB_CACHE", str(ROOT / "checkpoints" / "huggingface")) | |
| os.environ.setdefault("TORCH_EXTENSIONS_DIR", "/tmp/unisharp_torch_extensions") | |
| try: | |
| import spaces # type: ignore | |
| gpu: Callable[..., Callable[[Callable[..., Any]], Callable[..., Any]]] = spaces.GPU | |
| except Exception: | |
| def gpu(*_: Any, **__: Any) -> Callable[[Callable[..., Any]], Callable[..., Any]]: | |
| def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: | |
| return fn | |
| return decorator | |
| from scripts import infer_unisharp as infer # noqa: E402 | |
| MOUNTED_CHECKPOINT_PATH = Path( | |
| os.environ.get("UNISHARP_MOUNTED_CHECKPOINT", "/models/unisharpdemo/checkpoints") | |
| ) | |
| LOCAL_CHECKPOINT_PATH = Path( | |
| os.environ.get("UNISHARP_CHECKPOINT", str(ROOT / "checkpoints" / "unisharp.pt")) | |
| ) | |
| OUTPUT_ROOT = Path(os.environ.get("UNISHARP_OUTPUT_DIR", "/tmp/unisharp_outputs")) | |
| STABLE_INPUT_PATH = Path(os.environ.get("UNISHARP_STABLE_INPUT", "/tmp/unisharp_current_input.png")) | |
| EXAMPLE_REPLICA_DIR = ROOT / "examples" / "replica" | |
| EXAMPLE_OMNIROOMS_DIR = ROOT / "examples" / "omnirooms" | |
| EXAMPLE_PERSPECTIVE_DIR = ROOT / "examples" / "perspective" | |
| PANORAMA_EXAMPLE_NAMES = [ | |
| "replica_apartment_0_0004_g01189", | |
| "replica_office_3_0000_g01465", | |
| "replica_apartment_0_0004_g00631", | |
| "AI_vol4_05_middle_source.jpg", | |
| "replica_apartment_1_0000_g01727", | |
| "AI_vol4_02_middle_source.jpg", | |
| "replica_apartment_1_0000_g01186", | |
| "AI_vol4_03_last_source.jpg", | |
| "replica_apartment_1_0000_g01402", | |
| "AI_vol4_02_random_source.jpg", | |
| "replica_apartment_0_0000_g01456", | |
| "AI_vol4_04_random_source.jpg", | |
| "AIUE_V01_004_random_source.jpg", | |
| ] | |
| PERSPECTIVE_EXAMPLE_NAMES = [ | |
| "insta2.png", | |
| "dl3dv_9K_g00136.png", | |
| "Gemini_Generated_Image_8kdfrl8kdfrl8kdf.png", | |
| "wildrgbd_scene_202_g00205.png", | |
| "Gemini_Generated_Image_uxu9zwuxu9zwuxu9.png", | |
| "dl3dv_3K_g00360.png", | |
| "dl3dv_5K_g00762.png", | |
| "dl3dv_8K_g01501.png", | |
| "wildrgbd_scene_084_g00241.png", | |
| "wildrgbd_scene_272_g00434.png", | |
| "wildrgbd_scene_282_g00213.png", | |
| "wildrgbd_scene_284_g00000.png", | |
| "Gemini_Generated_Image_5bd8lc5bd8lc5bd8.png", | |
| "dl3dv_2K_g01989.png", | |
| ] | |
| DEFAULT_PERSPECTIVE_MAX_LONG_EDGE = 768 | |
| DEFAULT_PANORAMA_MAX_LONG_EDGE = 1536 | |
| DEFAULT_ORBIT_VIEWS = 10 | |
| DEFAULT_FORWARD_VIEWS = 10 | |
| DEFAULT_ORBIT_RADIUS_M = 0.10 | |
| DEFAULT_FORWARD_DISTANCE_M = 0.20 | |
| DEFAULT_GIF_DURATION_MS = 300 | |
| DEFAULT_LOW_PASS_FILTER_EPS = 0.0 | |
| Image.MAX_IMAGE_PIXELS = 50_000_000 | |
| _MODEL = None | |
| _STEP = 0 | |
| _MODEL_DEVICE = "" | |
| _MODEL_CHECKPOINT = "" | |
| _RENDERER = None | |
| _TRAIN_RENDERER = None | |
| def _device() -> torch.device: | |
| return torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| def _checkpoint_path() -> Path: | |
| candidates = [ | |
| MOUNTED_CHECKPOINT_PATH, | |
| Path("/models/unisharpdemo/checkpoints"), | |
| Path("/models/unisharpdemo/checkpoints.pt"), | |
| LOCAL_CHECKPOINT_PATH, | |
| ] | |
| checked: list[str] = [] | |
| for candidate in candidates: | |
| if str(candidate) in checked: | |
| continue | |
| checked.append(str(candidate)) | |
| if candidate.is_file(): | |
| return candidate | |
| if candidate.is_dir(): | |
| checkpoint_files = sorted(candidate.glob("*.pt")) | |
| if checkpoint_files: | |
| return checkpoint_files[0] | |
| raise FileNotFoundError( | |
| "UniSHARP checkpoint is unavailable. Expected a mounted checkpoint at " | |
| f"{MOUNTED_CHECKPOINT_PATH} or a local checkpoint at {LOCAL_CHECKPOINT_PATH}." | |
| ) | |
| def _runtime() -> tuple[Any, Any, Any, int, torch.device, Path]: | |
| global _MODEL, _STEP, _MODEL_DEVICE, _MODEL_CHECKPOINT, _RENDERER, _TRAIN_RENDERER | |
| device = _device() | |
| checkpoint_path = _checkpoint_path() | |
| if _MODEL is not None and _MODEL_DEVICE == str(device) and _MODEL_CHECKPOINT == str(checkpoint_path): | |
| return _MODEL, _RENDERER, _TRAIN_RENDERER, int(_STEP), device, checkpoint_path | |
| infer._configure_torchhub_cache() | |
| model, step = infer._load_model(checkpoint_path, device=device) | |
| renderer = infer.GSplatRenderer( | |
| color_space="sRGB", | |
| background_color="black", | |
| low_pass_filter_eps=DEFAULT_LOW_PASS_FILTER_EPS, | |
| ).to(device) | |
| train_renderer = infer.UnifiedTrainer( | |
| model=model, | |
| renderer=renderer, | |
| loss_fn=None, | |
| device=device, | |
| ) | |
| _MODEL = model | |
| _STEP = int(step) | |
| _MODEL_DEVICE = str(device) | |
| _MODEL_CHECKPOINT = str(checkpoint_path) | |
| _RENDERER = renderer | |
| _TRAIN_RENDERER = train_renderer | |
| return model, renderer, train_renderer, int(step), device, checkpoint_path | |
| def _make_args(checkpoint_path: Path) -> argparse.Namespace: | |
| return argparse.Namespace( | |
| checkpoint=checkpoint_path, | |
| camera="auto", | |
| camera_json=None, | |
| _camera_json_data=None, | |
| camera_intrinsics=None, | |
| camera_params=None, | |
| save_ply=True, | |
| low_pass_filter_eps=DEFAULT_LOW_PASS_FILTER_EPS, | |
| ) | |
| def _configure_demo_constants() -> None: | |
| infer.PANORAMA_ASPECT_MIN = 1.9 | |
| infer.PANORAMA_ASPECT_MAX = 2.1 | |
| infer.VIEW_MOTION_NEAR_SCENE_DEPTH_M = 2.0 | |
| infer.VIEW_MOTION_MIN_SCALE = 0.08 | |
| infer.VIEW_MOTION_FORWARD_DEPTH_FRAC = 0.04 | |
| infer.VIEW_MOTION_ROTATE_DEPTH_FRAC = 0.02 | |
| infer.VIEW_MOTION_FAR_SCENE_MEDIAN_M = 2.5 | |
| infer.VIEW_MOTION_FOREGROUND_DEPTH_QUANTILE = 0.20 | |
| infer.PERSPECTIVE_MAX_LONG_EDGE = DEFAULT_PERSPECTIVE_MAX_LONG_EDGE | |
| infer.PANORAMA_MAX_LONG_EDGE = DEFAULT_PANORAMA_MAX_LONG_EDGE | |
| infer.ROTATE_VIEWS = DEFAULT_ORBIT_VIEWS | |
| infer.ROTATE_RADIUS_M = DEFAULT_ORBIT_RADIUS_M | |
| infer.FORWARD_VIEWS = DEFAULT_FORWARD_VIEWS | |
| infer.FORWARD_DISTANCE_M = DEFAULT_FORWARD_DISTANCE_M | |
| infer.GIF_DURATION_MS = DEFAULT_GIF_DURATION_MS | |
| def _cuda_library_dirs() -> list[str]: | |
| dirs: list[str] = [] | |
| seen: set[str] = set() | |
| torch_lib = Path(torch.__file__).resolve().parent / "lib" | |
| if torch_lib.is_dir(): | |
| key = str(torch_lib) | |
| seen.add(key) | |
| dirs.append(key) | |
| try: | |
| import nvidia.cuda_runtime # type: ignore | |
| cuda_rt = Path(nvidia.cuda_runtime.__file__).resolve().parent / "lib" | |
| if cuda_rt.is_dir(): | |
| key = str(cuda_rt) | |
| if key not in seen: | |
| seen.add(key) | |
| dirs.append(key) | |
| except ImportError: | |
| pass | |
| return dirs | |
| def _resolve_cuda_home() -> str | None: | |
| try: | |
| import nvidia.cuda_nvcc # type: ignore | |
| nvcc_root = Path(nvidia.cuda_nvcc.__file__).resolve().parent | |
| if nvcc_root.is_dir(): | |
| return str(nvcc_root) | |
| except ImportError: | |
| pass | |
| from torch.utils import cpp_extension | |
| cuda_home = cpp_extension.CUDA_HOME | |
| return str(cuda_home) if cuda_home else None | |
| def _disable_incompatible_gsplat_binary() -> None: | |
| try: | |
| import gsplat # type: ignore | |
| csrc = Path(gsplat.__file__).resolve().parent / "csrc.so" | |
| if csrc.exists(): | |
| backup = csrc.with_suffix(".so.incompatible") | |
| if not backup.exists(): | |
| csrc.rename(backup) | |
| except Exception: | |
| pass | |
| def _prebuilt_gsplat_so_path() -> Path: | |
| return Path(os.environ["TORCH_EXTENSIONS_DIR"]) / "py310_cu128" / "gsplat_cuda" / "gsplat_cuda.so" | |
| def _install_prebuilt_gsplat_extension() -> Path | None: | |
| """Reuse gsplat_cuda.so built for torch 2.8 + cu128 + py310.""" | |
| mounted = Path("/models/unisharpdemo/gsplat_cuda/py310_cu128/gsplat_cuda.so") | |
| target = _prebuilt_gsplat_so_path() | |
| if not mounted.is_file(): | |
| return target if target.is_file() else None | |
| target.parent.mkdir(parents=True, exist_ok=True) | |
| if not target.exists() or mounted.stat().st_mtime > target.stat().st_mtime: | |
| shutil.copy2(mounted, target) | |
| return target | |
| def _configure_gsplat_cuda() -> None: | |
| """Align gsplat with PyTorch CUDA 12.8 on ZeroGPU (avoid system CUDA 13).""" | |
| _disable_incompatible_gsplat_binary() | |
| cuda_home = _resolve_cuda_home() | |
| if cuda_home: | |
| os.environ["CUDA_HOME"] = cuda_home | |
| lib_dirs = _cuda_library_dirs() | |
| if lib_dirs: | |
| current = os.environ.get("LD_LIBRARY_PATH", "") | |
| merged = os.pathsep.join(lib_dirs + ([current] if current else [])) | |
| os.environ["LD_LIBRARY_PATH"] = merged | |
| _install_prebuilt_gsplat_extension() | |
| def _warmup_gsplat_cuda() -> None: | |
| _configure_gsplat_cuda() | |
| ext_so = _install_prebuilt_gsplat_extension() | |
| if ext_so is None or not ext_so.is_file(): | |
| raise RuntimeError( | |
| "Prebuilt gsplat_cuda.so is unavailable. Expected " | |
| f"{_prebuilt_gsplat_so_path()} or " | |
| "/models/unisharpdemo/gsplat_cuda/py310_cu128/gsplat_cuda.so" | |
| ) | |
| import importlib.util | |
| spec = importlib.util.spec_from_file_location("gsplat_cuda", ext_so) | |
| if spec is None or spec.loader is None: | |
| raise ImportError(f"Cannot load gsplat extension from {ext_so}") | |
| module = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(module) | |
| # gsplat.cuda._backend does `from gsplat import csrc as _C` at import time. | |
| # Register the prebuilt module before importing _backend to avoid JIT on ZeroGPU. | |
| sys.modules["gsplat.csrc"] = module | |
| sys.modules["gsplat_cuda"] = module | |
| import gsplat.cuda._backend as backend # noqa: F401 | |
| if backend._C is None: | |
| raise ImportError("gsplat CUDA backend failed to initialize.") | |
| def _resolve_image_path( | |
| image_path: str | dict[str, Any] | np.ndarray | Image.Image | None, | |
| ) -> Path: | |
| if isinstance(image_path, dict): | |
| image_path = image_path.get("path") or image_path.get("name") | |
| if isinstance(image_path, Image.Image): | |
| out_path = OUTPUT_ROOT / "selected_input.png" | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| image_path.convert("RGB").save(out_path) | |
| return out_path | |
| if isinstance(image_path, np.ndarray): | |
| out_path = OUTPUT_ROOT / "selected_input.png" | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| Image.fromarray(image_path).convert("RGB").save(out_path) | |
| return out_path | |
| if not image_path: | |
| raise gr.Error("Please upload an image or select an example.") | |
| example_path = Path(image_path) | |
| if not example_path.exists(): | |
| raise gr.Error(f"Selected image is unavailable: {example_path}") | |
| return example_path | |
| def _run_unisharp_once( | |
| image_path: str | None, | |
| progress: gr.Progress | None = None, | |
| ) -> tuple[str | None, str | None, str | None]: | |
| example_path = _resolve_image_path(image_path) | |
| if progress is not None: | |
| progress(0.03, desc="Preparing input image") | |
| request_id = uuid.uuid4().hex[:10] | |
| out_root = OUTPUT_ROOT / request_id | |
| out_root.mkdir(parents=True, exist_ok=True) | |
| input_path = out_root / "input.png" | |
| ImageOps.exif_transpose(Image.open(example_path)).convert("RGB").save(input_path) | |
| try: | |
| if progress is not None: | |
| progress(0.08, desc="Initializing gsplat CUDA") | |
| _warmup_gsplat_cuda() | |
| if progress is not None: | |
| progress(0.12, desc="Loading UniSHARP") | |
| _configure_demo_constants() | |
| model, renderer, train_renderer, step, device, checkpoint_path = _runtime() | |
| args = _make_args(checkpoint_path) | |
| if progress is not None: | |
| progress(0.35, desc="Running inference and rendering views") | |
| infer._process_one( | |
| model=model, | |
| renderer=renderer, | |
| train_renderer=train_renderer, | |
| image_path=input_path, | |
| out_root=out_root, | |
| step=int(step), | |
| args=args, | |
| ) | |
| sample_dir = out_root / infer._slug_from_path(input_path) | |
| orbit = sample_dir / "rotate.gif" | |
| forward = sample_dir / "forward.gif" | |
| ply = sample_dir / "gaussians.ply" | |
| if progress is not None: | |
| progress(0.95, desc="Preparing outputs") | |
| return ( | |
| str(forward) if forward.exists() else None, | |
| str(orbit) if orbit.exists() else None, | |
| str(ply) if ply.exists() else None, | |
| ) | |
| except Exception as exc: | |
| traceback.print_exc() | |
| shutil.rmtree(out_root, ignore_errors=True) | |
| raise gr.Error(str(exc)) from exc | |
| finally: | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def run_unisharp( | |
| image_path: str | None, | |
| progress: gr.Progress = gr.Progress(track_tqdm=True), | |
| ): | |
| return _run_unisharp_once(image_path, progress=progress) | |
| def _panorama_example_path(name: str) -> Path: | |
| if name.lower().endswith((".jpg", ".jpeg", ".webp")): | |
| return EXAMPLE_OMNIROOMS_DIR / name | |
| return EXAMPLE_REPLICA_DIR / f"{name}.png" | |
| def _panorama_example_paths() -> list[Path]: | |
| return [path for name in PANORAMA_EXAMPLE_NAMES if (path := _panorama_example_path(name)).is_file()] | |
| def _perspective_example_paths() -> list[Path]: | |
| return [EXAMPLE_PERSPECTIVE_DIR / name for name in PERSPECTIVE_EXAMPLE_NAMES if (EXAMPLE_PERSPECTIVE_DIR / name).is_file()] | |
| def _example_paths() -> list[Path]: | |
| return _panorama_example_paths() + _perspective_example_paths() | |
| def _demo_gallery_items(paths: list[Path]) -> list[str]: | |
| return [str(path) for path in paths] | |
| def _normalize_image_input(image: Any) -> str | np.ndarray | Image.Image | None: | |
| if image is None: | |
| return None | |
| if isinstance(image, Path): | |
| return str(image) | |
| if isinstance(image, str): | |
| return image | |
| if isinstance(image, (np.ndarray, Image.Image)): | |
| return image | |
| if isinstance(image, dict): | |
| path = image.get("path") or image.get("url") or image.get("name") | |
| return str(path) if path else None | |
| if isinstance(image, (list, tuple)): | |
| return _normalize_image_input(image[0]) if image else None | |
| if hasattr(image, "model_dump"): | |
| try: | |
| data = image.model_dump() | |
| if isinstance(data, dict): | |
| return _normalize_image_input(data.get("path") or data.get("url")) | |
| except Exception: | |
| pass | |
| path = getattr(image, "path", None) | |
| if path: | |
| return str(path) | |
| url = getattr(image, "url", None) | |
| if url: | |
| return str(url) | |
| return None | |
| def _load_example_image(path: Path) -> tuple[np.ndarray | None, str | None]: | |
| src = Path(path) | |
| if not src.is_file(): | |
| return None, None | |
| STABLE_INPUT_PATH.parent.mkdir(parents=True, exist_ok=True) | |
| shutil.copy2(src, STABLE_INPUT_PATH) | |
| pil = ImageOps.exif_transpose(Image.open(STABLE_INPUT_PATH)).convert("RGB") | |
| pil.save(STABLE_INPUT_PATH) | |
| return np.array(pil), str(STABLE_INPUT_PATH) | |
| def _persist_image_to_stable( | |
| image: str | dict[str, Any] | np.ndarray | Image.Image | None, | |
| ) -> tuple[np.ndarray | None, str | None]: | |
| """Copy uploads/previews to a stable path so Gradio temp files can expire safely.""" | |
| image = _normalize_image_input(image) | |
| if image is None: | |
| return None, None | |
| if isinstance(image, Path): | |
| image = str(image) | |
| STABLE_INPUT_PATH.parent.mkdir(parents=True, exist_ok=True) | |
| if isinstance(image, str): | |
| src = Path(image) | |
| if not src.exists(): | |
| raise gr.Error(f"Selected image is unavailable: {src}") | |
| shutil.copy2(src, STABLE_INPUT_PATH) | |
| elif isinstance(image, np.ndarray): | |
| Image.fromarray(image).convert("RGB").save(STABLE_INPUT_PATH) | |
| elif isinstance(image, Image.Image): | |
| image.convert("RGB").save(STABLE_INPUT_PATH) | |
| else: | |
| raise gr.Error("Unsupported image input.") | |
| pil = ImageOps.exif_transpose(Image.open(STABLE_INPUT_PATH)).convert("RGB") | |
| pil.save(STABLE_INPUT_PATH) | |
| return np.array(pil), str(STABLE_INPUT_PATH) | |
| def _make_gallery_select_handler(paths: list[Path], *, columns: int = 6): | |
| def _handler(evt: gr.SelectData) -> tuple[np.ndarray | None, str | None]: | |
| if not paths: | |
| return None, None | |
| index = getattr(evt, "index", 0) | |
| if isinstance(index, (list, tuple)): | |
| if len(index) >= 2: | |
| flat_index = int(index[0]) * columns + int(index[1]) | |
| else: | |
| flat_index = int(index[0]) if index else 0 | |
| else: | |
| flat_index = int(index) | |
| flat_index = max(0, min(flat_index, len(paths) - 1)) | |
| return _load_example_image(paths[flat_index]) | |
| return _handler | |
| CSS = """ | |
| :root { | |
| --unisharp-orange: #f97316; | |
| --unisharp-orange-soft: rgba(249, 115, 22, 0.12); | |
| } | |
| .gradio-container { | |
| max-width: none !important; | |
| width: 100% !important; | |
| padding: 24px 36px 36px !important; | |
| } | |
| .resource-list { | |
| background: var(--block-background-fill); | |
| border: 1px solid var(--border-color-primary); | |
| border-radius: 18px; | |
| box-shadow: var(--block-shadow); | |
| margin-bottom: 18px; | |
| padding: 22px 26px; | |
| } | |
| .resource-list h1, | |
| .resource-list h2, | |
| .resource-list strong { | |
| color: var(--unisharp-orange); | |
| } | |
| .resource-list a { | |
| color: var(--unisharp-orange) !important; | |
| } | |
| .resource-list ul { margin-top: 10px; } | |
| .resource-list li { margin: 5px 0; } | |
| #main_row { | |
| align-items: stretch; | |
| gap: 22px; | |
| } | |
| #run_button { | |
| min-height: 48px; | |
| margin: 12px 0 14px 0; | |
| } | |
| .panel-card { | |
| background: var(--block-background-fill); | |
| border: 1px solid var(--border-color-primary); | |
| border-radius: 18px; | |
| box-shadow: var(--block-shadow); | |
| padding: 14px; | |
| } | |
| .output-card .label-wrap, | |
| .input-card .label-wrap { | |
| background: var(--unisharp-orange-soft) !important; | |
| color: var(--unisharp-orange) !important; | |
| border-radius: 999px !important; | |
| padding: 4px 10px !important; | |
| } | |
| .example-panel { | |
| background: var(--block-background-fill) !important; | |
| border: 1px solid var(--border-color-primary) !important; | |
| border-radius: 14px !important; | |
| box-shadow: var(--block-shadow) !important; | |
| padding: 10px 12px 12px 12px !important; | |
| margin-top: 10px !important; | |
| } | |
| .example-panel-title, | |
| .example-panel-title p { | |
| color: var(--unisharp-orange) !important; | |
| font-size: 0.95rem !important; | |
| font-weight: 600 !important; | |
| margin: 0 0 8px 0 !important; | |
| padding: 0 !important; | |
| } | |
| #demo_gallery_panorama .label-wrap, | |
| #demo_gallery_perspective .label-wrap { | |
| display: none !important; | |
| } | |
| #demo_gallery_panorama_scroll, | |
| #demo_gallery_perspective_scroll { | |
| height: 175px !important; | |
| max-height: 175px !important; | |
| overflow-y: auto !important; | |
| overflow-x: hidden !important; | |
| scrollbar-gutter: stable; | |
| } | |
| #demo_gallery_panorama, | |
| #demo_gallery_panorama > .wrap, | |
| #demo_gallery_panorama > .wrap > div, | |
| #demo_gallery_panorama .grid-wrap, | |
| #demo_gallery_panorama .gallery, | |
| #demo_gallery_panorama [data-testid="gallery"], | |
| #demo_gallery_panorama [role="grid"], | |
| #demo_gallery_perspective, | |
| #demo_gallery_perspective > .wrap, | |
| #demo_gallery_perspective > .wrap > div, | |
| #demo_gallery_perspective .grid-wrap, | |
| #demo_gallery_perspective .gallery, | |
| #demo_gallery_perspective [data-testid="gallery"], | |
| #demo_gallery_perspective [role="grid"] { | |
| height: auto !important; | |
| max-height: none !important; | |
| overflow: visible !important; | |
| } | |
| #demo_gallery_panorama .caption, | |
| #demo_gallery_perspective .caption { display: none !important; } | |
| #demo_gallery_panorama [role="gridcell"], | |
| #demo_gallery_panorama .thumbnail-item, | |
| #demo_gallery_perspective [role="gridcell"], | |
| #demo_gallery_perspective .thumbnail-item { | |
| aspect-ratio: 16 / 9 !important; | |
| } | |
| #demo_gallery_panorama img, | |
| #demo_gallery_perspective img { | |
| min-height: 64px !important; | |
| max-height: 82px !important; | |
| object-fit: cover !important; | |
| } | |
| /* Example galleries: interactive=False (select only). Hide any leftover upload UI. */ | |
| #demo_gallery_panorama .icon-button, | |
| #demo_gallery_perspective .icon-button, | |
| #demo_gallery_panorama .upload-container, | |
| #demo_gallery_perspective .upload-container, | |
| #demo_gallery_panorama [data-testid="modify-upload"], | |
| #demo_gallery_perspective [data-testid="modify-upload"] { | |
| display: none !important; | |
| } | |
| #main_row > .gr-column { | |
| flex: 1 1 0 !important; | |
| min-width: 0 !important; | |
| } | |
| #outputs_panel .output-card { | |
| margin-bottom: 10px !important; | |
| } | |
| /* Hide fullscreen only; do not hide image buttons (Gradio renders GIF inside them). */ | |
| .input-card button[aria-label*="ullscreen"], | |
| .input-card button[aria-label*="Expand"], | |
| .input-card button[aria-label*="expand"] { | |
| display: none !important; | |
| } | |
| """ | |
| INTRO = """ | |
| # UniSHARP | |
| Predicts a 3D Gaussian point cloud from a single image across diverse camera models, enabling high-quality novel view synthesis. | |
| Here are our resources: | |
| - **π» Code:** https://github.com/Insta360-Research-Team/UniSHARP | |
| - **π Web Page:** https://insta360-research-team.github.io/Unisharp-website/ | |
| - **π Paper:** (coming soon) | |
| - **π¦ Dataset:** https://huggingface.co/datasets/Insta360-Research/OmniRooms | |
| - **π€ Demo:** https://huggingface.co/spaces/Insta360-Research/UniSHARP | |
| """ | |
| THEME = gr.themes.Soft(primary_hue="orange", secondary_hue="orange") | |
| with gr.Blocks(title="UniSHARP", css=CSS, theme=THEME) as demo: | |
| gr.Markdown(INTRO, elem_classes=["resource-list"]) | |
| panorama_paths = _panorama_example_paths() | |
| perspective_paths = _perspective_example_paths() | |
| default_state: str | None = None | |
| default_example: np.ndarray | None = None | |
| if panorama_paths: | |
| default_example, default_state = _load_example_image(panorama_paths[0]) | |
| image_state = gr.State(value=default_state) | |
| with gr.Row(equal_height=False, elem_id="main_row"): | |
| with gr.Column(scale=1, elem_classes=["panel-card"]): | |
| image_preview = gr.Image( | |
| label="Input Image", | |
| value=default_example, | |
| type="numpy", | |
| sources=["upload"], | |
| height=360, | |
| interactive=True, | |
| show_download_button=False, | |
| show_share_button=False, | |
| show_fullscreen_button=False, | |
| elem_classes=["input-card"], | |
| ) | |
| run_button = gr.Button("Run", variant="primary", elem_id="run_button") | |
| with gr.Column(elem_classes=["example-panel"], elem_id="example_panel_panorama"): | |
| gr.Markdown("Example (Panorama)", elem_classes=["example-panel-title"]) | |
| with gr.Column(elem_id="demo_gallery_panorama_scroll"): | |
| demo_gallery_panorama = gr.Gallery( | |
| value=_demo_gallery_items(panorama_paths), | |
| label="Example (Panorama)", | |
| show_label=False, | |
| columns=6, | |
| rows=None, | |
| allow_preview=False, | |
| preview=False, | |
| object_fit="cover", | |
| interactive=False, | |
| type="filepath", | |
| show_download_button=False, | |
| show_share_button=False, | |
| show_fullscreen_button=False, | |
| elem_id="demo_gallery_panorama", | |
| ) | |
| with gr.Column(elem_classes=["example-panel"], elem_id="example_panel_perspective"): | |
| gr.Markdown("Example (Perspective)", elem_classes=["example-panel-title"]) | |
| with gr.Column(elem_id="demo_gallery_perspective_scroll"): | |
| demo_gallery_perspective = gr.Gallery( | |
| value=_demo_gallery_items(perspective_paths), | |
| label="Example (Perspective)", | |
| show_label=False, | |
| columns=6, | |
| rows=None, | |
| allow_preview=False, | |
| preview=False, | |
| object_fit="cover", | |
| interactive=False, | |
| type="filepath", | |
| show_download_button=False, | |
| show_share_button=False, | |
| show_fullscreen_button=False, | |
| elem_id="demo_gallery_perspective", | |
| ) | |
| with gr.Column(scale=1, elem_classes=["panel-card"], elem_id="outputs_panel"): | |
| forward_out = gr.Image( | |
| label="Forward View", | |
| type="filepath", | |
| sources=[], | |
| height=350, | |
| interactive=False, | |
| show_download_button=False, | |
| show_share_button=False, | |
| show_fullscreen_button=False, | |
| elem_classes=["output-card"], | |
| ) | |
| orbit_out = gr.Image( | |
| label="Orbit View", | |
| type="filepath", | |
| sources=[], | |
| height=350, | |
| interactive=False, | |
| show_download_button=False, | |
| show_share_button=False, | |
| show_fullscreen_button=False, | |
| elem_classes=["output-card"], | |
| ) | |
| ply_out = gr.File(label="Gaussian PLY") | |
| image_preview.upload( | |
| fn=_persist_image_to_stable, | |
| inputs=[image_preview], | |
| outputs=[image_preview, image_state], | |
| show_progress="hidden", | |
| queue=False, | |
| ) | |
| demo_gallery_panorama.select( | |
| fn=_make_gallery_select_handler(panorama_paths), | |
| inputs=None, | |
| outputs=[image_preview, image_state], | |
| show_progress="hidden", | |
| queue=False, | |
| ) | |
| demo_gallery_perspective.select( | |
| fn=_make_gallery_select_handler(perspective_paths), | |
| inputs=None, | |
| outputs=[image_preview, image_state], | |
| show_progress="hidden", | |
| queue=False, | |
| ) | |
| run_button.click( | |
| fn=run_unisharp, | |
| inputs=[image_state], | |
| outputs=[forward_out, orbit_out, ply_out], | |
| show_progress="full", | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=8).launch() | |