Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
from __future__ import absolute_import, division, print_function
|
| 2 |
|
| 3 |
-
import os
|
|
|
|
| 4 |
import cv2
|
| 5 |
import yaml
|
| 6 |
import numpy as np
|
|
@@ -9,12 +10,13 @@ from huggingface_hub import hf_hub_download
|
|
| 9 |
|
| 10 |
# ================== 必须最早 import spaces ==================
|
| 11 |
try:
|
| 12 |
-
import spaces
|
| 13 |
gpu_decorator = spaces.GPU
|
| 14 |
except Exception:
|
| 15 |
gpu_decorator = lambda f: f
|
| 16 |
|
| 17 |
-
# ==================
|
|
|
|
| 18 |
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
|
| 19 |
sys.path.append(PROJECT_ROOT)
|
| 20 |
|
|
@@ -23,38 +25,23 @@ from networks.models import make # noqa: E402
|
|
| 23 |
# ================== HF 模型配置 ==================
|
| 24 |
WEIGHTS_REPO = "Insta360-Research/DAP-weights"
|
| 25 |
WEIGHTS_FILE = "model.pth"
|
| 26 |
-
CONFIG_PATH = "config
|
| 27 |
|
| 28 |
model = None
|
| 29 |
device = "cpu"
|
| 30 |
|
| 31 |
-
# ==================
|
| 32 |
import matplotlib
|
| 33 |
|
| 34 |
-
def
|
| 35 |
-
depth: np.ndarray,
|
| 36 |
-
cmap: str = "Spectral",
|
| 37 |
-
depth_truncation: float = 1.0
|
| 38 |
-
) -> np.ndarray:
|
| 39 |
"""
|
| 40 |
-
|
| 41 |
-
depth_truncation: 归一化后的截断阈值(0~1),超过的都视为最远
|
| 42 |
return: RGB uint8
|
| 43 |
"""
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
dmax = float(np.max(depth))
|
| 49 |
-
denom = (dmax - dmin) + 1e-8
|
| 50 |
-
|
| 51 |
-
depth_normalized = (depth - dmin) / denom
|
| 52 |
-
depth_normalized = np.clip(depth_normalized, 0.0, float(depth_truncation))
|
| 53 |
-
depth_normalized = depth_normalized / (float(depth_truncation) + 1e-8)
|
| 54 |
-
|
| 55 |
-
colored = matplotlib.colormaps[cmap](depth_normalized)[..., :3]
|
| 56 |
-
colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8))
|
| 57 |
-
return colored
|
| 58 |
|
| 59 |
# ================== 模型加载 ==================
|
| 60 |
def load_model(config_path: str):
|
|
@@ -68,7 +55,10 @@ def load_model(config_path: str):
|
|
| 68 |
config = yaml.load(f, Loader=yaml.FullLoader)
|
| 69 |
|
| 70 |
print(f"Downloading weights from HF: {WEIGHTS_REPO}/{WEIGHTS_FILE}")
|
| 71 |
-
model_path = hf_hub_download(
|
|
|
|
|
|
|
|
|
|
| 72 |
print(f"✅ Weights downloaded to: {model_path}")
|
| 73 |
|
| 74 |
state = torch.load(model_path, map_location=device)
|
|
@@ -79,20 +69,18 @@ def load_model(config_path: str):
|
|
| 79 |
|
| 80 |
m = m.to(device)
|
| 81 |
m_state = m.state_dict()
|
| 82 |
-
m.load_state_dict(
|
|
|
|
|
|
|
|
|
|
| 83 |
m.eval()
|
| 84 |
print("✅ Model loaded.")
|
| 85 |
-
return m
|
| 86 |
|
| 87 |
# ================== 启动时加载一次模型 ==================
|
| 88 |
-
model
|
| 89 |
|
| 90 |
-
#
|
| 91 |
-
infer_cfg = cfg.get("inference", {}) if isinstance(cfg, dict) else {}
|
| 92 |
-
INFER_H = int(infer_cfg.get("height", 512))
|
| 93 |
-
INFER_W = int(infer_cfg.get("width", 1024))
|
| 94 |
-
|
| 95 |
-
# ================== 推理函数(修复:resize + 强制 dict/mask 逻辑一致) ==================
|
| 96 |
@gpu_decorator
|
| 97 |
def infer_raw(img_rgb: np.ndarray):
|
| 98 |
if img_rgb is None:
|
|
@@ -100,48 +88,34 @@ def infer_raw(img_rgb: np.ndarray):
|
|
| 100 |
|
| 101 |
import torch
|
| 102 |
|
| 103 |
-
#
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
# 2) normalize
|
| 107 |
-
img = img_resized.astype(np.float32) / 255.0
|
| 108 |
tensor = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).to(device)
|
| 109 |
|
| 110 |
with torch.inference_mode():
|
| 111 |
outputs = model(tensor)
|
| 112 |
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
outputs["pred_mask"] = outputs["pred_mask"] > 0.5
|
| 123 |
-
outputs["pred_depth"][~outputs["pred_mask"]] = 1
|
| 124 |
-
|
| 125 |
-
pred = outputs["pred_depth"][0].detach().cpu().squeeze().numpy()
|
| 126 |
|
| 127 |
return pred.astype(np.float32)
|
| 128 |
|
| 129 |
-
# ================== 可视化:改为自适应映射(问题1) ==================
|
| 130 |
def visualize_100m(pred: np.ndarray):
|
| 131 |
if pred is None:
|
| 132 |
return None, None, None
|
| 133 |
|
| 134 |
-
|
| 135 |
-
|
|
|
|
| 136 |
|
| 137 |
-
|
| 138 |
-
dmin = float(np.min(pred))
|
| 139 |
-
dmax = float(np.max(pred))
|
| 140 |
-
gray = ((pred - dmin) / ((dmax - dmin) + 1e-8))
|
| 141 |
-
gray = np.clip(gray, 0.0, 1.0)
|
| 142 |
-
depth_gray = (gray * 255).astype(np.uint8)
|
| 143 |
-
|
| 144 |
-
npy_path = "/tmp/depth.npy"
|
| 145 |
np.save(npy_path, pred)
|
| 146 |
|
| 147 |
return depth_color, depth_gray, npy_path
|
|
@@ -150,24 +124,19 @@ def visualize_10m(pred: np.ndarray):
|
|
| 150 |
if pred is None:
|
| 151 |
return None, None, None
|
| 152 |
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
dmin = float(np.min(pred))
|
| 157 |
-
dmax = float(np.max(pred))
|
| 158 |
-
gray = ((pred - dmin) / ((dmax - dmin) + 1e-8))
|
| 159 |
-
gray = np.clip(gray, 0.0, 0.1) / 0.1
|
| 160 |
-
depth_gray = (gray * 255).astype(np.uint8)
|
| 161 |
|
| 162 |
-
npy_path = "/tmp/
|
| 163 |
np.save(npy_path, pred)
|
| 164 |
|
| 165 |
return depth_color, depth_gray, npy_path
|
| 166 |
|
| 167 |
@gpu_decorator
|
| 168 |
def infer_and_vis_100m(img_rgb: np.ndarray):
|
| 169 |
-
pred = infer_raw(img_rgb)
|
| 170 |
-
color, gray, npy = visualize_100m(pred)
|
| 171 |
return pred, color, gray, npy
|
| 172 |
|
| 173 |
# ================== Gradio UI ==================
|
|
@@ -191,9 +160,10 @@ example_gen_paths = [
|
|
| 191 |
with gr.Blocks() as demo:
|
| 192 |
gr.Markdown("# DAP Depth Prediction Demo")
|
| 193 |
|
| 194 |
-
raw_depth = gr.State() # 保存模型输出
|
| 195 |
|
| 196 |
with gr.Row():
|
|
|
|
| 197 |
with gr.Column(scale=1):
|
| 198 |
inp = gr.Image(type="numpy", label="Input Image", height=360)
|
| 199 |
|
|
@@ -208,28 +178,52 @@ with gr.Blocks() as demo:
|
|
| 208 |
btn_10m = gr.Button("Visualize (10m)")
|
| 209 |
|
| 210 |
gr.Markdown(
|
| 211 |
-
|
| 212 |
<small>
|
| 213 |
-
<b>
|
| 214 |
-
<b>
|
| 215 |
-
• <b>
|
| 216 |
-
|
| 217 |
</small>
|
| 218 |
-
"""
|
|
|
|
| 219 |
)
|
| 220 |
|
|
|
|
| 221 |
with gr.Column(scale=2):
|
| 222 |
out_color = gr.Image(label="Depth (Color)", height=260)
|
| 223 |
out_gray = gr.Image(label="Depth (Gray)", height=260)
|
| 224 |
out_npy = gr.File(label="Depth (.npy)")
|
| 225 |
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import absolute_import, division, print_function
|
| 2 |
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
import cv2
|
| 6 |
import yaml
|
| 7 |
import numpy as np
|
|
|
|
| 10 |
|
| 11 |
# ================== 必须最早 import spaces ==================
|
| 12 |
try:
|
| 13 |
+
import spaces # type: ignore
|
| 14 |
gpu_decorator = spaces.GPU
|
| 15 |
except Exception:
|
| 16 |
gpu_decorator = lambda f: f
|
| 17 |
|
| 18 |
+
# ================== 工程路径:确保能 import networks ==================
|
| 19 |
+
# 适配:无论你从哪里启动 python app.py,都能找到项目根目录
|
| 20 |
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
|
| 21 |
sys.path.append(PROJECT_ROOT)
|
| 22 |
|
|
|
|
| 25 |
# ================== HF 模型配置 ==================
|
| 26 |
WEIGHTS_REPO = "Insta360-Research/DAP-weights"
|
| 27 |
WEIGHTS_FILE = "model.pth"
|
| 28 |
+
CONFIG_PATH = os.path.join(PROJECT_ROOT, "config", "infer.yaml")
|
| 29 |
|
| 30 |
model = None
|
| 31 |
device = "cpu"
|
| 32 |
|
| 33 |
+
# ================== 固定颜色映射(颜色一致) ==================
|
| 34 |
import matplotlib
|
| 35 |
|
| 36 |
+
def colorize_depth_fixed(depth_u8: np.ndarray, cmap: str = "Spectral") -> np.ndarray:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
"""
|
| 38 |
+
depth_u8: uint8, 0~255
|
|
|
|
| 39 |
return: RGB uint8
|
| 40 |
"""
|
| 41 |
+
disp = depth_u8.astype(np.float32) / 255.0
|
| 42 |
+
colored = matplotlib.colormaps[cmap](disp)[..., :3]
|
| 43 |
+
colored = (colored * 255).astype(np.uint8)
|
| 44 |
+
return np.ascontiguousarray(colored)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
# ================== 模型加载 ==================
|
| 47 |
def load_model(config_path: str):
|
|
|
|
| 55 |
config = yaml.load(f, Loader=yaml.FullLoader)
|
| 56 |
|
| 57 |
print(f"Downloading weights from HF: {WEIGHTS_REPO}/{WEIGHTS_FILE}")
|
| 58 |
+
model_path = hf_hub_download(
|
| 59 |
+
repo_id=WEIGHTS_REPO,
|
| 60 |
+
filename=WEIGHTS_FILE
|
| 61 |
+
)
|
| 62 |
print(f"✅ Weights downloaded to: {model_path}")
|
| 63 |
|
| 64 |
state = torch.load(model_path, map_location=device)
|
|
|
|
| 69 |
|
| 70 |
m = m.to(device)
|
| 71 |
m_state = m.state_dict()
|
| 72 |
+
m.load_state_dict(
|
| 73 |
+
{k: v for k, v in state.items() if k in m_state},
|
| 74 |
+
strict=False
|
| 75 |
+
)
|
| 76 |
m.eval()
|
| 77 |
print("✅ Model loaded.")
|
| 78 |
+
return m
|
| 79 |
|
| 80 |
# ================== 启动时加载一次模型 ==================
|
| 81 |
+
model = load_model(CONFIG_PATH)
|
| 82 |
|
| 83 |
+
# ================== 推理函数 ==================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
@gpu_decorator
|
| 85 |
def infer_raw(img_rgb: np.ndarray):
|
| 86 |
if img_rgb is None:
|
|
|
|
| 88 |
|
| 89 |
import torch
|
| 90 |
|
| 91 |
+
# 保持你原逻辑:不 resize,直接喂入
|
| 92 |
+
img = img_rgb.astype(np.float32) / 255.0
|
|
|
|
|
|
|
|
|
|
| 93 |
tensor = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).to(device)
|
| 94 |
|
| 95 |
with torch.inference_mode():
|
| 96 |
outputs = model(tensor)
|
| 97 |
|
| 98 |
+
if isinstance(outputs, dict) and "pred_depth" in outputs:
|
| 99 |
+
if "pred_mask" in outputs:
|
| 100 |
+
mask = 1 - outputs["pred_mask"]
|
| 101 |
+
mask = mask > 0.5
|
| 102 |
+
outputs["pred_depth"][~mask] = 1
|
| 103 |
+
pred = outputs["pred_depth"][0].cpu().squeeze().numpy()
|
| 104 |
+
else:
|
| 105 |
+
# 保持你原逻辑的 fallback
|
| 106 |
+
pred = outputs[0].cpu().squeeze().numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
return pred.astype(np.float32)
|
| 109 |
|
|
|
|
| 110 |
def visualize_100m(pred: np.ndarray):
|
| 111 |
if pred is None:
|
| 112 |
return None, None, None
|
| 113 |
|
| 114 |
+
pred_clip = np.clip(pred, 0.0, 1.0)
|
| 115 |
+
depth_gray = (pred_clip * 255).astype(np.uint8)
|
| 116 |
+
depth_color = colorize_depth_fixed(depth_gray, cmap="Spectral")
|
| 117 |
|
| 118 |
+
npy_path = "/tmp/depth_100m.npy"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
np.save(npy_path, pred)
|
| 120 |
|
| 121 |
return depth_color, depth_gray, npy_path
|
|
|
|
| 124 |
if pred is None:
|
| 125 |
return None, None, None
|
| 126 |
|
| 127 |
+
pred_clip = np.clip(pred, 0.0, 0.1)
|
| 128 |
+
depth_gray = (pred_clip * 10 * 255).astype(np.uint8)
|
| 129 |
+
depth_color = colorize_depth_fixed(depth_gray, cmap="Spectral")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
+
npy_path = "/tmp/depth_10m.npy"
|
| 132 |
np.save(npy_path, pred)
|
| 133 |
|
| 134 |
return depth_color, depth_gray, npy_path
|
| 135 |
|
| 136 |
@gpu_decorator
|
| 137 |
def infer_and_vis_100m(img_rgb: np.ndarray):
|
| 138 |
+
pred = infer_raw(img_rgb) # 跑模型一次(GPU)
|
| 139 |
+
color, gray, npy = visualize_100m(pred) # 默认100m显示(CPU)
|
| 140 |
return pred, color, gray, npy
|
| 141 |
|
| 142 |
# ================== Gradio UI ==================
|
|
|
|
| 160 |
with gr.Blocks() as demo:
|
| 161 |
gr.Markdown("# DAP Depth Prediction Demo")
|
| 162 |
|
| 163 |
+
raw_depth = gr.State() # 🔑 保存模型输出
|
| 164 |
|
| 165 |
with gr.Row():
|
| 166 |
+
# ========== Left ==========
|
| 167 |
with gr.Column(scale=1):
|
| 168 |
inp = gr.Image(type="numpy", label="Input Image", height=360)
|
| 169 |
|
|
|
|
| 178 |
btn_10m = gr.Button("Visualize (10m)")
|
| 179 |
|
| 180 |
gr.Markdown(
|
| 181 |
+
"""
|
| 182 |
<small>
|
| 183 |
+
<b>Visualization range:</b><br>
|
| 184 |
+
• <b>100m</b>: recommended for <b>outdoor</b> scenes<br>
|
| 185 |
+
• <b>10m</b>: recommended for <b>indoor</b> scenes<br>
|
| 186 |
+
(Only affects visualization, not the raw depth output)
|
| 187 |
</small>
|
| 188 |
+
""",
|
| 189 |
+
elem_id="vis_hint",
|
| 190 |
)
|
| 191 |
|
| 192 |
+
# ========== Right ==========
|
| 193 |
with gr.Column(scale=2):
|
| 194 |
out_color = gr.Image(label="Depth (Color)", height=260)
|
| 195 |
out_gray = gr.Image(label="Depth (Gray)", height=260)
|
| 196 |
out_npy = gr.File(label="Depth (.npy)")
|
| 197 |
|
| 198 |
+
# 1️⃣ 跑模型
|
| 199 |
+
btn_infer.click(
|
| 200 |
+
fn=infer_and_vis_100m,
|
| 201 |
+
inputs=inp,
|
| 202 |
+
outputs=[raw_depth, out_color, out_gray, out_npy],
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
# 2️⃣ 100m
|
| 206 |
+
btn_100m.click(
|
| 207 |
+
fn=visualize_100m,
|
| 208 |
+
inputs=raw_depth,
|
| 209 |
+
outputs=[out_color, out_gray, out_npy],
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
# 3️⃣ 10m
|
| 213 |
+
btn_10m.click(
|
| 214 |
+
fn=visualize_10m,
|
| 215 |
+
inputs=raw_depth,
|
| 216 |
+
outputs=[out_color, out_gray, out_npy],
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
if __name__ == "__main__":
|
| 220 |
+
# 适配“放到网页里”:建议用环境变量控制 host/port
|
| 221 |
+
host = os.environ.get("HOST", "0.0.0.0")
|
| 222 |
+
port = int(os.environ.get("PORT", "7860"))
|
| 223 |
+
|
| 224 |
+
demo.launch(
|
| 225 |
+
server_name=host,
|
| 226 |
+
server_port=port,
|
| 227 |
+
ssr_mode=False,
|
| 228 |
+
show_error=True,
|
| 229 |
+
)
|