Spaces:
Running on Zero
Running on Zero
| from __future__ import annotations | |
| from functools import lru_cache | |
| import os | |
| from pathlib import Path | |
| from typing import Optional | |
| def _ensure_localhost_no_proxy() -> None: | |
| hosts = ["127.0.0.1", "localhost", "::1"] | |
| for key in ("NO_PROXY", "no_proxy"): | |
| current = os.environ.get(key, "") | |
| values = [value.strip() for value in current.split(",") if value.strip()] | |
| changed = False | |
| for host in hosts: | |
| if host not in values: | |
| values.append(host) | |
| changed = True | |
| if changed or not current: | |
| os.environ[key] = ",".join(values) | |
| _ensure_localhost_no_proxy() | |
| def _ensure_hf_cache_dirs() -> None: | |
| hf_home = os.environ.get("HF_HOME", "/tmp/huggingface") | |
| hub_cache = os.environ.get("HF_HUB_CACHE", os.path.join(hf_home, "hub")) | |
| assets_cache = os.environ.get("HF_ASSETS_CACHE", os.path.join(hf_home, "assets")) | |
| os.environ["HF_HOME"] = hf_home | |
| os.environ["HF_HUB_CACHE"] = hub_cache | |
| os.environ["HF_ASSETS_CACHE"] = assets_cache | |
| os.environ.setdefault("HUGGINGFACE_HUB_CACHE", hub_cache) | |
| os.environ.setdefault("HF_HUB_DISABLE_XET", "1") | |
| os.makedirs(hf_home, exist_ok=True) | |
| os.makedirs(hub_cache, exist_ok=True) | |
| os.makedirs(assets_cache, exist_ok=True) | |
| _ensure_hf_cache_dirs() | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| try: | |
| import spaces | |
| except ImportError: | |
| class _SpacesFallback: | |
| def GPU(fn=None, **_kwargs): | |
| if callable(fn): | |
| return fn | |
| def decorator(inner_fn): | |
| return inner_fn | |
| return decorator | |
| spaces = _SpacesFallback() | |
| from InfiniDepth.gs import GSPixelAlignPredictor, export_ply | |
| from InfiniDepth.utils.gs_utils import ( | |
| _build_sparse_uniform_gaussians, | |
| ) | |
| from InfiniDepth.utils.hf_demo_utils import ( | |
| DemoArtifacts, | |
| ensure_session_output_dir, | |
| export_point_cloud_assets, | |
| preview_depth_file, | |
| save_demo_artifacts, | |
| scan_example_cases, | |
| ) | |
| from InfiniDepth.utils.hf_gs_viewer import ( | |
| APP_TEMP_ROOT as GS_VIEWER_ROOT, | |
| build_embedded_viewer_html, | |
| build_viewer_error_html, | |
| ) | |
| from InfiniDepth.utils.inference_utils import ( | |
| apply_sky_mask_to_depth, | |
| build_camera_matrices, | |
| build_scaled_intrinsics_matrix, | |
| filter_gaussians_by_statistical_outlier, | |
| prepare_metric_depth_inputs, | |
| resolve_camera_intrinsics_for_inference, | |
| resolve_output_size_from_mode, | |
| run_optional_sampling_sky_mask, | |
| run_optional_sky_mask, | |
| unpack_gaussians_for_export, | |
| ) | |
| from InfiniDepth.utils.io_utils import depth_to_disparity | |
| from InfiniDepth.utils.model_utils import build_model | |
| from InfiniDepth.utils.sampling_utils import SAMPLING_METHODS | |
| APP_ROOT = Path(__file__).resolve().parent | |
| EXAMPLES_DIR = APP_ROOT / "example_data" | |
| INPUT_SIZE = (768, 1024) | |
| APP_NAME = "infinidepth-hf-demo" | |
| GS_TASK_CHOICE = "3DGS" | |
| TASK_CHOICES = ["Depth", GS_TASK_CHOICE] | |
| RGB_MODEL_TYPE = "InfiniDepth" | |
| DEPTH_SENSOR_MODEL_TYPE = "InfiniDepth_DepthSensor" | |
| MODEL_CHOICES = [RGB_MODEL_TYPE, DEPTH_SENSOR_MODEL_TYPE] | |
| OUTPUT_MODE_CHOICES = ["upsample", "original", "specific"] | |
| GS_SAMPLE_POINT_NUM = 2000000 | |
| GS_COORD_DETERMINISTIC_SAMPLING = True | |
| DEPTH_GPU_DURATION_SECONDS = 180 | |
| GS_GPU_DURATION_SECONDS = 240 | |
| LOCAL_DEPTH_MODEL_PATHS = { | |
| "InfiniDepth": APP_ROOT / "checkpoints/depth/infinidepth.ckpt", | |
| "InfiniDepth_DepthSensor": APP_ROOT / "checkpoints/depth/infinidepth_depthsensor.ckpt", | |
| } | |
| LOCAL_GS_MODEL_PATHS = { | |
| "InfiniDepth": APP_ROOT / "checkpoints/gs/infinidepth_gs.ckpt", | |
| "InfiniDepth_DepthSensor": APP_ROOT / "checkpoints/gs/infinidepth_depthsensor_gs.ckpt", | |
| } | |
| HF_REPO_ID = "ritianyu/InfiniDepth" | |
| HF_DEPTH_FILENAMES = { | |
| "InfiniDepth": "infinidepth.ckpt", | |
| "InfiniDepth_DepthSensor": "infinidepth_depthsensor.ckpt", | |
| } | |
| HF_GS_FILENAMES = { | |
| "InfiniDepth": "infinidepth_gs.ckpt", | |
| "InfiniDepth_DepthSensor": "infinidepth_depthsensor_gs.ckpt", | |
| } | |
| LOCAL_MOGE2_PATH = APP_ROOT / "checkpoints/moge-2-vitl-normal/model.pt" | |
| HF_MOGE2_FILENAME = "moge2.pt" | |
| LOCAL_SKYSEG_PATH = APP_ROOT / "checkpoints/sky/skyseg.onnx" | |
| HF_SKYSEG_FILENAME = "skyseg.onnx" | |
| EXAMPLE_CASES = scan_example_cases(EXAMPLES_DIR) | |
| EXAMPLE_LOOKUP = {case.name: case for case in EXAMPLE_CASES} | |
| DEFAULT_EXAMPLE_NAME = EXAMPLE_CASES[0].name if EXAMPLE_CASES else None | |
| DEFAULT_EXAMPLE_INDEX = 0 if EXAMPLE_CASES else None | |
| EXAMPLE_GALLERY_ITEMS = [(case.image_path, case.gallery_caption) for case in EXAMPLE_CASES] | |
| DEPTH_VIEW_TAB_ID = "pcd-viewer-tab" | |
| GS_VIEW_TAB_ID = "gs-viewer-tab" | |
| gr.set_static_paths(paths=[str(GS_VIEWER_ROOT)]) | |
| CSS = """ | |
| #top-workspace { | |
| align-items: stretch; | |
| } | |
| #controls-column, | |
| #inputs-column, | |
| #outputs-column { | |
| min-width: 0; | |
| } | |
| #example-gallery { | |
| min-height: 280px; | |
| } | |
| #input-image { | |
| min-height: 420px; | |
| } | |
| #input-depth-preview { | |
| min-height: 240px; | |
| } | |
| #depth-model3d-viewer { | |
| height: 700px; | |
| } | |
| #depth-model3d-viewer canvas, | |
| #depth-model3d-viewer model-viewer, | |
| #depth-model3d-viewer .wrap, | |
| #depth-model3d-viewer .container { | |
| height: 100% !important; | |
| max-height: 100% !important; | |
| } | |
| #gs-viewer-html { | |
| min-height: 748px; | |
| padding-bottom: 0.75rem; | |
| } | |
| #gs-viewer-html iframe { | |
| display: block; | |
| width: 100%; | |
| height: 700px !important; | |
| min-height: 700px !important; | |
| } | |
| #depth-preview, | |
| #depth-comparison, | |
| #depth-color { | |
| min-height: 260px; | |
| } | |
| """ | |
| def _ensure_cuda() -> None: | |
| if not torch.cuda.is_available(): | |
| raise gr.Error( | |
| "No CUDA device is available for this request. On Hugging Face ZeroGPU, " | |
| "GPU access is only attached while the decorated inference call is running." | |
| ) | |
| def _resolve_repo_asset(local_path: Path, filename: str) -> str: | |
| if local_path.exists(): | |
| return str(local_path) | |
| return hf_hub_download( | |
| repo_id=HF_REPO_ID, | |
| filename=filename, | |
| ) | |
| def _resolve_depth_checkpoint(model_type: str) -> str: | |
| return _resolve_repo_asset(LOCAL_DEPTH_MODEL_PATHS[model_type], HF_DEPTH_FILENAMES[model_type]) | |
| def _resolve_gs_checkpoint(model_type: str) -> str: | |
| return _resolve_repo_asset(LOCAL_GS_MODEL_PATHS[model_type], HF_GS_FILENAMES[model_type]) | |
| def _resolve_skyseg_path() -> str: | |
| return _resolve_repo_asset(LOCAL_SKYSEG_PATH, HF_SKYSEG_FILENAME) | |
| def _resolve_moge2_source() -> str: | |
| return _resolve_repo_asset(LOCAL_MOGE2_PATH, HF_MOGE2_FILENAME) | |
| def _preload_repo_assets() -> tuple[str, ...]: | |
| depth_paths = tuple(_resolve_depth_checkpoint(model_type) for model_type in MODEL_CHOICES) | |
| gs_paths = tuple(_resolve_gs_checkpoint(model_type) for model_type in MODEL_CHOICES) | |
| return depth_paths + gs_paths + (_resolve_moge2_source(), _resolve_skyseg_path()) | |
| def _load_model(model_type: str): | |
| _ensure_cuda() | |
| model_path = _resolve_depth_checkpoint(model_type) | |
| return build_model(model_type=model_type, model_path=model_path) | |
| def _load_gs_predictor(model_type: str, dino_feature_dim: int): | |
| _ensure_cuda() | |
| predictor = GSPixelAlignPredictor(dino_feature_dim=dino_feature_dim).to(torch.device("cuda")) | |
| predictor.load_from_infinidepth_gs_checkpoint(_resolve_gs_checkpoint(model_type)) | |
| predictor.eval() | |
| return predictor | |
| def _to_optional_float(value: Optional[float]) -> Optional[float]: | |
| if value in (None, ""): | |
| return None | |
| return float(value) | |
| def _to_rgb_uint8(image: np.ndarray) -> np.ndarray: | |
| image = np.asarray(image) | |
| if image.ndim != 3 or image.shape[2] != 3: | |
| raise gr.Error("Input image must be an RGB image.") | |
| if image.dtype == np.uint8: | |
| return image | |
| if np.issubdtype(image.dtype, np.floating): | |
| image = np.clip(image, 0.0, 1.0 if image.max() <= 1.0 else 255.0) | |
| if image.max() <= 1.0: | |
| image = image * 255.0 | |
| return image.astype(np.uint8) | |
| return np.clip(image, 0, 255).astype(np.uint8) | |
| def _prepare_image_tensors(image_rgb: np.ndarray) -> tuple[np.ndarray, torch.Tensor, tuple[int, int]]: | |
| image_rgb = _to_rgb_uint8(image_rgb) | |
| org_h, org_w = image_rgb.shape[:2] | |
| resized = cv2.resize(image_rgb, INPUT_SIZE[::-1], interpolation=cv2.INTER_AREA) | |
| image = torch.from_numpy(resized).permute(2, 0, 1).unsqueeze(0).float() / 255.0 | |
| return image_rgb, image, (org_h, org_w) | |
| def _format_depth_status( | |
| model_type: str, | |
| metric_depth_source: str, | |
| intrinsics_source: str, | |
| output_hw: tuple[int, int], | |
| depth_file: Optional[str], | |
| ) -> str: | |
| depth_label = Path(depth_file).name if depth_file else "None" | |
| return ( | |
| f"Task: `Depth`\n\n" | |
| f"Model: `{model_type}`\n\n" | |
| f"Input depth: `{depth_label}`\n\n" | |
| f"Metric alignment source: `{metric_depth_source}`\n\n" | |
| f"Camera intrinsics source: `{intrinsics_source}`\n\n" | |
| f"Output size: `{output_hw[0]} x {output_hw[1]}`" | |
| ) | |
| def _format_gs_status( | |
| model_type: str, | |
| metric_depth_source: str, | |
| intrinsics_source: str, | |
| depth_file: Optional[str], | |
| gaussian_count: int, | |
| ) -> str: | |
| depth_label = Path(depth_file).name if depth_file else "None" | |
| return ( | |
| f"Task: `GS`\n\n" | |
| f"Model: `{model_type}`\n\n" | |
| f"Input depth: `{depth_label}`\n\n" | |
| f"Metric alignment source: `{metric_depth_source}`\n\n" | |
| f"Camera intrinsics source: `{intrinsics_source}`\n\n" | |
| f"Exported gaussians: `{gaussian_count}`" | |
| ) | |
| def _model_availability_note(depth_path: Optional[str], model_type: str, *, auto_switched: bool = False) -> str: | |
| if depth_path: | |
| if auto_switched and model_type == DEPTH_SENSOR_MODEL_TYPE: | |
| return ( | |
| "Depth file loaded. Switched model to `InfiniDepth_DepthSensor`. " | |
| "You can still switch back to `InfiniDepth` for RGB-only inference." | |
| ) | |
| return ( | |
| "Depth file loaded. `InfiniDepth_DepthSensor` is available. " | |
| "You can also keep `InfiniDepth` for RGB-only inference." | |
| ) | |
| if auto_switched: | |
| return ( | |
| "No input depth loaded. Switched model back to `InfiniDepth`. " | |
| "Upload a depth file to enable `InfiniDepth_DepthSensor`." | |
| ) | |
| return "No input depth loaded. `InfiniDepth` will be used until you upload a depth file." | |
| def _compose_depth_info_message(base_message: str, note: str) -> str: | |
| return f"{base_message}\n\n{note}" if note else base_message | |
| def _load_example_image(example_name: str) -> tuple[np.ndarray, Optional[str], Optional[np.ndarray], str, str]: | |
| if not example_name: | |
| raise gr.Error("Select an example case first.") | |
| case = EXAMPLE_LOOKUP[example_name] | |
| image_bgr = cv2.imread(case.image_path, cv2.IMREAD_COLOR) | |
| if image_bgr is None: | |
| raise gr.Error(f"Failed to load example image: {case.image_path}") | |
| image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) | |
| depth_path = case.depth_path | |
| preview = None | |
| detail = _compose_depth_info_message( | |
| f"Loaded example `{case.name}`.", | |
| _model_availability_note(None, RGB_MODEL_TYPE), | |
| ) | |
| model_type = RGB_MODEL_TYPE | |
| if depth_path is not None: | |
| preview, depth_msg = preview_depth_file(depth_path) | |
| model_type = DEPTH_SENSOR_MODEL_TYPE | |
| detail = _compose_depth_info_message( | |
| f"Loaded example `{case.name}` with paired depth. {depth_msg}", | |
| _model_availability_note(depth_path, model_type, auto_switched=True), | |
| ) | |
| return image_rgb, depth_path, preview, model_type, detail | |
| def _selected_example_message(example_name: Optional[str]) -> str: | |
| if not example_name or example_name not in EXAMPLE_LOOKUP: | |
| return "Select an example thumbnail, then click `Load Example`." | |
| case = EXAMPLE_LOOKUP[example_name] | |
| mode_label = "RGB + depth" if case.has_depth else "RGB only" | |
| return f"Selected example: `{case.name}` ({mode_label})" | |
| def _select_example(evt: gr.SelectData): | |
| if not EXAMPLE_CASES or evt.index is None: | |
| return None, _selected_example_message(None) | |
| index = evt.index | |
| if isinstance(index, (tuple, list)): | |
| index = index[0] | |
| case = EXAMPLE_CASES[int(index)] | |
| return case.name, _selected_example_message(case.name) | |
| def _primary_view_for_task(task_type: str): | |
| selected_tab = GS_VIEW_TAB_ID if task_type == GS_TASK_CHOICE else DEPTH_VIEW_TAB_ID | |
| return gr.update(selected=selected_tab) | |
| def _reset_uploaded_image_state( | |
| _image: Optional[np.ndarray], | |
| depth_path: Optional[str], | |
| ) -> tuple[None, None, str, str]: | |
| note = ( | |
| "Image updated. Cleared the previous depth file. Upload a new depth file to enable " | |
| "`InfiniDepth_DepthSensor`." | |
| if depth_path | |
| else "Image updated. Upload a depth file to enable `InfiniDepth_DepthSensor`." | |
| ) | |
| return None, None, RGB_MODEL_TYPE, note | |
| def _update_depth_preview(depth_path: Optional[str], model_type: str) -> tuple[Optional[np.ndarray], str, str]: | |
| try: | |
| preview, depth_msg = preview_depth_file(depth_path) | |
| except Exception as exc: | |
| raise gr.Error(f"Failed to preview depth file: {exc}") from exc | |
| if depth_path: | |
| next_model = DEPTH_SENSOR_MODEL_TYPE | |
| note = _model_availability_note(depth_path, next_model, auto_switched=(model_type != next_model)) | |
| else: | |
| next_model = RGB_MODEL_TYPE | |
| note = _model_availability_note(depth_path, next_model, auto_switched=(model_type != next_model)) | |
| return preview, next_model, _compose_depth_info_message(depth_msg, note) | |
| def _settings_visibility(task_type: str, output_resolution_mode: str): | |
| is_depth = task_type == "Depth" | |
| return ( | |
| gr.update(visible=is_depth), | |
| gr.update(visible=is_depth and output_resolution_mode == "upsample"), | |
| gr.update(visible=is_depth and output_resolution_mode == "specific"), | |
| gr.update(visible=is_depth and output_resolution_mode == "specific"), | |
| gr.update(visible=is_depth), | |
| ) | |
| def _normalize_filtered_gaussians(filtered_result): | |
| if isinstance(filtered_result, tuple): | |
| return filtered_result[0] | |
| return filtered_result | |
| def _run_depth_inference( | |
| image: np.ndarray, | |
| depth_file: Optional[str], | |
| model_type: str, | |
| output_resolution_mode: str, | |
| upsample_ratio: int, | |
| specific_height: int, | |
| specific_width: int, | |
| enable_skyseg_model: bool, | |
| filter_point_cloud: bool, | |
| fx_org: Optional[float], | |
| fy_org: Optional[float], | |
| cx_org: Optional[float], | |
| cy_org: Optional[float], | |
| request: gr.Request, | |
| ): | |
| _ensure_cuda() | |
| if image is None: | |
| raise gr.Error("Upload an image or load an example before running inference.") | |
| if model_type == DEPTH_SENSOR_MODEL_TYPE and not depth_file: | |
| raise gr.Error("InfiniDepth_DepthSensor requires an input depth file.") | |
| skyseg_path = _resolve_skyseg_path() if enable_skyseg_model else None | |
| image_rgb, image_tensor, (org_h, org_w) = _prepare_image_tensors(image) | |
| device = torch.device("cuda") | |
| image_tensor = image_tensor.to(device) | |
| model = _load_model(model_type) | |
| gt_depth, prompt_depth, gt_depth_mask, use_gt_depth, moge2_intrinsics = prepare_metric_depth_inputs( | |
| input_depth_path=depth_file, | |
| input_size=INPUT_SIZE, | |
| image=image_tensor, | |
| device=device, | |
| moge2_pretrained=_resolve_moge2_source(), | |
| ) | |
| gt_disp = depth_to_disparity(gt_depth) | |
| prompt_disp = depth_to_disparity(prompt_depth) | |
| fx_org, fy_org, cx_org, cy_org, intrinsics_source = resolve_camera_intrinsics_for_inference( | |
| fx_org=_to_optional_float(fx_org), | |
| fy_org=_to_optional_float(fy_org), | |
| cx_org=_to_optional_float(cx_org), | |
| cy_org=_to_optional_float(cy_org), | |
| org_h=org_h, | |
| org_w=org_w, | |
| image=image_tensor, | |
| moge2_pretrained=_resolve_moge2_source(), | |
| moge2_intrinsics=moge2_intrinsics, | |
| ) | |
| _, _, h, w = image_tensor.shape | |
| fx, fy, cx, cy, _ = build_scaled_intrinsics_matrix( | |
| fx_org=fx_org, | |
| fy_org=fy_org, | |
| cx_org=cx_org, | |
| cy_org=cy_org, | |
| org_h=org_h, | |
| org_w=org_w, | |
| h=h, | |
| w=w, | |
| device=image_tensor.device, | |
| ) | |
| sky_mask = run_optional_sky_mask( | |
| image=image_tensor, | |
| enable_skyseg_model=enable_skyseg_model, | |
| sky_model_ckpt_path=skyseg_path or str(LOCAL_SKYSEG_PATH), | |
| ) | |
| h_out, w_out = resolve_output_size_from_mode( | |
| output_resolution_mode=output_resolution_mode, | |
| org_h=org_h, | |
| org_w=org_w, | |
| h=h, | |
| w=w, | |
| output_size=(int(specific_height), int(specific_width)), | |
| upsample_ratio=int(upsample_ratio), | |
| ) | |
| query_2d_uniform_coord = SAMPLING_METHODS["2d_uniform"]((h_out, w_out)).unsqueeze(0).to(device) | |
| pred_2d_uniform_depth, _ = model.inference( | |
| image=image_tensor, | |
| query_coord=query_2d_uniform_coord, | |
| gt_depth=gt_disp, | |
| gt_depth_mask=gt_depth_mask, | |
| prompt_depth=prompt_disp, | |
| prompt_mask=prompt_disp > 0, | |
| ) | |
| pred_depthmap = pred_2d_uniform_depth.permute(0, 2, 1).reshape(1, 1, h_out, w_out) | |
| pred_depthmap, pred_2d_uniform_depth = apply_sky_mask_to_depth( | |
| pred_depthmap=pred_depthmap, | |
| pred_2d_uniform_depth=pred_2d_uniform_depth, | |
| sky_mask=sky_mask, | |
| h_sample=h_out, | |
| w_sample=w_out, | |
| sky_depth_value=200.0, | |
| ) | |
| session_hash = getattr(request, "session_hash", None) | |
| output_dir = ensure_session_output_dir(APP_NAME, session_hash) | |
| pred_depth_np = pred_depthmap.squeeze(0).squeeze(0).detach().cpu().numpy().astype(np.float32) | |
| artifacts = save_demo_artifacts(image_rgb=image_rgb, pred_depth=pred_depth_np, output_dir=output_dir) | |
| ply_path, glb_path = export_point_cloud_assets( | |
| sampled_coord=query_2d_uniform_coord.squeeze(0).cpu(), | |
| sampled_depth=pred_2d_uniform_depth.squeeze(0).squeeze(-1).cpu(), | |
| rgb_image=image_tensor.squeeze(0).cpu(), | |
| fx=fx, | |
| fy=fy, | |
| cx=cx, | |
| cy=cy, | |
| output_dir=output_dir, | |
| filter_flying_points=filter_point_cloud, | |
| ) | |
| artifacts = DemoArtifacts( | |
| comparison_path=artifacts.comparison_path, | |
| color_depth_path=artifacts.color_depth_path, | |
| gray_depth_path=artifacts.gray_depth_path, | |
| raw_depth_path=artifacts.raw_depth_path, | |
| ply_path=ply_path, | |
| glb_path=glb_path, | |
| ) | |
| metric_depth_source = "user depth" if use_gt_depth and depth_file else "MoGe-2" | |
| status = _format_depth_status( | |
| model_type=model_type, | |
| metric_depth_source=metric_depth_source, | |
| intrinsics_source=intrinsics_source, | |
| output_hw=(h_out, w_out), | |
| depth_file=depth_file, | |
| ) | |
| return ( | |
| status, | |
| artifacts.comparison_path, | |
| artifacts.color_depth_path, | |
| artifacts.gray_depth_path, | |
| glb_path, | |
| artifacts.download_files(), | |
| None, | |
| None, | |
| ) | |
| def _run_gs_inference( | |
| image: np.ndarray, | |
| depth_file: Optional[str], | |
| model_type: str, | |
| enable_skyseg_model: bool, | |
| fx_org: Optional[float], | |
| fy_org: Optional[float], | |
| cx_org: Optional[float], | |
| cy_org: Optional[float], | |
| request: gr.Request, | |
| ): | |
| _ensure_cuda() | |
| if image is None: | |
| raise gr.Error("Upload an image or load an example before running inference.") | |
| if model_type == DEPTH_SENSOR_MODEL_TYPE and not depth_file: | |
| raise gr.Error("InfiniDepth_DepthSensor requires an input depth file for GS inference.") | |
| image_rgb, image_tensor, (org_h, org_w) = _prepare_image_tensors(image) | |
| del image_rgb | |
| device = torch.device("cuda") | |
| image_tensor = image_tensor.to(device) | |
| model = _load_model(model_type) | |
| gt_depth, prompt_depth, gt_depth_mask, use_gt_depth, moge2_intrinsics = prepare_metric_depth_inputs( | |
| input_depth_path=depth_file, | |
| input_size=INPUT_SIZE, | |
| image=image_tensor, | |
| device=device, | |
| moge2_pretrained=_resolve_moge2_source(), | |
| ) | |
| gt_disp = depth_to_disparity(gt_depth) | |
| prompt_disp = depth_to_disparity(prompt_depth) | |
| fx_org, fy_org, cx_org, cy_org, intrinsics_source = resolve_camera_intrinsics_for_inference( | |
| fx_org=_to_optional_float(fx_org), | |
| fy_org=_to_optional_float(fy_org), | |
| cx_org=_to_optional_float(cx_org), | |
| cy_org=_to_optional_float(cy_org), | |
| org_h=org_h, | |
| org_w=org_w, | |
| image=image_tensor, | |
| moge2_pretrained=_resolve_moge2_source(), | |
| moge2_intrinsics=moge2_intrinsics, | |
| ) | |
| b, _, h, w = image_tensor.shape | |
| _, _, _, _, intrinsics, extrinsics = build_camera_matrices( | |
| fx_org=fx_org, | |
| fy_org=fy_org, | |
| cx_org=cx_org, | |
| cy_org=cy_org, | |
| org_h=org_h, | |
| org_w=org_w, | |
| h=h, | |
| w=w, | |
| batch=b, | |
| device=device, | |
| ) | |
| skyseg_path = _resolve_skyseg_path() if enable_skyseg_model else str(LOCAL_SKYSEG_PATH) | |
| sky_mask = run_optional_sampling_sky_mask( | |
| image=image_tensor, | |
| enable_skyseg_model=enable_skyseg_model, | |
| sky_model_ckpt_path=skyseg_path, | |
| dilate_px=0, | |
| ) | |
| depthmap, dino_tokens, query_3d_uniform_coord, pred_depth_3d = model.inference_for_gs( | |
| image=image_tensor, | |
| intrinsics=intrinsics, | |
| gt_depth=gt_disp, | |
| gt_depth_mask=gt_depth_mask, | |
| prompt_depth=prompt_disp, | |
| prompt_mask=prompt_disp > 0, | |
| sky_mask=sky_mask, | |
| sample_point_num=GS_SAMPLE_POINT_NUM, | |
| coord_deterministic_sampling=GS_COORD_DETERMINISTIC_SAMPLING, | |
| ) | |
| if query_3d_uniform_coord is None or pred_depth_3d is None: | |
| raise gr.Error("GS inference did not return 3D-uniform query outputs.") | |
| gs_predictor = _load_gs_predictor(model_type, int(dino_tokens.shape[-1])) | |
| dense_gaussians = gs_predictor( | |
| image=image_tensor, | |
| depthmap=depthmap, | |
| dino_tokens=dino_tokens, | |
| intrinsics=intrinsics, | |
| extrinsics=extrinsics, | |
| ) | |
| pixel_gaussians = _build_sparse_uniform_gaussians( | |
| dense_gaussians=dense_gaussians, | |
| query_3d_uniform_coord=query_3d_uniform_coord, | |
| pred_depth_3d=pred_depth_3d, | |
| intrinsics=intrinsics, | |
| extrinsics=extrinsics, | |
| h=h, | |
| w=w, | |
| ) | |
| pixel_gaussians = _normalize_filtered_gaussians(filter_gaussians_by_statistical_outlier(pixel_gaussians)) | |
| gaussian_count = int(pixel_gaussians.means.shape[1]) | |
| if gaussian_count == 0: | |
| raise gr.Error("No valid gaussians remained after filtering.") | |
| means, harmonics, opacities, scales, rotations = unpack_gaussians_for_export(pixel_gaussians) | |
| session_hash = getattr(request, "session_hash", None) | |
| output_dir = ensure_session_output_dir(APP_NAME, session_hash) | |
| ply_path = output_dir / "gaussians.ply" | |
| export_ply( | |
| means=means, | |
| harmonics=harmonics, | |
| opacities=opacities, | |
| path=ply_path, | |
| scales=scales, | |
| rotations=rotations, | |
| focal_length_px=(fx_org, fy_org), | |
| principal_point_px=(cx_org, cy_org), | |
| image_shape=(org_h, org_w), | |
| extrinsic_matrix=extrinsics[0], | |
| ) | |
| try: | |
| gs_viewer_html = build_embedded_viewer_html(ply_path) | |
| except Exception as exc: | |
| print(f"[Warning] Failed to build embedded GS viewer: {exc}") | |
| gs_viewer_html = build_viewer_error_html(str(exc), ply_path) | |
| metric_depth_source = "user depth" if use_gt_depth and depth_file else "MoGe-2" | |
| status = _format_gs_status( | |
| model_type=model_type, | |
| metric_depth_source=metric_depth_source, | |
| intrinsics_source=intrinsics_source, | |
| depth_file=depth_file, | |
| gaussian_count=gaussian_count, | |
| ) | |
| download_files = [str(ply_path)] | |
| return ( | |
| status, | |
| None, | |
| None, | |
| None, | |
| None, | |
| None, | |
| gs_viewer_html, | |
| download_files, | |
| ) | |
| def _run_inference( | |
| task_type: str, | |
| image: np.ndarray, | |
| depth_file: Optional[str], | |
| model_type: str, | |
| output_resolution_mode: str, | |
| upsample_ratio: int, | |
| specific_height: int, | |
| specific_width: int, | |
| enable_skyseg_model: bool, | |
| filter_point_cloud: bool, | |
| fx_org: Optional[float], | |
| fy_org: Optional[float], | |
| cx_org: Optional[float], | |
| cy_org: Optional[float], | |
| request: gr.Request, | |
| ): | |
| if task_type == GS_TASK_CHOICE: | |
| return _run_gs_inference( | |
| image=image, | |
| depth_file=depth_file, | |
| model_type=model_type, | |
| enable_skyseg_model=enable_skyseg_model, | |
| fx_org=fx_org, | |
| fy_org=fy_org, | |
| cx_org=cx_org, | |
| cy_org=cy_org, | |
| request=request, | |
| ) | |
| return _run_depth_inference( | |
| image=image, | |
| depth_file=depth_file, | |
| model_type=model_type, | |
| output_resolution_mode=output_resolution_mode, | |
| upsample_ratio=upsample_ratio, | |
| specific_height=specific_height, | |
| specific_width=specific_width, | |
| enable_skyseg_model=enable_skyseg_model, | |
| filter_point_cloud=filter_point_cloud, | |
| fx_org=fx_org, | |
| fy_org=fy_org, | |
| cx_org=cx_org, | |
| cy_org=cy_org, | |
| request=request, | |
| ) | |
| def _clear_outputs(): | |
| return "", None, None, None, None, None, "", None | |
| with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# InfiniDepth Demo") | |
| gr.Markdown( | |
| "Switch between depth inference and GS inference. `InfiniDepth` works with RGB-only inputs, while `InfiniDepth_DepthSensor` is enabled only after you upload a depth file or load an example with paired depth." | |
| ) | |
| selected_example_name = gr.State(DEFAULT_EXAMPLE_NAME) | |
| with gr.Row(elem_id="top-workspace"): | |
| with gr.Column(scale=4, min_width=320, elem_id="controls-column"): | |
| task_type = gr.Radio(label="Inference Task", choices=TASK_CHOICES, value="Depth") | |
| model_type = gr.Radio(label="Model Type", choices=MODEL_CHOICES, value=RGB_MODEL_TYPE) | |
| gr.Markdown("### Example Data") | |
| example_gallery = gr.Gallery( | |
| value=EXAMPLE_GALLERY_ITEMS, | |
| label="Example Data", | |
| show_label=False, | |
| columns=2, | |
| height=280, | |
| object_fit="cover", | |
| allow_preview=False, | |
| selected_index=DEFAULT_EXAMPLE_INDEX, | |
| elem_id="example-gallery", | |
| ) | |
| example_selection = gr.Markdown(_selected_example_message(DEFAULT_EXAMPLE_NAME)) | |
| load_example_btn = gr.Button("Load Example") | |
| with gr.Accordion("Depth Settings", open=True): | |
| output_resolution_mode = gr.Dropdown( | |
| label="Output Resolution Mode", | |
| choices=OUTPUT_MODE_CHOICES, | |
| value="upsample", | |
| ) | |
| upsample_ratio = gr.Slider(label="Upsample Ratio", minimum=1, maximum=4, step=1, value=1) | |
| specific_height = gr.Number(label="Specific Height", value=INPUT_SIZE[0], precision=0, visible=False) | |
| specific_width = gr.Number(label="Specific Width", value=INPUT_SIZE[1], precision=0, visible=False) | |
| enable_skyseg_model = gr.Checkbox(label="Apply Sky Mask", value=False) | |
| filter_point_cloud = gr.Checkbox(label="Filter Flying Points", value=True) | |
| with gr.Accordion("Optional Camera Intrinsics", open=False): | |
| fx_org = gr.Textbox(label="fx", value="", placeholder="auto") | |
| fy_org = gr.Textbox(label="fy", value="", placeholder="auto") | |
| cx_org = gr.Textbox(label="cx", value="", placeholder="auto") | |
| cy_org = gr.Textbox(label="cy", value="", placeholder="auto") | |
| with gr.Column(scale=5, min_width=360, elem_id="inputs-column"): | |
| input_image = gr.Image( | |
| label="Input Image", | |
| image_mode="RGB", | |
| type="numpy", | |
| sources=["upload", "clipboard", "webcam"], | |
| height=420, | |
| elem_id="input-image", | |
| ) | |
| input_depth_file = gr.File( | |
| label="Optional Depth File", | |
| type="filepath", | |
| file_types=[".png", ".npy", ".npz", ".h5", ".hdf5", ".exr"], | |
| ) | |
| input_depth_preview = gr.Image( | |
| label="Input Depth Preview", | |
| type="numpy", | |
| height=240, | |
| elem_id="input-depth-preview", | |
| ) | |
| depth_info = gr.Markdown("No input depth loaded. `InfiniDepth` will be used until you upload a depth file.") | |
| submit_btn = gr.Button("Run Inference", variant="primary") | |
| with gr.Column(scale=8, min_width=640, elem_id="outputs-column"): | |
| status_output = gr.Markdown() | |
| with gr.Tabs(selected=DEPTH_VIEW_TAB_ID, elem_id="primary-view-tabs") as primary_view_tabs: | |
| with gr.Tab("PCD Viewer", id=DEPTH_VIEW_TAB_ID, render_children=True): | |
| depth_model_3d = gr.Model3D( | |
| label="Point Cloud Viewer", | |
| display_mode="solid", | |
| clear_color=[1.0, 1.0, 1.0, 1.0], | |
| height=700, | |
| elem_id="depth-model3d-viewer", | |
| ) | |
| with gr.Tab("GS Viewer", id=GS_VIEW_TAB_ID, render_children=True): | |
| gs_viewer_html = gr.HTML(elem_id="gs-viewer-html") | |
| with gr.Tabs(elem_id="secondary-output-tabs"): | |
| with gr.Tab("Depth Analysis", render_children=True): | |
| depth_comparison = gr.Image( | |
| label="RGB vs Depth", | |
| type="filepath", | |
| height=280, | |
| elem_id="depth-comparison", | |
| ) | |
| with gr.Row(): | |
| color_depth = gr.Image( | |
| label="Colorized Depth", | |
| type="filepath", | |
| height=260, | |
| elem_id="depth-color", | |
| ) | |
| gray_depth = gr.Image( | |
| label="Grayscale Depth", | |
| type="filepath", | |
| height=260, | |
| elem_id="depth-preview", | |
| ) | |
| with gr.Tab("Downloads", render_children=True): | |
| with gr.Row(): | |
| depth_download_files = gr.File(label="Depth Files", type="filepath") | |
| gs_download_files = gr.File(label="GS Files", type="filepath") | |
| task_type.change( | |
| fn=_settings_visibility, | |
| inputs=[task_type, output_resolution_mode], | |
| outputs=[output_resolution_mode, upsample_ratio, specific_height, specific_width, filter_point_cloud], | |
| ) | |
| task_type.change( | |
| fn=_primary_view_for_task, | |
| inputs=[task_type], | |
| outputs=[primary_view_tabs], | |
| ) | |
| output_resolution_mode.change( | |
| fn=_settings_visibility, | |
| inputs=[task_type, output_resolution_mode], | |
| outputs=[output_resolution_mode, upsample_ratio, specific_height, specific_width, filter_point_cloud], | |
| ) | |
| example_gallery.select( | |
| fn=_select_example, | |
| outputs=[selected_example_name, example_selection], | |
| ) | |
| input_image.input( | |
| fn=_reset_uploaded_image_state, | |
| inputs=[input_image, input_depth_file], | |
| outputs=[input_depth_file, input_depth_preview, model_type, depth_info], | |
| ) | |
| input_depth_file.change( | |
| fn=_update_depth_preview, | |
| inputs=[input_depth_file, model_type], | |
| outputs=[input_depth_preview, model_type, depth_info], | |
| ) | |
| load_example_btn.click( | |
| fn=_load_example_image, | |
| inputs=[selected_example_name], | |
| outputs=[input_image, input_depth_file, input_depth_preview, model_type, depth_info], | |
| ) | |
| submit_btn.click( | |
| fn=_primary_view_for_task, | |
| inputs=[task_type], | |
| outputs=[primary_view_tabs], | |
| ).then( | |
| fn=_clear_outputs, | |
| outputs=[ | |
| status_output, | |
| depth_comparison, | |
| color_depth, | |
| gray_depth, | |
| depth_model_3d, | |
| depth_download_files, | |
| gs_viewer_html, | |
| gs_download_files, | |
| ], | |
| ).then( | |
| fn=_run_inference, | |
| inputs=[ | |
| task_type, | |
| input_image, | |
| input_depth_file, | |
| model_type, | |
| output_resolution_mode, | |
| upsample_ratio, | |
| specific_height, | |
| specific_width, | |
| enable_skyseg_model, | |
| filter_point_cloud, | |
| fx_org, | |
| fy_org, | |
| cx_org, | |
| cy_org, | |
| ], | |
| outputs=[ | |
| status_output, | |
| depth_comparison, | |
| color_depth, | |
| gray_depth, | |
| depth_model_3d, | |
| depth_download_files, | |
| gs_viewer_html, | |
| gs_download_files, | |
| ], | |
| ) | |
| if __name__ == "__main__": | |
| _preload_repo_assets() | |
| demo.queue().launch() | |