Spaces:
Running
on
Zero
Running
on
Zero
| from __future__ import absolute_import, division, print_function | |
| import os | |
| import sys | |
| import cv2 | |
| import yaml | |
| import numpy as np | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| try: | |
| import spaces | |
| gpu_decorator = spaces.GPU | |
| except Exception: | |
| gpu_decorator = lambda f: f | |
| PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__)) | |
| sys.path.append(PROJECT_ROOT) | |
| from networks.models import make # noqa: E402 | |
| WEIGHTS_REPO = "Insta360-Research/DAP-weights" | |
| WEIGHTS_FILE = "model.pth" | |
| CONFIG_PATH = os.path.join(PROJECT_ROOT, "config", "infer.yaml") | |
| model = None | |
| device = "cpu" | |
| import matplotlib | |
| def colorize_depth_fixed(depth_u8: np.ndarray, cmap: str = "Spectral") -> np.ndarray: | |
| """ | |
| depth_u8: uint8, 0~255 | |
| return: RGB uint8 | |
| """ | |
| disp = depth_u8.astype(np.float32) / 255.0 | |
| colored = matplotlib.colormaps[cmap](disp)[..., :3] | |
| colored = (colored * 255).astype(np.uint8) | |
| return np.ascontiguousarray(colored) | |
| def load_model(config_path: str): | |
| import torch | |
| import torch.nn as nn | |
| global device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| with open(config_path, "r") as f: | |
| config = yaml.load(f, Loader=yaml.FullLoader) | |
| print(f"Downloading weights from HF: {WEIGHTS_REPO}/{WEIGHTS_FILE}") | |
| model_path = hf_hub_download( | |
| repo_id=WEIGHTS_REPO, | |
| filename=WEIGHTS_FILE | |
| ) | |
| print(f"β Weights downloaded to: {model_path}") | |
| state = torch.load(model_path, map_location=device) | |
| m = make(config["model"]) | |
| if any(k.startswith("module") for k in state.keys()): | |
| m = nn.DataParallel(m) | |
| m = m.to(device) | |
| m_state = m.state_dict() | |
| m.load_state_dict( | |
| {k: v for k, v in state.items() if k in m_state}, | |
| strict=False | |
| ) | |
| m.eval() | |
| print("β Model loaded.") | |
| return m | |
| model = load_model(CONFIG_PATH) | |
| COLORBAR_DIR = os.path.join(PROJECT_ROOT, "colorbars") | |
| colorbar_100m_color = cv2.imread(os.path.join(COLORBAR_DIR, "colorbar_100m_color.png")) | |
| colorbar_100m_gray = cv2.imread(os.path.join(COLORBAR_DIR, "colorbar_100m_gray.png")) | |
| colorbar_10m_color = cv2.imread(os.path.join(COLORBAR_DIR, "colorbar_10m_color.png")) | |
| colorbar_10m_gray = cv2.imread(os.path.join(COLORBAR_DIR, "colorbar_10m_gray.png")) | |
| if colorbar_100m_color is not None: | |
| colorbar_100m_color = cv2.cvtColor(colorbar_100m_color, cv2.COLOR_BGR2RGB) | |
| if colorbar_100m_gray is not None: | |
| colorbar_100m_gray = cv2.cvtColor(colorbar_100m_gray, cv2.COLOR_BGR2RGB) | |
| if colorbar_10m_color is not None: | |
| colorbar_10m_color = cv2.cvtColor(colorbar_10m_color, cv2.COLOR_BGR2RGB) | |
| if colorbar_10m_gray is not None: | |
| colorbar_10m_gray = cv2.cvtColor(colorbar_10m_gray, cv2.COLOR_BGR2RGB) | |
| def infer_raw(img_rgb: np.ndarray): | |
| if img_rgb is None: | |
| return None | |
| import torch | |
| img = img_rgb.astype(np.float32) / 255.0 | |
| tensor = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).to(device) | |
| with torch.inference_mode(): | |
| outputs = model(tensor) | |
| if isinstance(outputs, dict) and "pred_depth" in outputs: | |
| if "pred_mask" in outputs: | |
| mask = 1 - outputs["pred_mask"] | |
| mask = mask > 0.5 | |
| outputs["pred_depth"][~mask] = 1 | |
| pred = outputs["pred_depth"][0].cpu().squeeze().numpy() | |
| else: | |
| pred = outputs[0].cpu().squeeze().numpy() | |
| return pred.astype(np.float32) | |
| def visualize_100m(pred: np.ndarray): | |
| if pred is None: | |
| return None, None, None, None, None | |
| pred_clip = np.clip(pred, 0.0, 1.0) | |
| depth_gray = (pred_clip * 255).astype(np.uint8) | |
| depth_color = colorize_depth_fixed(depth_gray, cmap="Spectral") | |
| npy_path = "/tmp/depth_100m.npy" | |
| np.save(npy_path, pred) | |
| return depth_color, depth_gray, npy_path, colorbar_100m_color, colorbar_100m_gray | |
| def visualize_10m(pred: np.ndarray): | |
| if pred is None: | |
| return None, None, None, None, None | |
| pred_clip = np.clip(pred, 0.0, 0.1) | |
| depth_gray = (pred_clip * 10 * 255).astype(np.uint8) | |
| depth_color = colorize_depth_fixed(depth_gray, cmap="Spectral") | |
| npy_path = "/tmp/depth_10m.npy" | |
| np.save(npy_path, pred) | |
| return depth_color, depth_gray, npy_path, colorbar_10m_color, colorbar_10m_gray | |
| def infer_and_vis_100m(img_rgb: np.ndarray): | |
| pred = infer_raw(img_rgb) | |
| color, gray, npy, cbar_color, cbar_gray = visualize_100m(pred) | |
| return pred, color, gray, npy, cbar_color, cbar_gray | |
| example_paths = [ | |
| "hfdemo/01.jpg", | |
| "hfdemo/02.jpg", | |
| "hfdemo/03.jpg", | |
| "hfdemo/04.jpg", | |
| "hfdemo/05.jpg", | |
| "hfdemo/06.jpg", | |
| "hfdemo/07.jpg", | |
| "hfdemo/08.jpg", | |
| "hfdemo/09.jpg", | |
| "hfdemo/10.jpg", | |
| "hfdemo/11.jpg", | |
| ] | |
| example_gen_paths = [ | |
| "hfdemo/generated_00.jpg", | |
| "hfdemo/generated_01.jpg", | |
| "hfdemo/generated_02.jpg", | |
| "hfdemo/generated_03.jpg", | |
| "hfdemo/generated_04.jpg", | |
| "hfdemo/generated_05.jpg", | |
| "hfdemo/generated_06.jpg", | |
| "hfdemo/generated_07.jpg", | |
| ] | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # π DAP Depth Prediction Demo | |
| Here are our resources: | |
| - π» **Code**: [https://github.com/Insta360-Research-Team/DAP](https://github.com/Insta360-Research-Team/DAP) | |
| - π **Web Page**: [https://insta360-research-team.github.io/DAP_website/](https://insta360-research-team.github.io/DAP_website/) | |
| - π§ **Pretrained Model**: [https://huggingface.co/Insta360-Research/DAP-weights](https://huggingface.co/Insta360-Research/DAP-weights) | |
| """ | |
| ) | |
| gr.Markdown("# Official Depth Prediction demo for **[DAP](https://insta360-research-team.github.io/DAP_website/)**") | |
| raw_depth = gr.State() | |
| with gr.Row(): | |
| with gr.Column(scale=10): | |
| inp = gr.Image( | |
| type="numpy", | |
| label="Input Image", | |
| height=360 | |
| ) | |
| gr.Markdown("### Examples (click to load)") | |
| gr.Examples(examples=example_paths, inputs=inp) | |
| gr.Markdown("### Examples from Gemini (click to load)") | |
| gr.Examples(examples=example_gen_paths, inputs=inp) | |
| btn_infer = gr.Button("Run Inference", variant="primary") | |
| btn_100m = gr.Button("Visualize (100m)") | |
| btn_10m = gr.Button("Visualize (10m)") | |
| gr.Markdown( | |
| """ | |
| <small> | |
| <b>Visualization range:</b><br> | |
| β’ <b>100m</b>: recommended for <b>outdoor</b> scenes<br> | |
| β’ <b>10m</b>: recommended for <b>indoor</b> scenes<br> | |
| (Only affects visualization, not the raw depth output) | |
| </small> | |
| """, | |
| elem_id="vis_hint", | |
| ) | |
| with gr.Column(scale=11): | |
| # -------- Row 1: Color Depth -------- | |
| with gr.Row(): | |
| with gr.Column(scale=10): | |
| out_color = gr.Image( | |
| label="Depth (Color)", | |
| height=260 | |
| ) | |
| with gr.Column(scale=1, min_width=80): | |
| colorbar_color = gr.Image( | |
| label="Scale", | |
| height=260, | |
| show_label=False | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=10): | |
| out_gray = gr.Image( | |
| label="Depth (Gray)", | |
| height=260 | |
| ) | |
| with gr.Column(scale=1, min_width=80): | |
| colorbar_gray = gr.Image( | |
| label="Scale", | |
| height=260, | |
| show_label=False | |
| ) | |
| out_npy = gr.File(label="Depth (.npy)") | |
| btn_infer.click( | |
| fn=infer_and_vis_100m, | |
| inputs=inp, | |
| outputs=[raw_depth, out_color, out_gray, out_npy, colorbar_color, colorbar_gray], | |
| ) | |
| btn_100m.click( | |
| fn=visualize_100m, | |
| inputs=raw_depth, | |
| outputs=[out_color, out_gray, out_npy, colorbar_color, colorbar_gray], | |
| ) | |
| btn_10m.click( | |
| fn=visualize_10m, | |
| inputs=raw_depth, | |
| outputs=[out_color, out_gray, out_npy, colorbar_color, colorbar_gray], | |
| ) | |
| if __name__ == "__main__": | |
| host = os.environ.get("HOST", "0.0.0.0") | |
| port = int(os.environ.get("PORT", "7860")) | |
| demo.queue( | |
| max_size=32, | |
| default_concurrency_limit=1, | |
| ).launch( | |
| server_name=host, | |
| server_port=port, | |
| ssr_mode=False, | |
| show_error=True, | |
| ) |