from __future__ import annotations from functools import lru_cache import os from pathlib import Path from typing import Optional def _ensure_localhost_no_proxy() -> None: hosts = ["127.0.0.1", "localhost", "::1"] for key in ("NO_PROXY", "no_proxy"): current = os.environ.get(key, "") values = [value.strip() for value in current.split(",") if value.strip()] changed = False for host in hosts: if host not in values: values.append(host) changed = True if changed or not current: os.environ[key] = ",".join(values) _ensure_localhost_no_proxy() def _ensure_hf_cache_dirs() -> None: hf_home = os.environ.get("HF_HOME", "/tmp/huggingface") hub_cache = os.environ.get("HF_HUB_CACHE", os.path.join(hf_home, "hub")) assets_cache = os.environ.get("HF_ASSETS_CACHE", os.path.join(hf_home, "assets")) os.environ["HF_HOME"] = hf_home os.environ["HF_HUB_CACHE"] = hub_cache os.environ["HF_ASSETS_CACHE"] = assets_cache os.environ.setdefault("HUGGINGFACE_HUB_CACHE", hub_cache) os.environ.setdefault("HF_HUB_DISABLE_XET", "1") os.makedirs(hf_home, exist_ok=True) os.makedirs(hub_cache, exist_ok=True) os.makedirs(assets_cache, exist_ok=True) _ensure_hf_cache_dirs() import cv2 import gradio as gr import numpy as np import torch from huggingface_hub import hf_hub_download try: import spaces except ImportError: class _SpacesFallback: @staticmethod def GPU(fn=None, **_kwargs): if callable(fn): return fn def decorator(inner_fn): return inner_fn return decorator spaces = _SpacesFallback() from InfiniDepth.gs import GSPixelAlignPredictor, export_ply from InfiniDepth.utils.gs_utils import ( _build_sparse_uniform_gaussians, ) from InfiniDepth.utils.hf_demo_utils import ( DemoArtifacts, ensure_session_output_dir, export_point_cloud_assets, preview_depth_file, save_demo_artifacts, scan_example_cases, ) from InfiniDepth.utils.hf_gs_viewer import ( APP_TEMP_ROOT as GS_VIEWER_ROOT, build_embedded_viewer_html, build_viewer_error_html, ) from InfiniDepth.utils.inference_utils import ( apply_sky_mask_to_depth, build_camera_matrices, build_scaled_intrinsics_matrix, filter_gaussians_by_statistical_outlier, prepare_metric_depth_inputs, resolve_camera_intrinsics_for_inference, resolve_output_size_from_mode, run_optional_sampling_sky_mask, run_optional_sky_mask, unpack_gaussians_for_export, ) from InfiniDepth.utils.io_utils import depth_to_disparity from InfiniDepth.utils.model_utils import build_model from InfiniDepth.utils.sampling_utils import SAMPLING_METHODS APP_ROOT = Path(__file__).resolve().parent EXAMPLES_DIR = APP_ROOT / "example_data" INPUT_SIZE = (768, 1024) APP_NAME = "infinidepth-hf-demo" GS_TASK_CHOICE = "3DGS" TASK_CHOICES = ["Depth", GS_TASK_CHOICE] RGB_MODEL_TYPE = "InfiniDepth" DEPTH_SENSOR_MODEL_TYPE = "InfiniDepth_DepthSensor" MODEL_CHOICES = [RGB_MODEL_TYPE, DEPTH_SENSOR_MODEL_TYPE] OUTPUT_MODE_CHOICES = ["upsample", "original", "specific"] GS_SAMPLE_POINT_NUM = 2000000 GS_COORD_DETERMINISTIC_SAMPLING = True DEPTH_GPU_DURATION_SECONDS = 180 GS_GPU_DURATION_SECONDS = 240 LOCAL_DEPTH_MODEL_PATHS = { "InfiniDepth": APP_ROOT / "checkpoints/depth/infinidepth.ckpt", "InfiniDepth_DepthSensor": APP_ROOT / "checkpoints/depth/infinidepth_depthsensor.ckpt", } LOCAL_GS_MODEL_PATHS = { "InfiniDepth": APP_ROOT / "checkpoints/gs/infinidepth_gs.ckpt", "InfiniDepth_DepthSensor": APP_ROOT / "checkpoints/gs/infinidepth_depthsensor_gs.ckpt", } HF_REPO_ID = "ritianyu/InfiniDepth" HF_DEPTH_FILENAMES = { "InfiniDepth": "infinidepth.ckpt", "InfiniDepth_DepthSensor": "infinidepth_depthsensor.ckpt", } HF_GS_FILENAMES = { "InfiniDepth": "infinidepth_gs.ckpt", "InfiniDepth_DepthSensor": "infinidepth_depthsensor_gs.ckpt", } LOCAL_MOGE2_PATH = APP_ROOT / "checkpoints/moge-2-vitl-normal/model.pt" HF_MOGE2_FILENAME = "moge2.pt" LOCAL_SKYSEG_PATH = APP_ROOT / "checkpoints/sky/skyseg.onnx" HF_SKYSEG_FILENAME = "skyseg.onnx" EXAMPLE_CASES = scan_example_cases(EXAMPLES_DIR) EXAMPLE_LOOKUP = {case.name: case for case in EXAMPLE_CASES} DEFAULT_EXAMPLE_NAME = EXAMPLE_CASES[0].name if EXAMPLE_CASES else None DEFAULT_EXAMPLE_INDEX = 0 if EXAMPLE_CASES else None EXAMPLE_GALLERY_ITEMS = [(case.image_path, case.gallery_caption) for case in EXAMPLE_CASES] DEPTH_VIEW_TAB_ID = "pcd-viewer-tab" GS_VIEW_TAB_ID = "gs-viewer-tab" gr.set_static_paths(paths=[str(GS_VIEWER_ROOT)]) CSS = """ #top-workspace { align-items: stretch; } #controls-column, #inputs-column, #outputs-column { min-width: 0; } #example-gallery { min-height: 280px; } #input-image { min-height: 420px; } #input-depth-preview { min-height: 240px; } #depth-model3d-viewer { height: 700px; } #depth-model3d-viewer canvas, #depth-model3d-viewer model-viewer, #depth-model3d-viewer .wrap, #depth-model3d-viewer .container { height: 100% !important; max-height: 100% !important; } #gs-viewer-html { min-height: 748px; padding-bottom: 0.75rem; } #gs-viewer-html iframe { display: block; width: 100%; height: 700px !important; min-height: 700px !important; } #depth-preview, #depth-comparison, #depth-color { min-height: 260px; } """ def _ensure_cuda() -> None: if not torch.cuda.is_available(): raise gr.Error( "No CUDA device is available for this request. On Hugging Face ZeroGPU, " "GPU access is only attached while the decorated inference call is running." ) def _resolve_repo_asset(local_path: Path, filename: str) -> str: if local_path.exists(): return str(local_path) return hf_hub_download( repo_id=HF_REPO_ID, filename=filename, ) @lru_cache(maxsize=2) def _resolve_depth_checkpoint(model_type: str) -> str: return _resolve_repo_asset(LOCAL_DEPTH_MODEL_PATHS[model_type], HF_DEPTH_FILENAMES[model_type]) @lru_cache(maxsize=2) def _resolve_gs_checkpoint(model_type: str) -> str: return _resolve_repo_asset(LOCAL_GS_MODEL_PATHS[model_type], HF_GS_FILENAMES[model_type]) @lru_cache(maxsize=1) def _resolve_skyseg_path() -> str: return _resolve_repo_asset(LOCAL_SKYSEG_PATH, HF_SKYSEG_FILENAME) @lru_cache(maxsize=1) def _resolve_moge2_source() -> str: return _resolve_repo_asset(LOCAL_MOGE2_PATH, HF_MOGE2_FILENAME) @lru_cache(maxsize=1) def _preload_repo_assets() -> tuple[str, ...]: depth_paths = tuple(_resolve_depth_checkpoint(model_type) for model_type in MODEL_CHOICES) gs_paths = tuple(_resolve_gs_checkpoint(model_type) for model_type in MODEL_CHOICES) return depth_paths + gs_paths + (_resolve_moge2_source(), _resolve_skyseg_path()) @lru_cache(maxsize=2) def _load_model(model_type: str): _ensure_cuda() model_path = _resolve_depth_checkpoint(model_type) return build_model(model_type=model_type, model_path=model_path) @lru_cache(maxsize=4) def _load_gs_predictor(model_type: str, dino_feature_dim: int): _ensure_cuda() predictor = GSPixelAlignPredictor(dino_feature_dim=dino_feature_dim).to(torch.device("cuda")) predictor.load_from_infinidepth_gs_checkpoint(_resolve_gs_checkpoint(model_type)) predictor.eval() return predictor def _to_optional_float(value: Optional[float]) -> Optional[float]: if value in (None, ""): return None return float(value) def _to_rgb_uint8(image: np.ndarray) -> np.ndarray: image = np.asarray(image) if image.ndim != 3 or image.shape[2] != 3: raise gr.Error("Input image must be an RGB image.") if image.dtype == np.uint8: return image if np.issubdtype(image.dtype, np.floating): image = np.clip(image, 0.0, 1.0 if image.max() <= 1.0 else 255.0) if image.max() <= 1.0: image = image * 255.0 return image.astype(np.uint8) return np.clip(image, 0, 255).astype(np.uint8) def _prepare_image_tensors(image_rgb: np.ndarray) -> tuple[np.ndarray, torch.Tensor, tuple[int, int]]: image_rgb = _to_rgb_uint8(image_rgb) org_h, org_w = image_rgb.shape[:2] resized = cv2.resize(image_rgb, INPUT_SIZE[::-1], interpolation=cv2.INTER_AREA) image = torch.from_numpy(resized).permute(2, 0, 1).unsqueeze(0).float() / 255.0 return image_rgb, image, (org_h, org_w) def _format_depth_status( model_type: str, metric_depth_source: str, intrinsics_source: str, output_hw: tuple[int, int], depth_file: Optional[str], ) -> str: depth_label = Path(depth_file).name if depth_file else "None" return ( f"Task: `Depth`\n\n" f"Model: `{model_type}`\n\n" f"Input depth: `{depth_label}`\n\n" f"Metric alignment source: `{metric_depth_source}`\n\n" f"Camera intrinsics source: `{intrinsics_source}`\n\n" f"Output size: `{output_hw[0]} x {output_hw[1]}`" ) def _format_gs_status( model_type: str, metric_depth_source: str, intrinsics_source: str, depth_file: Optional[str], gaussian_count: int, ) -> str: depth_label = Path(depth_file).name if depth_file else "None" return ( f"Task: `GS`\n\n" f"Model: `{model_type}`\n\n" f"Input depth: `{depth_label}`\n\n" f"Metric alignment source: `{metric_depth_source}`\n\n" f"Camera intrinsics source: `{intrinsics_source}`\n\n" f"Exported gaussians: `{gaussian_count}`" ) def _model_availability_note(depth_path: Optional[str], model_type: str, *, auto_switched: bool = False) -> str: if depth_path: if auto_switched and model_type == DEPTH_SENSOR_MODEL_TYPE: return ( "Depth file loaded. Switched model to `InfiniDepth_DepthSensor`. " "You can still switch back to `InfiniDepth` for RGB-only inference." ) return ( "Depth file loaded. `InfiniDepth_DepthSensor` is available. " "You can also keep `InfiniDepth` for RGB-only inference." ) if auto_switched: return ( "No input depth loaded. Switched model back to `InfiniDepth`. " "Upload a depth file to enable `InfiniDepth_DepthSensor`." ) return "No input depth loaded. `InfiniDepth` will be used until you upload a depth file." def _compose_depth_info_message(base_message: str, note: str) -> str: return f"{base_message}\n\n{note}" if note else base_message def _load_example_image(example_name: str) -> tuple[np.ndarray, Optional[str], Optional[np.ndarray], str, str]: if not example_name: raise gr.Error("Select an example case first.") case = EXAMPLE_LOOKUP[example_name] image_bgr = cv2.imread(case.image_path, cv2.IMREAD_COLOR) if image_bgr is None: raise gr.Error(f"Failed to load example image: {case.image_path}") image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) depth_path = case.depth_path preview = None detail = _compose_depth_info_message( f"Loaded example `{case.name}`.", _model_availability_note(None, RGB_MODEL_TYPE), ) model_type = RGB_MODEL_TYPE if depth_path is not None: preview, depth_msg = preview_depth_file(depth_path) model_type = DEPTH_SENSOR_MODEL_TYPE detail = _compose_depth_info_message( f"Loaded example `{case.name}` with paired depth. {depth_msg}", _model_availability_note(depth_path, model_type, auto_switched=True), ) return image_rgb, depth_path, preview, model_type, detail def _selected_example_message(example_name: Optional[str]) -> str: if not example_name or example_name not in EXAMPLE_LOOKUP: return "Select an example thumbnail, then click `Load Example`." case = EXAMPLE_LOOKUP[example_name] mode_label = "RGB + depth" if case.has_depth else "RGB only" return f"Selected example: `{case.name}` ({mode_label})" def _select_example(evt: gr.SelectData): if not EXAMPLE_CASES or evt.index is None: return None, _selected_example_message(None) index = evt.index if isinstance(index, (tuple, list)): index = index[0] case = EXAMPLE_CASES[int(index)] return case.name, _selected_example_message(case.name) def _primary_view_for_task(task_type: str): selected_tab = GS_VIEW_TAB_ID if task_type == GS_TASK_CHOICE else DEPTH_VIEW_TAB_ID return gr.update(selected=selected_tab) def _reset_uploaded_image_state( _image: Optional[np.ndarray], depth_path: Optional[str], ) -> tuple[None, None, str, str]: note = ( "Image updated. Cleared the previous depth file. Upload a new depth file to enable " "`InfiniDepth_DepthSensor`." if depth_path else "Image updated. Upload a depth file to enable `InfiniDepth_DepthSensor`." ) return None, None, RGB_MODEL_TYPE, note def _update_depth_preview(depth_path: Optional[str], model_type: str) -> tuple[Optional[np.ndarray], str, str]: try: preview, depth_msg = preview_depth_file(depth_path) except Exception as exc: raise gr.Error(f"Failed to preview depth file: {exc}") from exc if depth_path: next_model = DEPTH_SENSOR_MODEL_TYPE note = _model_availability_note(depth_path, next_model, auto_switched=(model_type != next_model)) else: next_model = RGB_MODEL_TYPE note = _model_availability_note(depth_path, next_model, auto_switched=(model_type != next_model)) return preview, next_model, _compose_depth_info_message(depth_msg, note) def _settings_visibility(task_type: str, output_resolution_mode: str): is_depth = task_type == "Depth" return ( gr.update(visible=is_depth), gr.update(visible=is_depth and output_resolution_mode == "upsample"), gr.update(visible=is_depth and output_resolution_mode == "specific"), gr.update(visible=is_depth and output_resolution_mode == "specific"), gr.update(visible=is_depth), ) def _normalize_filtered_gaussians(filtered_result): if isinstance(filtered_result, tuple): return filtered_result[0] return filtered_result @spaces.GPU(duration=DEPTH_GPU_DURATION_SECONDS) @torch.no_grad() def _run_depth_inference( image: np.ndarray, depth_file: Optional[str], model_type: str, output_resolution_mode: str, upsample_ratio: int, specific_height: int, specific_width: int, enable_skyseg_model: bool, filter_point_cloud: bool, fx_org: Optional[float], fy_org: Optional[float], cx_org: Optional[float], cy_org: Optional[float], request: gr.Request, ): _ensure_cuda() if image is None: raise gr.Error("Upload an image or load an example before running inference.") if model_type == DEPTH_SENSOR_MODEL_TYPE and not depth_file: raise gr.Error("InfiniDepth_DepthSensor requires an input depth file.") skyseg_path = _resolve_skyseg_path() if enable_skyseg_model else None image_rgb, image_tensor, (org_h, org_w) = _prepare_image_tensors(image) device = torch.device("cuda") image_tensor = image_tensor.to(device) model = _load_model(model_type) gt_depth, prompt_depth, gt_depth_mask, use_gt_depth, moge2_intrinsics = prepare_metric_depth_inputs( input_depth_path=depth_file, input_size=INPUT_SIZE, image=image_tensor, device=device, moge2_pretrained=_resolve_moge2_source(), ) gt_disp = depth_to_disparity(gt_depth) prompt_disp = depth_to_disparity(prompt_depth) fx_org, fy_org, cx_org, cy_org, intrinsics_source = resolve_camera_intrinsics_for_inference( fx_org=_to_optional_float(fx_org), fy_org=_to_optional_float(fy_org), cx_org=_to_optional_float(cx_org), cy_org=_to_optional_float(cy_org), org_h=org_h, org_w=org_w, image=image_tensor, moge2_pretrained=_resolve_moge2_source(), moge2_intrinsics=moge2_intrinsics, ) _, _, h, w = image_tensor.shape fx, fy, cx, cy, _ = build_scaled_intrinsics_matrix( fx_org=fx_org, fy_org=fy_org, cx_org=cx_org, cy_org=cy_org, org_h=org_h, org_w=org_w, h=h, w=w, device=image_tensor.device, ) sky_mask = run_optional_sky_mask( image=image_tensor, enable_skyseg_model=enable_skyseg_model, sky_model_ckpt_path=skyseg_path or str(LOCAL_SKYSEG_PATH), ) h_out, w_out = resolve_output_size_from_mode( output_resolution_mode=output_resolution_mode, org_h=org_h, org_w=org_w, h=h, w=w, output_size=(int(specific_height), int(specific_width)), upsample_ratio=int(upsample_ratio), ) query_2d_uniform_coord = SAMPLING_METHODS["2d_uniform"]((h_out, w_out)).unsqueeze(0).to(device) pred_2d_uniform_depth, _ = model.inference( image=image_tensor, query_coord=query_2d_uniform_coord, gt_depth=gt_disp, gt_depth_mask=gt_depth_mask, prompt_depth=prompt_disp, prompt_mask=prompt_disp > 0, ) pred_depthmap = pred_2d_uniform_depth.permute(0, 2, 1).reshape(1, 1, h_out, w_out) pred_depthmap, pred_2d_uniform_depth = apply_sky_mask_to_depth( pred_depthmap=pred_depthmap, pred_2d_uniform_depth=pred_2d_uniform_depth, sky_mask=sky_mask, h_sample=h_out, w_sample=w_out, sky_depth_value=200.0, ) session_hash = getattr(request, "session_hash", None) output_dir = ensure_session_output_dir(APP_NAME, session_hash) pred_depth_np = pred_depthmap.squeeze(0).squeeze(0).detach().cpu().numpy().astype(np.float32) artifacts = save_demo_artifacts(image_rgb=image_rgb, pred_depth=pred_depth_np, output_dir=output_dir) ply_path, glb_path = export_point_cloud_assets( sampled_coord=query_2d_uniform_coord.squeeze(0).cpu(), sampled_depth=pred_2d_uniform_depth.squeeze(0).squeeze(-1).cpu(), rgb_image=image_tensor.squeeze(0).cpu(), fx=fx, fy=fy, cx=cx, cy=cy, output_dir=output_dir, filter_flying_points=filter_point_cloud, ) artifacts = DemoArtifacts( comparison_path=artifacts.comparison_path, color_depth_path=artifacts.color_depth_path, gray_depth_path=artifacts.gray_depth_path, raw_depth_path=artifacts.raw_depth_path, ply_path=ply_path, glb_path=glb_path, ) metric_depth_source = "user depth" if use_gt_depth and depth_file else "MoGe-2" status = _format_depth_status( model_type=model_type, metric_depth_source=metric_depth_source, intrinsics_source=intrinsics_source, output_hw=(h_out, w_out), depth_file=depth_file, ) return ( status, artifacts.comparison_path, artifacts.color_depth_path, artifacts.gray_depth_path, glb_path, artifacts.download_files(), None, None, ) @spaces.GPU(duration=GS_GPU_DURATION_SECONDS) @torch.no_grad() def _run_gs_inference( image: np.ndarray, depth_file: Optional[str], model_type: str, enable_skyseg_model: bool, fx_org: Optional[float], fy_org: Optional[float], cx_org: Optional[float], cy_org: Optional[float], request: gr.Request, ): _ensure_cuda() if image is None: raise gr.Error("Upload an image or load an example before running inference.") if model_type == DEPTH_SENSOR_MODEL_TYPE and not depth_file: raise gr.Error("InfiniDepth_DepthSensor requires an input depth file for GS inference.") image_rgb, image_tensor, (org_h, org_w) = _prepare_image_tensors(image) del image_rgb device = torch.device("cuda") image_tensor = image_tensor.to(device) model = _load_model(model_type) gt_depth, prompt_depth, gt_depth_mask, use_gt_depth, moge2_intrinsics = prepare_metric_depth_inputs( input_depth_path=depth_file, input_size=INPUT_SIZE, image=image_tensor, device=device, moge2_pretrained=_resolve_moge2_source(), ) gt_disp = depth_to_disparity(gt_depth) prompt_disp = depth_to_disparity(prompt_depth) fx_org, fy_org, cx_org, cy_org, intrinsics_source = resolve_camera_intrinsics_for_inference( fx_org=_to_optional_float(fx_org), fy_org=_to_optional_float(fy_org), cx_org=_to_optional_float(cx_org), cy_org=_to_optional_float(cy_org), org_h=org_h, org_w=org_w, image=image_tensor, moge2_pretrained=_resolve_moge2_source(), moge2_intrinsics=moge2_intrinsics, ) b, _, h, w = image_tensor.shape _, _, _, _, intrinsics, extrinsics = build_camera_matrices( fx_org=fx_org, fy_org=fy_org, cx_org=cx_org, cy_org=cy_org, org_h=org_h, org_w=org_w, h=h, w=w, batch=b, device=device, ) skyseg_path = _resolve_skyseg_path() if enable_skyseg_model else str(LOCAL_SKYSEG_PATH) sky_mask = run_optional_sampling_sky_mask( image=image_tensor, enable_skyseg_model=enable_skyseg_model, sky_model_ckpt_path=skyseg_path, dilate_px=0, ) depthmap, dino_tokens, query_3d_uniform_coord, pred_depth_3d = model.inference_for_gs( image=image_tensor, intrinsics=intrinsics, gt_depth=gt_disp, gt_depth_mask=gt_depth_mask, prompt_depth=prompt_disp, prompt_mask=prompt_disp > 0, sky_mask=sky_mask, sample_point_num=GS_SAMPLE_POINT_NUM, coord_deterministic_sampling=GS_COORD_DETERMINISTIC_SAMPLING, ) if query_3d_uniform_coord is None or pred_depth_3d is None: raise gr.Error("GS inference did not return 3D-uniform query outputs.") gs_predictor = _load_gs_predictor(model_type, int(dino_tokens.shape[-1])) dense_gaussians = gs_predictor( image=image_tensor, depthmap=depthmap, dino_tokens=dino_tokens, intrinsics=intrinsics, extrinsics=extrinsics, ) pixel_gaussians = _build_sparse_uniform_gaussians( dense_gaussians=dense_gaussians, query_3d_uniform_coord=query_3d_uniform_coord, pred_depth_3d=pred_depth_3d, intrinsics=intrinsics, extrinsics=extrinsics, h=h, w=w, ) pixel_gaussians = _normalize_filtered_gaussians(filter_gaussians_by_statistical_outlier(pixel_gaussians)) gaussian_count = int(pixel_gaussians.means.shape[1]) if gaussian_count == 0: raise gr.Error("No valid gaussians remained after filtering.") means, harmonics, opacities, scales, rotations = unpack_gaussians_for_export(pixel_gaussians) session_hash = getattr(request, "session_hash", None) output_dir = ensure_session_output_dir(APP_NAME, session_hash) ply_path = output_dir / "gaussians.ply" export_ply( means=means, harmonics=harmonics, opacities=opacities, path=ply_path, scales=scales, rotations=rotations, focal_length_px=(fx_org, fy_org), principal_point_px=(cx_org, cy_org), image_shape=(org_h, org_w), extrinsic_matrix=extrinsics[0], ) try: gs_viewer_html = build_embedded_viewer_html(ply_path) except Exception as exc: print(f"[Warning] Failed to build embedded GS viewer: {exc}") gs_viewer_html = build_viewer_error_html(str(exc), ply_path) metric_depth_source = "user depth" if use_gt_depth and depth_file else "MoGe-2" status = _format_gs_status( model_type=model_type, metric_depth_source=metric_depth_source, intrinsics_source=intrinsics_source, depth_file=depth_file, gaussian_count=gaussian_count, ) download_files = [str(ply_path)] return ( status, None, None, None, None, None, gs_viewer_html, download_files, ) def _run_inference( task_type: str, image: np.ndarray, depth_file: Optional[str], model_type: str, output_resolution_mode: str, upsample_ratio: int, specific_height: int, specific_width: int, enable_skyseg_model: bool, filter_point_cloud: bool, fx_org: Optional[float], fy_org: Optional[float], cx_org: Optional[float], cy_org: Optional[float], request: gr.Request, ): if task_type == GS_TASK_CHOICE: return _run_gs_inference( image=image, depth_file=depth_file, model_type=model_type, enable_skyseg_model=enable_skyseg_model, fx_org=fx_org, fy_org=fy_org, cx_org=cx_org, cy_org=cy_org, request=request, ) return _run_depth_inference( image=image, depth_file=depth_file, model_type=model_type, output_resolution_mode=output_resolution_mode, upsample_ratio=upsample_ratio, specific_height=specific_height, specific_width=specific_width, enable_skyseg_model=enable_skyseg_model, filter_point_cloud=filter_point_cloud, fx_org=fx_org, fy_org=fy_org, cx_org=cx_org, cy_org=cy_org, request=request, ) def _clear_outputs(): return "", None, None, None, None, None, "", None with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo: gr.Markdown("# InfiniDepth Demo") gr.Markdown( "Switch between depth inference and GS inference. `InfiniDepth` works with RGB-only inputs, while `InfiniDepth_DepthSensor` is enabled only after you upload a depth file or load an example with paired depth." ) selected_example_name = gr.State(DEFAULT_EXAMPLE_NAME) with gr.Row(elem_id="top-workspace"): with gr.Column(scale=4, min_width=320, elem_id="controls-column"): task_type = gr.Radio(label="Inference Task", choices=TASK_CHOICES, value="Depth") model_type = gr.Radio(label="Model Type", choices=MODEL_CHOICES, value=RGB_MODEL_TYPE) gr.Markdown("### Example Data") example_gallery = gr.Gallery( value=EXAMPLE_GALLERY_ITEMS, label="Example Data", show_label=False, columns=2, height=280, object_fit="cover", allow_preview=False, selected_index=DEFAULT_EXAMPLE_INDEX, elem_id="example-gallery", ) example_selection = gr.Markdown(_selected_example_message(DEFAULT_EXAMPLE_NAME)) load_example_btn = gr.Button("Load Example") with gr.Accordion("Depth Settings", open=True): output_resolution_mode = gr.Dropdown( label="Output Resolution Mode", choices=OUTPUT_MODE_CHOICES, value="upsample", ) upsample_ratio = gr.Slider(label="Upsample Ratio", minimum=1, maximum=4, step=1, value=1) specific_height = gr.Number(label="Specific Height", value=INPUT_SIZE[0], precision=0, visible=False) specific_width = gr.Number(label="Specific Width", value=INPUT_SIZE[1], precision=0, visible=False) enable_skyseg_model = gr.Checkbox(label="Apply Sky Mask", value=False) filter_point_cloud = gr.Checkbox(label="Filter Flying Points", value=True) with gr.Accordion("Optional Camera Intrinsics", open=False): fx_org = gr.Textbox(label="fx", value="", placeholder="auto") fy_org = gr.Textbox(label="fy", value="", placeholder="auto") cx_org = gr.Textbox(label="cx", value="", placeholder="auto") cy_org = gr.Textbox(label="cy", value="", placeholder="auto") with gr.Column(scale=5, min_width=360, elem_id="inputs-column"): input_image = gr.Image( label="Input Image", image_mode="RGB", type="numpy", sources=["upload", "clipboard", "webcam"], height=420, elem_id="input-image", ) input_depth_file = gr.File( label="Optional Depth File", type="filepath", file_types=[".png", ".npy", ".npz", ".h5", ".hdf5", ".exr"], ) input_depth_preview = gr.Image( label="Input Depth Preview", type="numpy", height=240, elem_id="input-depth-preview", ) depth_info = gr.Markdown("No input depth loaded. `InfiniDepth` will be used until you upload a depth file.") submit_btn = gr.Button("Run Inference", variant="primary") with gr.Column(scale=8, min_width=640, elem_id="outputs-column"): status_output = gr.Markdown() with gr.Tabs(selected=DEPTH_VIEW_TAB_ID, elem_id="primary-view-tabs") as primary_view_tabs: with gr.Tab("PCD Viewer", id=DEPTH_VIEW_TAB_ID, render_children=True): depth_model_3d = gr.Model3D( label="Point Cloud Viewer", display_mode="solid", clear_color=[1.0, 1.0, 1.0, 1.0], height=700, elem_id="depth-model3d-viewer", ) with gr.Tab("GS Viewer", id=GS_VIEW_TAB_ID, render_children=True): gs_viewer_html = gr.HTML(elem_id="gs-viewer-html") with gr.Tabs(elem_id="secondary-output-tabs"): with gr.Tab("Depth Analysis", render_children=True): depth_comparison = gr.Image( label="RGB vs Depth", type="filepath", height=280, elem_id="depth-comparison", ) with gr.Row(): color_depth = gr.Image( label="Colorized Depth", type="filepath", height=260, elem_id="depth-color", ) gray_depth = gr.Image( label="Grayscale Depth", type="filepath", height=260, elem_id="depth-preview", ) with gr.Tab("Downloads", render_children=True): with gr.Row(): depth_download_files = gr.File(label="Depth Files", type="filepath") gs_download_files = gr.File(label="GS Files", type="filepath") task_type.change( fn=_settings_visibility, inputs=[task_type, output_resolution_mode], outputs=[output_resolution_mode, upsample_ratio, specific_height, specific_width, filter_point_cloud], ) task_type.change( fn=_primary_view_for_task, inputs=[task_type], outputs=[primary_view_tabs], ) output_resolution_mode.change( fn=_settings_visibility, inputs=[task_type, output_resolution_mode], outputs=[output_resolution_mode, upsample_ratio, specific_height, specific_width, filter_point_cloud], ) example_gallery.select( fn=_select_example, outputs=[selected_example_name, example_selection], ) input_image.input( fn=_reset_uploaded_image_state, inputs=[input_image, input_depth_file], outputs=[input_depth_file, input_depth_preview, model_type, depth_info], ) input_depth_file.change( fn=_update_depth_preview, inputs=[input_depth_file, model_type], outputs=[input_depth_preview, model_type, depth_info], ) load_example_btn.click( fn=_load_example_image, inputs=[selected_example_name], outputs=[input_image, input_depth_file, input_depth_preview, model_type, depth_info], ) submit_btn.click( fn=_primary_view_for_task, inputs=[task_type], outputs=[primary_view_tabs], ).then( fn=_clear_outputs, outputs=[ status_output, depth_comparison, color_depth, gray_depth, depth_model_3d, depth_download_files, gs_viewer_html, gs_download_files, ], ).then( fn=_run_inference, inputs=[ task_type, input_image, input_depth_file, model_type, output_resolution_mode, upsample_ratio, specific_height, specific_width, enable_skyseg_model, filter_point_cloud, fx_org, fy_org, cx_org, cy_org, ], outputs=[ status_output, depth_comparison, color_depth, gray_depth, depth_model_3d, depth_download_files, gs_viewer_html, gs_download_files, ], ) if __name__ == "__main__": _preload_repo_assets() demo.queue().launch()