Insta360-Research commited on
Commit
07df94d
·
verified ·
1 Parent(s): 72731f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -93
app.py CHANGED
@@ -1,6 +1,7 @@
1
  from __future__ import absolute_import, division, print_function
2
 
3
- import os, sys
 
4
  import cv2
5
  import yaml
6
  import numpy as np
@@ -9,12 +10,13 @@ from huggingface_hub import hf_hub_download
9
 
10
  # ================== 必须最早 import spaces ==================
11
  try:
12
- import spaces
13
  gpu_decorator = spaces.GPU
14
  except Exception:
15
  gpu_decorator = lambda f: f
16
 
17
- # ================== 工程路径 ==================
 
18
  PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
19
  sys.path.append(PROJECT_ROOT)
20
 
@@ -23,38 +25,23 @@ from networks.models import make # noqa: E402
23
  # ================== HF 模型配置 ==================
24
  WEIGHTS_REPO = "Insta360-Research/DAP-weights"
25
  WEIGHTS_FILE = "model.pth"
26
- CONFIG_PATH = "config/infer.yaml"
27
 
28
  model = None
29
  device = "cpu"
30
 
31
- # ================== 自适应可视化(与测试脚本一致) ==================
32
  import matplotlib
33
 
34
- def colorize_depth_adaptive(
35
- depth: np.ndarray,
36
- cmap: str = "Spectral",
37
- depth_truncation: float = 1.0
38
- ) -> np.ndarray:
39
  """
40
- depth: float32 depth map (H, W)
41
- depth_truncation: 归一化后的截断阈值(0~1),超过的都视为最远
42
  return: RGB uint8
43
  """
44
- if depth is None:
45
- return None
46
-
47
- dmin = float(np.min(depth))
48
- dmax = float(np.max(depth))
49
- denom = (dmax - dmin) + 1e-8
50
-
51
- depth_normalized = (depth - dmin) / denom
52
- depth_normalized = np.clip(depth_normalized, 0.0, float(depth_truncation))
53
- depth_normalized = depth_normalized / (float(depth_truncation) + 1e-8)
54
-
55
- colored = matplotlib.colormaps[cmap](depth_normalized)[..., :3]
56
- colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8))
57
- return colored
58
 
59
  # ================== 模型加载 ==================
60
  def load_model(config_path: str):
@@ -68,7 +55,10 @@ def load_model(config_path: str):
68
  config = yaml.load(f, Loader=yaml.FullLoader)
69
 
70
  print(f"Downloading weights from HF: {WEIGHTS_REPO}/{WEIGHTS_FILE}")
71
- model_path = hf_hub_download(repo_id=WEIGHTS_REPO, filename=WEIGHTS_FILE)
 
 
 
72
  print(f"✅ Weights downloaded to: {model_path}")
73
 
74
  state = torch.load(model_path, map_location=device)
@@ -79,20 +69,18 @@ def load_model(config_path: str):
79
 
80
  m = m.to(device)
81
  m_state = m.state_dict()
82
- m.load_state_dict({k: v for k, v in state.items() if k in m_state}, strict=False)
 
 
 
83
  m.eval()
84
  print("✅ Model loaded.")
85
- return m, config
86
 
87
  # ================== 启动时加载一次模型 ==================
88
- model, cfg = load_model(CONFIG_PATH)
89
 
90
- # ====== config 读推理尺寸(和你的测试脚本默认一致) ======
91
- infer_cfg = cfg.get("inference", {}) if isinstance(cfg, dict) else {}
92
- INFER_H = int(infer_cfg.get("height", 512))
93
- INFER_W = int(infer_cfg.get("width", 1024))
94
-
95
- # ================== 推理函数(修复:resize + 强制 dict/mask 逻辑一致) ==================
96
  @gpu_decorator
97
  def infer_raw(img_rgb: np.ndarray):
98
  if img_rgb is None:
@@ -100,48 +88,34 @@ def infer_raw(img_rgb: np.ndarray):
100
 
101
  import torch
102
 
103
- # 1) resize 到固定输入尺寸(与测试脚本一致)
104
- img_resized = cv2.resize(img_rgb, (INFER_W, INFER_H), interpolation=cv2.INTER_LINEAR)
105
-
106
- # 2) normalize
107
- img = img_resized.astype(np.float32) / 255.0
108
  tensor = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).to(device)
109
 
110
  with torch.inference_mode():
111
  outputs = model(tensor)
112
 
113
- # 3) 强制走 dict + pred_mask/pred_depth 逻辑,避免 fallback 导致不一致
114
- if not (isinstance(outputs, dict) and ("pred_depth" in outputs)):
115
- raise RuntimeError(
116
- f"Model output format unexpected. Expect dict with key 'pred_depth', got: {type(outputs)}"
117
- )
118
-
119
- # mask 处理(与测试脚本一致)
120
- if "pred_mask" in outputs:
121
- outputs["pred_mask"] = 1 - outputs["pred_mask"]
122
- outputs["pred_mask"] = outputs["pred_mask"] > 0.5
123
- outputs["pred_depth"][~outputs["pred_mask"]] = 1
124
-
125
- pred = outputs["pred_depth"][0].detach().cpu().squeeze().numpy()
126
 
127
  return pred.astype(np.float32)
128
 
129
- # ================== 可视化:改为自适应映射(问题1) ==================
130
  def visualize_100m(pred: np.ndarray):
131
  if pred is None:
132
  return None, None, None
133
 
134
- # 与你的测试脚本默认一致:depth_truncation = 1.0
135
- depth_color = colorize_depth_adaptive(pred, cmap="Spectral", depth_truncation=1.0)
 
136
 
137
- # 灰度图也用同样的“自适应”方式生成(便于对齐观感)
138
- dmin = float(np.min(pred))
139
- dmax = float(np.max(pred))
140
- gray = ((pred - dmin) / ((dmax - dmin) + 1e-8))
141
- gray = np.clip(gray, 0.0, 1.0)
142
- depth_gray = (gray * 255).astype(np.uint8)
143
-
144
- npy_path = "/tmp/depth.npy"
145
  np.save(npy_path, pred)
146
 
147
  return depth_color, depth_gray, npy_path
@@ -150,24 +124,19 @@ def visualize_10m(pred: np.ndarray):
150
  if pred is None:
151
  return None, None, None
152
 
153
- # “近处更细”:相当于把可视化截断阈值调小(你测试脚本的 depth_truncation 思路)
154
- depth_color = colorize_depth_adaptive(pred, cmap="Spectral", depth_truncation=0.1)
155
-
156
- dmin = float(np.min(pred))
157
- dmax = float(np.max(pred))
158
- gray = ((pred - dmin) / ((dmax - dmin) + 1e-8))
159
- gray = np.clip(gray, 0.0, 0.1) / 0.1
160
- depth_gray = (gray * 255).astype(np.uint8)
161
 
162
- npy_path = "/tmp/depth.npy"
163
  np.save(npy_path, pred)
164
 
165
  return depth_color, depth_gray, npy_path
166
 
167
  @gpu_decorator
168
  def infer_and_vis_100m(img_rgb: np.ndarray):
169
- pred = infer_raw(img_rgb) # 跑模型一次(GPU)
170
- color, gray, npy = visualize_100m(pred) # 默认可视化(CPU)
171
  return pred, color, gray, npy
172
 
173
  # ================== Gradio UI ==================
@@ -191,9 +160,10 @@ example_gen_paths = [
191
  with gr.Blocks() as demo:
192
  gr.Markdown("# DAP Depth Prediction Demo")
193
 
194
- raw_depth = gr.State() # 保存模型输出
195
 
196
  with gr.Row():
 
197
  with gr.Column(scale=1):
198
  inp = gr.Image(type="numpy", label="Input Image", height=360)
199
 
@@ -208,28 +178,52 @@ with gr.Blocks() as demo:
208
  btn_10m = gr.Button("Visualize (10m)")
209
 
210
  gr.Markdown(
211
- f"""
212
  <small>
213
- <b>Inference resize:</b> {INFER_W}×{INFER_H}<br>
214
- <b>Visualization:</b><br>
215
- • <b>100m</b>: truncation=1.0(默认)<br>
216
- <b>10m</b>: truncation=0.1(近处更细)<br>
217
  </small>
218
- """
 
219
  )
220
 
 
221
  with gr.Column(scale=2):
222
  out_color = gr.Image(label="Depth (Color)", height=260)
223
  out_gray = gr.Image(label="Depth (Gray)", height=260)
224
  out_npy = gr.File(label="Depth (.npy)")
225
 
226
- btn_infer.click(fn=infer_and_vis_100m, inputs=inp, outputs=[raw_depth, out_color, out_gray, out_npy])
227
- btn_100m.click(fn=visualize_100m, inputs=raw_depth, outputs=[out_color, out_gray, out_npy])
228
- btn_10m.click(fn=visualize_10m, inputs=raw_depth, outputs=[out_color, out_gray, out_npy])
229
-
230
- demo.launch(
231
- server_name="0.0.0.0",
232
- server_port=7860,
233
- ssr_mode=False,
234
- show_error=True,
235
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import absolute_import, division, print_function
2
 
3
+ import os
4
+ import sys
5
  import cv2
6
  import yaml
7
  import numpy as np
 
10
 
11
  # ================== 必须最早 import spaces ==================
12
  try:
13
+ import spaces # type: ignore
14
  gpu_decorator = spaces.GPU
15
  except Exception:
16
  gpu_decorator = lambda f: f
17
 
18
+ # ================== 工程路径:确保能 import networks ==================
19
+ # 适配:无论你从哪里启动 python app.py,都能找到项目根目录
20
  PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
21
  sys.path.append(PROJECT_ROOT)
22
 
 
25
  # ================== HF 模型配置 ==================
26
  WEIGHTS_REPO = "Insta360-Research/DAP-weights"
27
  WEIGHTS_FILE = "model.pth"
28
+ CONFIG_PATH = os.path.join(PROJECT_ROOT, "config", "infer.yaml")
29
 
30
  model = None
31
  device = "cpu"
32
 
33
+ # ================== 固定颜色映射(颜色一致) ==================
34
  import matplotlib
35
 
36
+ def colorize_depth_fixed(depth_u8: np.ndarray, cmap: str = "Spectral") -> np.ndarray:
 
 
 
 
37
  """
38
+ depth_u8: uint8, 0~255
 
39
  return: RGB uint8
40
  """
41
+ disp = depth_u8.astype(np.float32) / 255.0
42
+ colored = matplotlib.colormaps[cmap](disp)[..., :3]
43
+ colored = (colored * 255).astype(np.uint8)
44
+ return np.ascontiguousarray(colored)
 
 
 
 
 
 
 
 
 
 
45
 
46
  # ================== 模型加载 ==================
47
  def load_model(config_path: str):
 
55
  config = yaml.load(f, Loader=yaml.FullLoader)
56
 
57
  print(f"Downloading weights from HF: {WEIGHTS_REPO}/{WEIGHTS_FILE}")
58
+ model_path = hf_hub_download(
59
+ repo_id=WEIGHTS_REPO,
60
+ filename=WEIGHTS_FILE
61
+ )
62
  print(f"✅ Weights downloaded to: {model_path}")
63
 
64
  state = torch.load(model_path, map_location=device)
 
69
 
70
  m = m.to(device)
71
  m_state = m.state_dict()
72
+ m.load_state_dict(
73
+ {k: v for k, v in state.items() if k in m_state},
74
+ strict=False
75
+ )
76
  m.eval()
77
  print("✅ Model loaded.")
78
+ return m
79
 
80
  # ================== 启动时加载一次模型 ==================
81
+ model = load_model(CONFIG_PATH)
82
 
83
+ # ================== 推理函数 ==================
 
 
 
 
 
84
  @gpu_decorator
85
  def infer_raw(img_rgb: np.ndarray):
86
  if img_rgb is None:
 
88
 
89
  import torch
90
 
91
+ # 保持你原逻辑:不 resize,直接喂入
92
+ img = img_rgb.astype(np.float32) / 255.0
 
 
 
93
  tensor = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).to(device)
94
 
95
  with torch.inference_mode():
96
  outputs = model(tensor)
97
 
98
+ if isinstance(outputs, dict) and "pred_depth" in outputs:
99
+ if "pred_mask" in outputs:
100
+ mask = 1 - outputs["pred_mask"]
101
+ mask = mask > 0.5
102
+ outputs["pred_depth"][~mask] = 1
103
+ pred = outputs["pred_depth"][0].cpu().squeeze().numpy()
104
+ else:
105
+ # 保持你原逻辑的 fallback
106
+ pred = outputs[0].cpu().squeeze().numpy()
 
 
 
 
107
 
108
  return pred.astype(np.float32)
109
 
 
110
  def visualize_100m(pred: np.ndarray):
111
  if pred is None:
112
  return None, None, None
113
 
114
+ pred_clip = np.clip(pred, 0.0, 1.0)
115
+ depth_gray = (pred_clip * 255).astype(np.uint8)
116
+ depth_color = colorize_depth_fixed(depth_gray, cmap="Spectral")
117
 
118
+ npy_path = "/tmp/depth_100m.npy"
 
 
 
 
 
 
 
119
  np.save(npy_path, pred)
120
 
121
  return depth_color, depth_gray, npy_path
 
124
  if pred is None:
125
  return None, None, None
126
 
127
+ pred_clip = np.clip(pred, 0.0, 0.1)
128
+ depth_gray = (pred_clip * 10 * 255).astype(np.uint8)
129
+ depth_color = colorize_depth_fixed(depth_gray, cmap="Spectral")
 
 
 
 
 
130
 
131
+ npy_path = "/tmp/depth_10m.npy"
132
  np.save(npy_path, pred)
133
 
134
  return depth_color, depth_gray, npy_path
135
 
136
  @gpu_decorator
137
  def infer_and_vis_100m(img_rgb: np.ndarray):
138
+ pred = infer_raw(img_rgb) # 跑模型一次(GPU)
139
+ color, gray, npy = visualize_100m(pred) # 默认100m显示(CPU)
140
  return pred, color, gray, npy
141
 
142
  # ================== Gradio UI ==================
 
160
  with gr.Blocks() as demo:
161
  gr.Markdown("# DAP Depth Prediction Demo")
162
 
163
+ raw_depth = gr.State() # 🔑 保存模型输出
164
 
165
  with gr.Row():
166
+ # ========== Left ==========
167
  with gr.Column(scale=1):
168
  inp = gr.Image(type="numpy", label="Input Image", height=360)
169
 
 
178
  btn_10m = gr.Button("Visualize (10m)")
179
 
180
  gr.Markdown(
181
+ """
182
  <small>
183
+ <b>Visualization range:</b><br>
184
+ <b>100m</b>: recommended for <b>outdoor</b> scenes<br>
185
+ • <b>10m</b>: recommended for <b>indoor</b> scenes<br>
186
+ (Only affects visualization, not the raw depth output)
187
  </small>
188
+ """,
189
+ elem_id="vis_hint",
190
  )
191
 
192
+ # ========== Right ==========
193
  with gr.Column(scale=2):
194
  out_color = gr.Image(label="Depth (Color)", height=260)
195
  out_gray = gr.Image(label="Depth (Gray)", height=260)
196
  out_npy = gr.File(label="Depth (.npy)")
197
 
198
+ # 1️⃣ 跑模型
199
+ btn_infer.click(
200
+ fn=infer_and_vis_100m,
201
+ inputs=inp,
202
+ outputs=[raw_depth, out_color, out_gray, out_npy],
203
+ )
204
+
205
+ # 2️⃣ 100m
206
+ btn_100m.click(
207
+ fn=visualize_100m,
208
+ inputs=raw_depth,
209
+ outputs=[out_color, out_gray, out_npy],
210
+ )
211
+
212
+ # 3️⃣ 10m
213
+ btn_10m.click(
214
+ fn=visualize_10m,
215
+ inputs=raw_depth,
216
+ outputs=[out_color, out_gray, out_npy],
217
+ )
218
+
219
+ if __name__ == "__main__":
220
+ # 适配“放到网页里”:建议用环境变量控制 host/port
221
+ host = os.environ.get("HOST", "0.0.0.0")
222
+ port = int(os.environ.get("PORT", "7860"))
223
+
224
+ demo.launch(
225
+ server_name=host,
226
+ server_port=port,
227
+ ssr_mode=False,
228
+ show_error=True,
229
+ )