from __future__ import absolute_import, division, print_function import os import sys import cv2 import yaml import numpy as np import gradio as gr from huggingface_hub import hf_hub_download try: import spaces gpu_decorator = spaces.GPU except Exception: gpu_decorator = lambda f: f PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__)) sys.path.append(PROJECT_ROOT) from networks.models import make # noqa: E402 WEIGHTS_REPO = "Insta360-Research/DAP-weights" WEIGHTS_FILE = "model.pth" CONFIG_PATH = os.path.join(PROJECT_ROOT, "config", "infer.yaml") model = None device = "cpu" import matplotlib def colorize_depth_fixed(depth_u8: np.ndarray, cmap: str = "Spectral") -> np.ndarray: """ depth_u8: uint8, 0~255 return: RGB uint8 """ disp = depth_u8.astype(np.float32) / 255.0 colored = matplotlib.colormaps[cmap](disp)[..., :3] colored = (colored * 255).astype(np.uint8) return np.ascontiguousarray(colored) def load_model(config_path: str): import torch import torch.nn as nn global device device = "cuda" if torch.cuda.is_available() else "cpu" with open(config_path, "r") as f: config = yaml.load(f, Loader=yaml.FullLoader) print(f"Downloading weights from HF: {WEIGHTS_REPO}/{WEIGHTS_FILE}") model_path = hf_hub_download( repo_id=WEIGHTS_REPO, filename=WEIGHTS_FILE ) print(f"✅ Weights downloaded to: {model_path}") state = torch.load(model_path, map_location=device) m = make(config["model"]) if any(k.startswith("module") for k in state.keys()): m = nn.DataParallel(m) m = m.to(device) m_state = m.state_dict() m.load_state_dict( {k: v for k, v in state.items() if k in m_state}, strict=False ) m.eval() print("✅ Model loaded.") return m model = load_model(CONFIG_PATH) COLORBAR_DIR = os.path.join(PROJECT_ROOT, "colorbars") colorbar_100m_color = cv2.imread(os.path.join(COLORBAR_DIR, "colorbar_100m_color.png")) colorbar_100m_gray = cv2.imread(os.path.join(COLORBAR_DIR, "colorbar_100m_gray.png")) colorbar_10m_color = cv2.imread(os.path.join(COLORBAR_DIR, "colorbar_10m_color.png")) colorbar_10m_gray = cv2.imread(os.path.join(COLORBAR_DIR, "colorbar_10m_gray.png")) if colorbar_100m_color is not None: colorbar_100m_color = cv2.cvtColor(colorbar_100m_color, cv2.COLOR_BGR2RGB) if colorbar_100m_gray is not None: colorbar_100m_gray = cv2.cvtColor(colorbar_100m_gray, cv2.COLOR_BGR2RGB) if colorbar_10m_color is not None: colorbar_10m_color = cv2.cvtColor(colorbar_10m_color, cv2.COLOR_BGR2RGB) if colorbar_10m_gray is not None: colorbar_10m_gray = cv2.cvtColor(colorbar_10m_gray, cv2.COLOR_BGR2RGB) @gpu_decorator def infer_raw(img_rgb: np.ndarray): if img_rgb is None: return None import torch img = img_rgb.astype(np.float32) / 255.0 tensor = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).to(device) with torch.inference_mode(): outputs = model(tensor) if isinstance(outputs, dict) and "pred_depth" in outputs: if "pred_mask" in outputs: mask = 1 - outputs["pred_mask"] mask = mask > 0.5 outputs["pred_depth"][~mask] = 1 pred = outputs["pred_depth"][0].cpu().squeeze().numpy() else: pred = outputs[0].cpu().squeeze().numpy() return pred.astype(np.float32) def visualize_100m(pred: np.ndarray): if pred is None: return None, None, None, None, None pred_clip = np.clip(pred, 0.0, 1.0) depth_gray = (pred_clip * 255).astype(np.uint8) depth_color = colorize_depth_fixed(depth_gray, cmap="Spectral") npy_path = "/tmp/depth_100m.npy" np.save(npy_path, pred) return depth_color, depth_gray, npy_path, colorbar_100m_color, colorbar_100m_gray def visualize_10m(pred: np.ndarray): if pred is None: return None, None, None, None, None pred_clip = np.clip(pred, 0.0, 0.1) depth_gray = (pred_clip * 10 * 255).astype(np.uint8) depth_color = colorize_depth_fixed(depth_gray, cmap="Spectral") npy_path = "/tmp/depth_10m.npy" np.save(npy_path, pred) return depth_color, depth_gray, npy_path, colorbar_10m_color, colorbar_10m_gray @gpu_decorator def infer_and_vis_100m(img_rgb: np.ndarray): pred = infer_raw(img_rgb) color, gray, npy, cbar_color, cbar_gray = visualize_100m(pred) return pred, color, gray, npy, cbar_color, cbar_gray example_paths = [ "hfdemo/01.jpg", "hfdemo/02.jpg", "hfdemo/03.jpg", "hfdemo/04.jpg", "hfdemo/05.jpg", "hfdemo/06.jpg", "hfdemo/07.jpg", "hfdemo/08.jpg", "hfdemo/09.jpg", "hfdemo/10.jpg", "hfdemo/11.jpg", ] example_gen_paths = [ "hfdemo/generated_00.jpg", "hfdemo/generated_01.jpg", "hfdemo/generated_02.jpg", "hfdemo/generated_03.jpg", "hfdemo/generated_04.jpg", "hfdemo/generated_05.jpg", "hfdemo/generated_06.jpg", "hfdemo/generated_07.jpg", ] with gr.Blocks() as demo: gr.Markdown( """ # 🌀 DAP Depth Prediction Demo Here are our resources: - 💻 **Code**: [https://github.com/Insta360-Research-Team/DAP](https://github.com/Insta360-Research-Team/DAP) - 🌐 **Web Page**: [https://insta360-research-team.github.io/DAP_website/](https://insta360-research-team.github.io/DAP_website/) - 🧠 **Pretrained Model**: [https://huggingface.co/Insta360-Research/DAP-weights](https://huggingface.co/Insta360-Research/DAP-weights) """ ) gr.Markdown("# Official Depth Prediction demo for **[DAP](https://insta360-research-team.github.io/DAP_website/)**") raw_depth = gr.State() with gr.Row(): with gr.Column(scale=10): inp = gr.Image( type="numpy", label="Input Image", height=360 ) gr.Markdown("### Examples (click to load)") gr.Examples(examples=example_paths, inputs=inp) gr.Markdown("### Examples from Gemini (click to load)") gr.Examples(examples=example_gen_paths, inputs=inp) btn_infer = gr.Button("Run Inference", variant="primary") btn_100m = gr.Button("Visualize (100m)") btn_10m = gr.Button("Visualize (10m)") gr.Markdown( """ Visualization range:
100m: recommended for outdoor scenes
10m: recommended for indoor scenes
(Only affects visualization, not the raw depth output)
""", elem_id="vis_hint", ) with gr.Column(scale=11): # -------- Row 1: Color Depth -------- with gr.Row(): with gr.Column(scale=10): out_color = gr.Image( label="Depth (Color)", height=260 ) with gr.Column(scale=1, min_width=80): colorbar_color = gr.Image( label="Scale", height=260, show_label=False ) with gr.Row(): with gr.Column(scale=10): out_gray = gr.Image( label="Depth (Gray)", height=260 ) with gr.Column(scale=1, min_width=80): colorbar_gray = gr.Image( label="Scale", height=260, show_label=False ) out_npy = gr.File(label="Depth (.npy)") btn_infer.click( fn=infer_and_vis_100m, inputs=inp, outputs=[raw_depth, out_color, out_gray, out_npy, colorbar_color, colorbar_gray], ) btn_100m.click( fn=visualize_100m, inputs=raw_depth, outputs=[out_color, out_gray, out_npy, colorbar_color, colorbar_gray], ) btn_10m.click( fn=visualize_10m, inputs=raw_depth, outputs=[out_color, out_gray, out_npy, colorbar_color, colorbar_gray], ) if __name__ == "__main__": host = os.environ.get("HOST", "0.0.0.0") port = int(os.environ.get("PORT", "7860")) demo.queue( max_size=32, default_concurrency_limit=1, ).launch( server_name=host, server_port=port, ssr_mode=False, show_error=True, )