Insta360-Research commited on
Commit
300d330
·
verified ·
1 Parent(s): a910daf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -13
app.py CHANGED
@@ -80,6 +80,23 @@ def load_model(config_path: str):
80
  # ================== 启动时加载一次模型 ==================
81
  model = load_model(CONFIG_PATH)
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  # ================== 推理函数 ==================
84
  @gpu_decorator
85
  def infer_raw(img_rgb: np.ndarray):
@@ -109,7 +126,7 @@ def infer_raw(img_rgb: np.ndarray):
109
 
110
  def visualize_100m(pred: np.ndarray):
111
  if pred is None:
112
- return None, None, None
113
 
114
  pred_clip = np.clip(pred, 0.0, 1.0)
115
  depth_gray = (pred_clip * 255).astype(np.uint8)
@@ -118,11 +135,11 @@ def visualize_100m(pred: np.ndarray):
118
  npy_path = "/tmp/depth_100m.npy"
119
  np.save(npy_path, pred)
120
 
121
- return depth_color, depth_gray, npy_path
122
 
123
  def visualize_10m(pred: np.ndarray):
124
  if pred is None:
125
- return None, None, None
126
 
127
  pred_clip = np.clip(pred, 0.0, 0.1)
128
  depth_gray = (pred_clip * 10 * 255).astype(np.uint8)
@@ -131,13 +148,13 @@ def visualize_10m(pred: np.ndarray):
131
  npy_path = "/tmp/depth_10m.npy"
132
  np.save(npy_path, pred)
133
 
134
- return depth_color, depth_gray, npy_path
135
 
136
  @gpu_decorator
137
  def infer_and_vis_100m(img_rgb: np.ndarray):
138
  pred = infer_raw(img_rgb) # 跑模型一次(GPU)
139
- color, gray, npy = visualize_100m(pred) # 默认100m显示(CPU)
140
- return pred, color, gray, npy
141
 
142
  # ================== Gradio UI ==================
143
  example_paths = [
@@ -210,29 +227,33 @@ with gr.Blocks() as demo:
210
 
211
  # ========== Right ==========
212
  with gr.Column(scale=2):
213
- out_color = gr.Image(label="Depth (Color)", height=260)
214
- out_gray = gr.Image(label="Depth (Gray)", height=260)
 
 
 
 
215
  out_npy = gr.File(label="Depth (.npy)")
216
 
217
  # 1️⃣ 跑模型
218
  btn_infer.click(
219
  fn=infer_and_vis_100m,
220
  inputs=inp,
221
- outputs=[raw_depth, out_color, out_gray, out_npy],
222
  )
223
 
224
  # 2️⃣ 100m
225
  btn_100m.click(
226
  fn=visualize_100m,
227
  inputs=raw_depth,
228
- outputs=[out_color, out_gray, out_npy],
229
  )
230
 
231
  # 3️⃣ 10m
232
  btn_10m.click(
233
  fn=visualize_10m,
234
  inputs=raw_depth,
235
- outputs=[out_color, out_gray, out_npy],
236
  )
237
 
238
 
@@ -248,5 +269,4 @@ if __name__ == "__main__":
248
  server_port=port,
249
  ssr_mode=False,
250
  show_error=True,
251
- )
252
-
 
80
  # ================== 启动时加载一次模型 ==================
81
  model = load_model(CONFIG_PATH)
82
 
83
+ # ================== 加载标度尺图片 ==================
84
+ COLORBAR_DIR = os.path.join(PROJECT_ROOT, "colorbars")
85
+ colorbar_100m_color = cv2.imread(os.path.join(COLORBAR_DIR, "colorbar_100m_color.png"))
86
+ colorbar_100m_gray = cv2.imread(os.path.join(COLORBAR_DIR, "colorbar_100m_gray.png"))
87
+ colorbar_10m_color = cv2.imread(os.path.join(COLORBAR_DIR, "colorbar_10m_color.png"))
88
+ colorbar_10m_gray = cv2.imread(os.path.join(COLORBAR_DIR, "colorbar_10m_gray.png"))
89
+
90
+ # 转换为RGB(Gradio需要RGB格式)
91
+ if colorbar_100m_color is not None:
92
+ colorbar_100m_color = cv2.cvtColor(colorbar_100m_color, cv2.COLOR_BGR2RGB)
93
+ if colorbar_100m_gray is not None:
94
+ colorbar_100m_gray = cv2.cvtColor(colorbar_100m_gray, cv2.COLOR_BGR2RGB)
95
+ if colorbar_10m_color is not None:
96
+ colorbar_10m_color = cv2.cvtColor(colorbar_10m_color, cv2.COLOR_BGR2RGB)
97
+ if colorbar_10m_gray is not None:
98
+ colorbar_10m_gray = cv2.cvtColor(colorbar_10m_gray, cv2.COLOR_BGR2RGB)
99
+
100
  # ================== 推理函数 ==================
101
  @gpu_decorator
102
  def infer_raw(img_rgb: np.ndarray):
 
126
 
127
  def visualize_100m(pred: np.ndarray):
128
  if pred is None:
129
+ return None, None, None, None, None
130
 
131
  pred_clip = np.clip(pred, 0.0, 1.0)
132
  depth_gray = (pred_clip * 255).astype(np.uint8)
 
135
  npy_path = "/tmp/depth_100m.npy"
136
  np.save(npy_path, pred)
137
 
138
+ return depth_color, depth_gray, npy_path, colorbar_100m_color, colorbar_100m_gray
139
 
140
  def visualize_10m(pred: np.ndarray):
141
  if pred is None:
142
+ return None, None, None, None, None
143
 
144
  pred_clip = np.clip(pred, 0.0, 0.1)
145
  depth_gray = (pred_clip * 10 * 255).astype(np.uint8)
 
148
  npy_path = "/tmp/depth_10m.npy"
149
  np.save(npy_path, pred)
150
 
151
+ return depth_color, depth_gray, npy_path, colorbar_10m_color, colorbar_10m_gray
152
 
153
  @gpu_decorator
154
  def infer_and_vis_100m(img_rgb: np.ndarray):
155
  pred = infer_raw(img_rgb) # 跑模型一次(GPU)
156
+ color, gray, npy, cbar_color, cbar_gray = visualize_100m(pred) # 默认100m显示(CPU)
157
+ return pred, color, gray, npy, cbar_color, cbar_gray
158
 
159
  # ================== Gradio UI ==================
160
  example_paths = [
 
227
 
228
  # ========== Right ==========
229
  with gr.Column(scale=2):
230
+ with gr.Row():
231
+ out_color = gr.Image(label="Depth (Color)", height=260, scale=5)
232
+ colorbar_color = gr.Image(label="Scale", height=260, scale=1, show_label=False)
233
+ with gr.Row():
234
+ out_gray = gr.Image(label="Depth (Gray)", height=260, scale=5)
235
+ colorbar_gray = gr.Image(label="Scale", height=260, scale=1, show_label=False)
236
  out_npy = gr.File(label="Depth (.npy)")
237
 
238
  # 1️⃣ 跑模型
239
  btn_infer.click(
240
  fn=infer_and_vis_100m,
241
  inputs=inp,
242
+ outputs=[raw_depth, out_color, out_gray, out_npy, colorbar_color, colorbar_gray],
243
  )
244
 
245
  # 2️⃣ 100m
246
  btn_100m.click(
247
  fn=visualize_100m,
248
  inputs=raw_depth,
249
+ outputs=[out_color, out_gray, out_npy, colorbar_color, colorbar_gray],
250
  )
251
 
252
  # 3️⃣ 10m
253
  btn_10m.click(
254
  fn=visualize_10m,
255
  inputs=raw_depth,
256
+ outputs=[out_color, out_gray, out_npy, colorbar_color, colorbar_gray],
257
  )
258
 
259
 
 
269
  server_port=port,
270
  ssr_mode=False,
271
  show_error=True,
272
+ )