DAP / app.py
Insta360-Research's picture
Update app.py
f5f5ec2 verified
from __future__ import absolute_import, division, print_function
import os
import sys
import cv2
import yaml
import numpy as np
import gradio as gr
from huggingface_hub import hf_hub_download
try:
import spaces
gpu_decorator = spaces.GPU
except Exception:
gpu_decorator = lambda f: f
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
sys.path.append(PROJECT_ROOT)
from networks.models import make # noqa: E402
WEIGHTS_REPO = "Insta360-Research/DAP-weights"
WEIGHTS_FILE = "model.pth"
CONFIG_PATH = os.path.join(PROJECT_ROOT, "config", "infer.yaml")
model = None
device = "cpu"
import matplotlib
def colorize_depth_fixed(depth_u8: np.ndarray, cmap: str = "Spectral") -> np.ndarray:
"""
depth_u8: uint8, 0~255
return: RGB uint8
"""
disp = depth_u8.astype(np.float32) / 255.0
colored = matplotlib.colormaps[cmap](disp)[..., :3]
colored = (colored * 255).astype(np.uint8)
return np.ascontiguousarray(colored)
def load_model(config_path: str):
import torch
import torch.nn as nn
global device
device = "cuda" if torch.cuda.is_available() else "cpu"
with open(config_path, "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
print(f"Downloading weights from HF: {WEIGHTS_REPO}/{WEIGHTS_FILE}")
model_path = hf_hub_download(
repo_id=WEIGHTS_REPO,
filename=WEIGHTS_FILE
)
print(f"βœ… Weights downloaded to: {model_path}")
state = torch.load(model_path, map_location=device)
m = make(config["model"])
if any(k.startswith("module") for k in state.keys()):
m = nn.DataParallel(m)
m = m.to(device)
m_state = m.state_dict()
m.load_state_dict(
{k: v for k, v in state.items() if k in m_state},
strict=False
)
m.eval()
print("βœ… Model loaded.")
return m
model = load_model(CONFIG_PATH)
COLORBAR_DIR = os.path.join(PROJECT_ROOT, "colorbars")
colorbar_100m_color = cv2.imread(os.path.join(COLORBAR_DIR, "colorbar_100m_color.png"))
colorbar_100m_gray = cv2.imread(os.path.join(COLORBAR_DIR, "colorbar_100m_gray.png"))
colorbar_10m_color = cv2.imread(os.path.join(COLORBAR_DIR, "colorbar_10m_color.png"))
colorbar_10m_gray = cv2.imread(os.path.join(COLORBAR_DIR, "colorbar_10m_gray.png"))
if colorbar_100m_color is not None:
colorbar_100m_color = cv2.cvtColor(colorbar_100m_color, cv2.COLOR_BGR2RGB)
if colorbar_100m_gray is not None:
colorbar_100m_gray = cv2.cvtColor(colorbar_100m_gray, cv2.COLOR_BGR2RGB)
if colorbar_10m_color is not None:
colorbar_10m_color = cv2.cvtColor(colorbar_10m_color, cv2.COLOR_BGR2RGB)
if colorbar_10m_gray is not None:
colorbar_10m_gray = cv2.cvtColor(colorbar_10m_gray, cv2.COLOR_BGR2RGB)
@gpu_decorator
def infer_raw(img_rgb: np.ndarray):
if img_rgb is None:
return None
import torch
img = img_rgb.astype(np.float32) / 255.0
tensor = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).to(device)
with torch.inference_mode():
outputs = model(tensor)
if isinstance(outputs, dict) and "pred_depth" in outputs:
if "pred_mask" in outputs:
mask = 1 - outputs["pred_mask"]
mask = mask > 0.5
outputs["pred_depth"][~mask] = 1
pred = outputs["pred_depth"][0].cpu().squeeze().numpy()
else:
pred = outputs[0].cpu().squeeze().numpy()
return pred.astype(np.float32)
def visualize_100m(pred: np.ndarray):
if pred is None:
return None, None, None, None, None
pred_clip = np.clip(pred, 0.0, 1.0)
depth_gray = (pred_clip * 255).astype(np.uint8)
depth_color = colorize_depth_fixed(depth_gray, cmap="Spectral")
npy_path = "/tmp/depth_100m.npy"
np.save(npy_path, pred)
return depth_color, depth_gray, npy_path, colorbar_100m_color, colorbar_100m_gray
def visualize_10m(pred: np.ndarray):
if pred is None:
return None, None, None, None, None
pred_clip = np.clip(pred, 0.0, 0.1)
depth_gray = (pred_clip * 10 * 255).astype(np.uint8)
depth_color = colorize_depth_fixed(depth_gray, cmap="Spectral")
npy_path = "/tmp/depth_10m.npy"
np.save(npy_path, pred)
return depth_color, depth_gray, npy_path, colorbar_10m_color, colorbar_10m_gray
@gpu_decorator
def infer_and_vis_100m(img_rgb: np.ndarray):
pred = infer_raw(img_rgb)
color, gray, npy, cbar_color, cbar_gray = visualize_100m(pred)
return pred, color, gray, npy, cbar_color, cbar_gray
example_paths = [
"hfdemo/01.jpg",
"hfdemo/02.jpg",
"hfdemo/03.jpg",
"hfdemo/04.jpg",
"hfdemo/05.jpg",
"hfdemo/06.jpg",
"hfdemo/07.jpg",
"hfdemo/08.jpg",
"hfdemo/09.jpg",
"hfdemo/10.jpg",
"hfdemo/11.jpg",
]
example_gen_paths = [
"hfdemo/generated_00.jpg",
"hfdemo/generated_01.jpg",
"hfdemo/generated_02.jpg",
"hfdemo/generated_03.jpg",
"hfdemo/generated_04.jpg",
"hfdemo/generated_05.jpg",
"hfdemo/generated_06.jpg",
"hfdemo/generated_07.jpg",
]
with gr.Blocks() as demo:
gr.Markdown(
"""
# πŸŒ€ DAP Depth Prediction Demo
Here are our resources:
- πŸ’» **Code**: [https://github.com/Insta360-Research-Team/DAP](https://github.com/Insta360-Research-Team/DAP)
- 🌐 **Web Page**: [https://insta360-research-team.github.io/DAP_website/](https://insta360-research-team.github.io/DAP_website/)
- 🧠 **Pretrained Model**: [https://huggingface.co/Insta360-Research/DAP-weights](https://huggingface.co/Insta360-Research/DAP-weights)
"""
)
gr.Markdown("# Official Depth Prediction demo for **[DAP](https://insta360-research-team.github.io/DAP_website/)**")
raw_depth = gr.State()
with gr.Row():
with gr.Column(scale=10):
inp = gr.Image(
type="numpy",
label="Input Image",
height=360
)
gr.Markdown("### Examples (click to load)")
gr.Examples(examples=example_paths, inputs=inp)
gr.Markdown("### Examples from Gemini (click to load)")
gr.Examples(examples=example_gen_paths, inputs=inp)
btn_infer = gr.Button("Run Inference", variant="primary")
btn_100m = gr.Button("Visualize (100m)")
btn_10m = gr.Button("Visualize (10m)")
gr.Markdown(
"""
<small>
<b>Visualization range:</b><br>
β€’ <b>100m</b>: recommended for <b>outdoor</b> scenes<br>
β€’ <b>10m</b>: recommended for <b>indoor</b> scenes<br>
(Only affects visualization, not the raw depth output)
</small>
""",
elem_id="vis_hint",
)
with gr.Column(scale=11):
# -------- Row 1: Color Depth --------
with gr.Row():
with gr.Column(scale=10):
out_color = gr.Image(
label="Depth (Color)",
height=260
)
with gr.Column(scale=1, min_width=80):
colorbar_color = gr.Image(
label="Scale",
height=260,
show_label=False
)
with gr.Row():
with gr.Column(scale=10):
out_gray = gr.Image(
label="Depth (Gray)",
height=260
)
with gr.Column(scale=1, min_width=80):
colorbar_gray = gr.Image(
label="Scale",
height=260,
show_label=False
)
out_npy = gr.File(label="Depth (.npy)")
btn_infer.click(
fn=infer_and_vis_100m,
inputs=inp,
outputs=[raw_depth, out_color, out_gray, out_npy, colorbar_color, colorbar_gray],
)
btn_100m.click(
fn=visualize_100m,
inputs=raw_depth,
outputs=[out_color, out_gray, out_npy, colorbar_color, colorbar_gray],
)
btn_10m.click(
fn=visualize_10m,
inputs=raw_depth,
outputs=[out_color, out_gray, out_npy, colorbar_color, colorbar_gray],
)
if __name__ == "__main__":
host = os.environ.get("HOST", "0.0.0.0")
port = int(os.environ.get("PORT", "7860"))
demo.queue(
max_size=32,
default_concurrency_limit=1,
).launch(
server_name=host,
server_port=port,
ssr_mode=False,
show_error=True,
)