UniSHARP / app.py
Insta360-Research's picture
Move perspective example #2 to position #4
8d7b878 verified
from __future__ import annotations
import os
import sys
from pathlib import Path
def _prepend_cuda_library_path() -> None:
"""Expose PyTorch/CUDA shared libraries before gsplat loads CUDA extensions."""
search_roots = [Path(entry) for entry in sys.path if entry]
rel_candidates = (
"torch/lib",
"nvidia/cuda_runtime/lib",
"nvidia/cudnn/lib",
"nvidia/cublas/lib",
)
prepend: list[str] = []
seen: set[str] = set()
for root in search_roots:
for rel in rel_candidates:
lib_dir = (root / rel).resolve()
if not lib_dir.is_dir() or not any(lib_dir.glob("lib*.so*")):
continue
key = str(lib_dir)
if key in seen:
continue
seen.add(key)
prepend.append(key)
if not prepend:
return
current = os.environ.get("LD_LIBRARY_PATH", "")
os.environ["LD_LIBRARY_PATH"] = os.pathsep.join(prepend + ([current] if current else []))
_prepend_cuda_library_path()
import argparse
import shutil
import traceback
import uuid
from typing import Any, Callable
import gradio as gr
import numpy as np
import torch
from PIL import Image, ImageOps
ROOT = Path(__file__).resolve().parent
sys.path.insert(0, str(ROOT))
sys.path.insert(0, str(ROOT / "UniK3D"))
os.environ.setdefault("SHARP_LOAD_UNIK3D_PRETRAINED", "0")
os.environ.setdefault("TORCH_HOME", str(ROOT / "checkpoints" / "torchhub"))
os.environ.setdefault("HF_HOME", str(ROOT / "checkpoints" / "huggingface"))
os.environ.setdefault("HF_HUB_CACHE", str(ROOT / "checkpoints" / "huggingface"))
os.environ.setdefault("HUGGINGFACE_HUB_CACHE", str(ROOT / "checkpoints" / "huggingface"))
os.environ.setdefault("TORCH_EXTENSIONS_DIR", "/tmp/unisharp_torch_extensions")
try:
import spaces # type: ignore
gpu: Callable[..., Callable[[Callable[..., Any]], Callable[..., Any]]] = spaces.GPU
except Exception:
def gpu(*_: Any, **__: Any) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
return fn
return decorator
from scripts import infer_unisharp as infer # noqa: E402
MOUNTED_CHECKPOINT_PATH = Path(
os.environ.get("UNISHARP_MOUNTED_CHECKPOINT", "/models/unisharpdemo/checkpoints")
)
LOCAL_CHECKPOINT_PATH = Path(
os.environ.get("UNISHARP_CHECKPOINT", str(ROOT / "checkpoints" / "unisharp.pt"))
)
OUTPUT_ROOT = Path(os.environ.get("UNISHARP_OUTPUT_DIR", "/tmp/unisharp_outputs"))
STABLE_INPUT_PATH = Path(os.environ.get("UNISHARP_STABLE_INPUT", "/tmp/unisharp_current_input.png"))
EXAMPLE_REPLICA_DIR = ROOT / "examples" / "replica"
EXAMPLE_OMNIROOMS_DIR = ROOT / "examples" / "omnirooms"
EXAMPLE_PERSPECTIVE_DIR = ROOT / "examples" / "perspective"
PANORAMA_EXAMPLE_NAMES = [
"replica_apartment_0_0004_g01189",
"replica_office_3_0000_g01465",
"replica_apartment_0_0004_g00631",
"AI_vol4_05_middle_source.jpg",
"replica_apartment_1_0000_g01727",
"AI_vol4_02_middle_source.jpg",
"replica_apartment_1_0000_g01186",
"AI_vol4_03_last_source.jpg",
"replica_apartment_1_0000_g01402",
"AI_vol4_02_random_source.jpg",
"replica_apartment_0_0000_g01456",
"AI_vol4_04_random_source.jpg",
"AIUE_V01_004_random_source.jpg",
]
PERSPECTIVE_EXAMPLE_NAMES = [
"insta2.png",
"dl3dv_9K_g00136.png",
"Gemini_Generated_Image_8kdfrl8kdfrl8kdf.png",
"wildrgbd_scene_202_g00205.png",
"Gemini_Generated_Image_uxu9zwuxu9zwuxu9.png",
"dl3dv_3K_g00360.png",
"dl3dv_5K_g00762.png",
"dl3dv_8K_g01501.png",
"wildrgbd_scene_084_g00241.png",
"wildrgbd_scene_272_g00434.png",
"wildrgbd_scene_282_g00213.png",
"wildrgbd_scene_284_g00000.png",
"Gemini_Generated_Image_5bd8lc5bd8lc5bd8.png",
"dl3dv_2K_g01989.png",
]
DEFAULT_PERSPECTIVE_MAX_LONG_EDGE = 768
DEFAULT_PANORAMA_MAX_LONG_EDGE = 1536
DEFAULT_ORBIT_VIEWS = 10
DEFAULT_FORWARD_VIEWS = 10
DEFAULT_ORBIT_RADIUS_M = 0.10
DEFAULT_FORWARD_DISTANCE_M = 0.20
DEFAULT_GIF_DURATION_MS = 300
DEFAULT_LOW_PASS_FILTER_EPS = 0.0
Image.MAX_IMAGE_PIXELS = 50_000_000
_MODEL = None
_STEP = 0
_MODEL_DEVICE = ""
_MODEL_CHECKPOINT = ""
_RENDERER = None
_TRAIN_RENDERER = None
def _device() -> torch.device:
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def _checkpoint_path() -> Path:
candidates = [
MOUNTED_CHECKPOINT_PATH,
Path("/models/unisharpdemo/checkpoints"),
Path("/models/unisharpdemo/checkpoints.pt"),
LOCAL_CHECKPOINT_PATH,
]
checked: list[str] = []
for candidate in candidates:
if str(candidate) in checked:
continue
checked.append(str(candidate))
if candidate.is_file():
return candidate
if candidate.is_dir():
checkpoint_files = sorted(candidate.glob("*.pt"))
if checkpoint_files:
return checkpoint_files[0]
raise FileNotFoundError(
"UniSHARP checkpoint is unavailable. Expected a mounted checkpoint at "
f"{MOUNTED_CHECKPOINT_PATH} or a local checkpoint at {LOCAL_CHECKPOINT_PATH}."
)
def _runtime() -> tuple[Any, Any, Any, int, torch.device, Path]:
global _MODEL, _STEP, _MODEL_DEVICE, _MODEL_CHECKPOINT, _RENDERER, _TRAIN_RENDERER
device = _device()
checkpoint_path = _checkpoint_path()
if _MODEL is not None and _MODEL_DEVICE == str(device) and _MODEL_CHECKPOINT == str(checkpoint_path):
return _MODEL, _RENDERER, _TRAIN_RENDERER, int(_STEP), device, checkpoint_path
infer._configure_torchhub_cache()
model, step = infer._load_model(checkpoint_path, device=device)
renderer = infer.GSplatRenderer(
color_space="sRGB",
background_color="black",
low_pass_filter_eps=DEFAULT_LOW_PASS_FILTER_EPS,
).to(device)
train_renderer = infer.UnifiedTrainer(
model=model,
renderer=renderer,
loss_fn=None,
device=device,
)
_MODEL = model
_STEP = int(step)
_MODEL_DEVICE = str(device)
_MODEL_CHECKPOINT = str(checkpoint_path)
_RENDERER = renderer
_TRAIN_RENDERER = train_renderer
return model, renderer, train_renderer, int(step), device, checkpoint_path
def _make_args(checkpoint_path: Path) -> argparse.Namespace:
return argparse.Namespace(
checkpoint=checkpoint_path,
camera="auto",
camera_json=None,
_camera_json_data=None,
camera_intrinsics=None,
camera_params=None,
save_ply=True,
low_pass_filter_eps=DEFAULT_LOW_PASS_FILTER_EPS,
)
def _configure_demo_constants() -> None:
infer.PANORAMA_ASPECT_MIN = 1.9
infer.PANORAMA_ASPECT_MAX = 2.1
infer.VIEW_MOTION_NEAR_SCENE_DEPTH_M = 2.0
infer.VIEW_MOTION_MIN_SCALE = 0.08
infer.VIEW_MOTION_FORWARD_DEPTH_FRAC = 0.04
infer.VIEW_MOTION_ROTATE_DEPTH_FRAC = 0.02
infer.VIEW_MOTION_FAR_SCENE_MEDIAN_M = 2.5
infer.VIEW_MOTION_FOREGROUND_DEPTH_QUANTILE = 0.20
infer.PERSPECTIVE_MAX_LONG_EDGE = DEFAULT_PERSPECTIVE_MAX_LONG_EDGE
infer.PANORAMA_MAX_LONG_EDGE = DEFAULT_PANORAMA_MAX_LONG_EDGE
infer.ROTATE_VIEWS = DEFAULT_ORBIT_VIEWS
infer.ROTATE_RADIUS_M = DEFAULT_ORBIT_RADIUS_M
infer.FORWARD_VIEWS = DEFAULT_FORWARD_VIEWS
infer.FORWARD_DISTANCE_M = DEFAULT_FORWARD_DISTANCE_M
infer.GIF_DURATION_MS = DEFAULT_GIF_DURATION_MS
def _cuda_library_dirs() -> list[str]:
dirs: list[str] = []
seen: set[str] = set()
torch_lib = Path(torch.__file__).resolve().parent / "lib"
if torch_lib.is_dir():
key = str(torch_lib)
seen.add(key)
dirs.append(key)
try:
import nvidia.cuda_runtime # type: ignore
cuda_rt = Path(nvidia.cuda_runtime.__file__).resolve().parent / "lib"
if cuda_rt.is_dir():
key = str(cuda_rt)
if key not in seen:
seen.add(key)
dirs.append(key)
except ImportError:
pass
return dirs
def _resolve_cuda_home() -> str | None:
try:
import nvidia.cuda_nvcc # type: ignore
nvcc_root = Path(nvidia.cuda_nvcc.__file__).resolve().parent
if nvcc_root.is_dir():
return str(nvcc_root)
except ImportError:
pass
from torch.utils import cpp_extension
cuda_home = cpp_extension.CUDA_HOME
return str(cuda_home) if cuda_home else None
def _disable_incompatible_gsplat_binary() -> None:
try:
import gsplat # type: ignore
csrc = Path(gsplat.__file__).resolve().parent / "csrc.so"
if csrc.exists():
backup = csrc.with_suffix(".so.incompatible")
if not backup.exists():
csrc.rename(backup)
except Exception:
pass
def _prebuilt_gsplat_so_path() -> Path:
return Path(os.environ["TORCH_EXTENSIONS_DIR"]) / "py310_cu128" / "gsplat_cuda" / "gsplat_cuda.so"
def _install_prebuilt_gsplat_extension() -> Path | None:
"""Reuse gsplat_cuda.so built for torch 2.8 + cu128 + py310."""
mounted = Path("/models/unisharpdemo/gsplat_cuda/py310_cu128/gsplat_cuda.so")
target = _prebuilt_gsplat_so_path()
if not mounted.is_file():
return target if target.is_file() else None
target.parent.mkdir(parents=True, exist_ok=True)
if not target.exists() or mounted.stat().st_mtime > target.stat().st_mtime:
shutil.copy2(mounted, target)
return target
def _configure_gsplat_cuda() -> None:
"""Align gsplat with PyTorch CUDA 12.8 on ZeroGPU (avoid system CUDA 13)."""
_disable_incompatible_gsplat_binary()
cuda_home = _resolve_cuda_home()
if cuda_home:
os.environ["CUDA_HOME"] = cuda_home
lib_dirs = _cuda_library_dirs()
if lib_dirs:
current = os.environ.get("LD_LIBRARY_PATH", "")
merged = os.pathsep.join(lib_dirs + ([current] if current else []))
os.environ["LD_LIBRARY_PATH"] = merged
_install_prebuilt_gsplat_extension()
def _warmup_gsplat_cuda() -> None:
_configure_gsplat_cuda()
ext_so = _install_prebuilt_gsplat_extension()
if ext_so is None or not ext_so.is_file():
raise RuntimeError(
"Prebuilt gsplat_cuda.so is unavailable. Expected "
f"{_prebuilt_gsplat_so_path()} or "
"/models/unisharpdemo/gsplat_cuda/py310_cu128/gsplat_cuda.so"
)
import importlib.util
spec = importlib.util.spec_from_file_location("gsplat_cuda", ext_so)
if spec is None or spec.loader is None:
raise ImportError(f"Cannot load gsplat extension from {ext_so}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
# gsplat.cuda._backend does `from gsplat import csrc as _C` at import time.
# Register the prebuilt module before importing _backend to avoid JIT on ZeroGPU.
sys.modules["gsplat.csrc"] = module
sys.modules["gsplat_cuda"] = module
import gsplat.cuda._backend as backend # noqa: F401
if backend._C is None:
raise ImportError("gsplat CUDA backend failed to initialize.")
def _resolve_image_path(
image_path: str | dict[str, Any] | np.ndarray | Image.Image | None,
) -> Path:
if isinstance(image_path, dict):
image_path = image_path.get("path") or image_path.get("name")
if isinstance(image_path, Image.Image):
out_path = OUTPUT_ROOT / "selected_input.png"
out_path.parent.mkdir(parents=True, exist_ok=True)
image_path.convert("RGB").save(out_path)
return out_path
if isinstance(image_path, np.ndarray):
out_path = OUTPUT_ROOT / "selected_input.png"
out_path.parent.mkdir(parents=True, exist_ok=True)
Image.fromarray(image_path).convert("RGB").save(out_path)
return out_path
if not image_path:
raise gr.Error("Please upload an image or select an example.")
example_path = Path(image_path)
if not example_path.exists():
raise gr.Error(f"Selected image is unavailable: {example_path}")
return example_path
def _run_unisharp_once(
image_path: str | None,
progress: gr.Progress | None = None,
) -> tuple[str | None, str | None, str | None]:
example_path = _resolve_image_path(image_path)
if progress is not None:
progress(0.03, desc="Preparing input image")
request_id = uuid.uuid4().hex[:10]
out_root = OUTPUT_ROOT / request_id
out_root.mkdir(parents=True, exist_ok=True)
input_path = out_root / "input.png"
ImageOps.exif_transpose(Image.open(example_path)).convert("RGB").save(input_path)
try:
if progress is not None:
progress(0.08, desc="Initializing gsplat CUDA")
_warmup_gsplat_cuda()
if progress is not None:
progress(0.12, desc="Loading UniSHARP")
_configure_demo_constants()
model, renderer, train_renderer, step, device, checkpoint_path = _runtime()
args = _make_args(checkpoint_path)
if progress is not None:
progress(0.35, desc="Running inference and rendering views")
infer._process_one(
model=model,
renderer=renderer,
train_renderer=train_renderer,
image_path=input_path,
out_root=out_root,
step=int(step),
args=args,
)
sample_dir = out_root / infer._slug_from_path(input_path)
orbit = sample_dir / "rotate.gif"
forward = sample_dir / "forward.gif"
ply = sample_dir / "gaussians.ply"
if progress is not None:
progress(0.95, desc="Preparing outputs")
return (
str(forward) if forward.exists() else None,
str(orbit) if orbit.exists() else None,
str(ply) if ply.exists() else None,
)
except Exception as exc:
traceback.print_exc()
shutil.rmtree(out_root, ignore_errors=True)
raise gr.Error(str(exc)) from exc
finally:
if torch.cuda.is_available():
torch.cuda.empty_cache()
@gpu(duration=120)
def run_unisharp(
image_path: str | None,
progress: gr.Progress = gr.Progress(track_tqdm=True),
):
return _run_unisharp_once(image_path, progress=progress)
def _panorama_example_path(name: str) -> Path:
if name.lower().endswith((".jpg", ".jpeg", ".webp")):
return EXAMPLE_OMNIROOMS_DIR / name
return EXAMPLE_REPLICA_DIR / f"{name}.png"
def _panorama_example_paths() -> list[Path]:
return [path for name in PANORAMA_EXAMPLE_NAMES if (path := _panorama_example_path(name)).is_file()]
def _perspective_example_paths() -> list[Path]:
return [EXAMPLE_PERSPECTIVE_DIR / name for name in PERSPECTIVE_EXAMPLE_NAMES if (EXAMPLE_PERSPECTIVE_DIR / name).is_file()]
def _example_paths() -> list[Path]:
return _panorama_example_paths() + _perspective_example_paths()
def _demo_gallery_items(paths: list[Path]) -> list[str]:
return [str(path) for path in paths]
def _normalize_image_input(image: Any) -> str | np.ndarray | Image.Image | None:
if image is None:
return None
if isinstance(image, Path):
return str(image)
if isinstance(image, str):
return image
if isinstance(image, (np.ndarray, Image.Image)):
return image
if isinstance(image, dict):
path = image.get("path") or image.get("url") or image.get("name")
return str(path) if path else None
if isinstance(image, (list, tuple)):
return _normalize_image_input(image[0]) if image else None
if hasattr(image, "model_dump"):
try:
data = image.model_dump()
if isinstance(data, dict):
return _normalize_image_input(data.get("path") or data.get("url"))
except Exception:
pass
path = getattr(image, "path", None)
if path:
return str(path)
url = getattr(image, "url", None)
if url:
return str(url)
return None
def _load_example_image(path: Path) -> tuple[np.ndarray | None, str | None]:
src = Path(path)
if not src.is_file():
return None, None
STABLE_INPUT_PATH.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(src, STABLE_INPUT_PATH)
pil = ImageOps.exif_transpose(Image.open(STABLE_INPUT_PATH)).convert("RGB")
pil.save(STABLE_INPUT_PATH)
return np.array(pil), str(STABLE_INPUT_PATH)
def _persist_image_to_stable(
image: str | dict[str, Any] | np.ndarray | Image.Image | None,
) -> tuple[np.ndarray | None, str | None]:
"""Copy uploads/previews to a stable path so Gradio temp files can expire safely."""
image = _normalize_image_input(image)
if image is None:
return None, None
if isinstance(image, Path):
image = str(image)
STABLE_INPUT_PATH.parent.mkdir(parents=True, exist_ok=True)
if isinstance(image, str):
src = Path(image)
if not src.exists():
raise gr.Error(f"Selected image is unavailable: {src}")
shutil.copy2(src, STABLE_INPUT_PATH)
elif isinstance(image, np.ndarray):
Image.fromarray(image).convert("RGB").save(STABLE_INPUT_PATH)
elif isinstance(image, Image.Image):
image.convert("RGB").save(STABLE_INPUT_PATH)
else:
raise gr.Error("Unsupported image input.")
pil = ImageOps.exif_transpose(Image.open(STABLE_INPUT_PATH)).convert("RGB")
pil.save(STABLE_INPUT_PATH)
return np.array(pil), str(STABLE_INPUT_PATH)
def _make_gallery_select_handler(paths: list[Path], *, columns: int = 6):
def _handler(evt: gr.SelectData) -> tuple[np.ndarray | None, str | None]:
if not paths:
return None, None
index = getattr(evt, "index", 0)
if isinstance(index, (list, tuple)):
if len(index) >= 2:
flat_index = int(index[0]) * columns + int(index[1])
else:
flat_index = int(index[0]) if index else 0
else:
flat_index = int(index)
flat_index = max(0, min(flat_index, len(paths) - 1))
return _load_example_image(paths[flat_index])
return _handler
CSS = """
:root {
--unisharp-orange: #f97316;
--unisharp-orange-soft: rgba(249, 115, 22, 0.12);
}
.gradio-container {
max-width: none !important;
width: 100% !important;
padding: 24px 36px 36px !important;
}
.resource-list {
background: var(--block-background-fill);
border: 1px solid var(--border-color-primary);
border-radius: 18px;
box-shadow: var(--block-shadow);
margin-bottom: 18px;
padding: 22px 26px;
}
.resource-list h1,
.resource-list h2,
.resource-list strong {
color: var(--unisharp-orange);
}
.resource-list a {
color: var(--unisharp-orange) !important;
}
.resource-list ul { margin-top: 10px; }
.resource-list li { margin: 5px 0; }
#main_row {
align-items: stretch;
gap: 22px;
}
#run_button {
min-height: 48px;
margin: 12px 0 14px 0;
}
.panel-card {
background: var(--block-background-fill);
border: 1px solid var(--border-color-primary);
border-radius: 18px;
box-shadow: var(--block-shadow);
padding: 14px;
}
.output-card .label-wrap,
.input-card .label-wrap {
background: var(--unisharp-orange-soft) !important;
color: var(--unisharp-orange) !important;
border-radius: 999px !important;
padding: 4px 10px !important;
}
.example-panel {
background: var(--block-background-fill) !important;
border: 1px solid var(--border-color-primary) !important;
border-radius: 14px !important;
box-shadow: var(--block-shadow) !important;
padding: 10px 12px 12px 12px !important;
margin-top: 10px !important;
}
.example-panel-title,
.example-panel-title p {
color: var(--unisharp-orange) !important;
font-size: 0.95rem !important;
font-weight: 600 !important;
margin: 0 0 8px 0 !important;
padding: 0 !important;
}
#demo_gallery_panorama .label-wrap,
#demo_gallery_perspective .label-wrap {
display: none !important;
}
#demo_gallery_panorama_scroll,
#demo_gallery_perspective_scroll {
height: 175px !important;
max-height: 175px !important;
overflow-y: auto !important;
overflow-x: hidden !important;
scrollbar-gutter: stable;
}
#demo_gallery_panorama,
#demo_gallery_panorama > .wrap,
#demo_gallery_panorama > .wrap > div,
#demo_gallery_panorama .grid-wrap,
#demo_gallery_panorama .gallery,
#demo_gallery_panorama [data-testid="gallery"],
#demo_gallery_panorama [role="grid"],
#demo_gallery_perspective,
#demo_gallery_perspective > .wrap,
#demo_gallery_perspective > .wrap > div,
#demo_gallery_perspective .grid-wrap,
#demo_gallery_perspective .gallery,
#demo_gallery_perspective [data-testid="gallery"],
#demo_gallery_perspective [role="grid"] {
height: auto !important;
max-height: none !important;
overflow: visible !important;
}
#demo_gallery_panorama .caption,
#demo_gallery_perspective .caption { display: none !important; }
#demo_gallery_panorama [role="gridcell"],
#demo_gallery_panorama .thumbnail-item,
#demo_gallery_perspective [role="gridcell"],
#demo_gallery_perspective .thumbnail-item {
aspect-ratio: 16 / 9 !important;
}
#demo_gallery_panorama img,
#demo_gallery_perspective img {
min-height: 64px !important;
max-height: 82px !important;
object-fit: cover !important;
}
/* Example galleries: interactive=False (select only). Hide any leftover upload UI. */
#demo_gallery_panorama .icon-button,
#demo_gallery_perspective .icon-button,
#demo_gallery_panorama .upload-container,
#demo_gallery_perspective .upload-container,
#demo_gallery_panorama [data-testid="modify-upload"],
#demo_gallery_perspective [data-testid="modify-upload"] {
display: none !important;
}
#main_row > .gr-column {
flex: 1 1 0 !important;
min-width: 0 !important;
}
#outputs_panel .output-card {
margin-bottom: 10px !important;
}
/* Hide fullscreen only; do not hide image buttons (Gradio renders GIF inside them). */
.input-card button[aria-label*="ullscreen"],
.input-card button[aria-label*="Expand"],
.input-card button[aria-label*="expand"] {
display: none !important;
}
"""
INTRO = """
# UniSHARP
Predicts a 3D Gaussian point cloud from a single image across diverse camera models, enabling high-quality novel view synthesis.
Here are our resources:
- **πŸ’» Code:** https://github.com/Insta360-Research-Team/UniSHARP
- **🌐 Web Page:** https://insta360-research-team.github.io/Unisharp-website/
- **πŸ“„ Paper:** (coming soon)
- **πŸ“¦ Dataset:** https://huggingface.co/datasets/Insta360-Research/OmniRooms
- **πŸ€— Demo:** https://huggingface.co/spaces/Insta360-Research/UniSHARP
"""
THEME = gr.themes.Soft(primary_hue="orange", secondary_hue="orange")
with gr.Blocks(title="UniSHARP", css=CSS, theme=THEME) as demo:
gr.Markdown(INTRO, elem_classes=["resource-list"])
panorama_paths = _panorama_example_paths()
perspective_paths = _perspective_example_paths()
default_state: str | None = None
default_example: np.ndarray | None = None
if panorama_paths:
default_example, default_state = _load_example_image(panorama_paths[0])
image_state = gr.State(value=default_state)
with gr.Row(equal_height=False, elem_id="main_row"):
with gr.Column(scale=1, elem_classes=["panel-card"]):
image_preview = gr.Image(
label="Input Image",
value=default_example,
type="numpy",
sources=["upload"],
height=360,
interactive=True,
show_download_button=False,
show_share_button=False,
show_fullscreen_button=False,
elem_classes=["input-card"],
)
run_button = gr.Button("Run", variant="primary", elem_id="run_button")
with gr.Column(elem_classes=["example-panel"], elem_id="example_panel_panorama"):
gr.Markdown("Example (Panorama)", elem_classes=["example-panel-title"])
with gr.Column(elem_id="demo_gallery_panorama_scroll"):
demo_gallery_panorama = gr.Gallery(
value=_demo_gallery_items(panorama_paths),
label="Example (Panorama)",
show_label=False,
columns=6,
rows=None,
allow_preview=False,
preview=False,
object_fit="cover",
interactive=False,
type="filepath",
show_download_button=False,
show_share_button=False,
show_fullscreen_button=False,
elem_id="demo_gallery_panorama",
)
with gr.Column(elem_classes=["example-panel"], elem_id="example_panel_perspective"):
gr.Markdown("Example (Perspective)", elem_classes=["example-panel-title"])
with gr.Column(elem_id="demo_gallery_perspective_scroll"):
demo_gallery_perspective = gr.Gallery(
value=_demo_gallery_items(perspective_paths),
label="Example (Perspective)",
show_label=False,
columns=6,
rows=None,
allow_preview=False,
preview=False,
object_fit="cover",
interactive=False,
type="filepath",
show_download_button=False,
show_share_button=False,
show_fullscreen_button=False,
elem_id="demo_gallery_perspective",
)
with gr.Column(scale=1, elem_classes=["panel-card"], elem_id="outputs_panel"):
forward_out = gr.Image(
label="Forward View",
type="filepath",
sources=[],
height=350,
interactive=False,
show_download_button=False,
show_share_button=False,
show_fullscreen_button=False,
elem_classes=["output-card"],
)
orbit_out = gr.Image(
label="Orbit View",
type="filepath",
sources=[],
height=350,
interactive=False,
show_download_button=False,
show_share_button=False,
show_fullscreen_button=False,
elem_classes=["output-card"],
)
ply_out = gr.File(label="Gaussian PLY")
image_preview.upload(
fn=_persist_image_to_stable,
inputs=[image_preview],
outputs=[image_preview, image_state],
show_progress="hidden",
queue=False,
)
demo_gallery_panorama.select(
fn=_make_gallery_select_handler(panorama_paths),
inputs=None,
outputs=[image_preview, image_state],
show_progress="hidden",
queue=False,
)
demo_gallery_perspective.select(
fn=_make_gallery_select_handler(perspective_paths),
inputs=None,
outputs=[image_preview, image_state],
show_progress="hidden",
queue=False,
)
run_button.click(
fn=run_unisharp,
inputs=[image_state],
outputs=[forward_out, orbit_out, ply_out],
show_progress="full",
)
if __name__ == "__main__":
demo.queue(max_size=8).launch()