Insta360-Research commited on
Commit
daec925
·
verified ·
1 Parent(s): bd33a1c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -36
app.py CHANGED
@@ -3,33 +3,47 @@ from __future__ import absolute_import, division, print_function
3
  import os, sys
4
  import cv2
5
  import yaml
6
- import torch
7
  import numpy as np
8
- import torch.nn as nn
9
  import gradio as gr
10
  from huggingface_hub import hf_hub_download
11
 
 
 
 
 
 
 
12
  # ========== 让 Space 能 import 你的工程 ==========
13
- PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__)) # app.py 在仓库根目录
14
  sys.path.append(PROJECT_ROOT)
15
 
16
- from networks.models import make
17
-
18
- device = "cuda" if torch.cuda.is_available() else "cpu"
19
 
20
- # ====== HF 权重仓库配置(你已经上传成功)======
21
  WEIGHTS_REPO = "Insta360-Research/DAP-weights"
22
  WEIGHTS_FILE = "model.pth"
 
 
 
 
 
 
23
 
24
- # ========== 可视化 ==========
25
  def colorize_depth(depth, colormap=cv2.COLORMAP_JET):
26
  depth = depth.astype(np.float32)
27
  depth_norm = (depth - depth.min()) / (depth.max() - depth.min() + 1e-6)
28
  depth_u8 = (depth_norm * 255).astype(np.uint8)
29
  return cv2.applyColorMap(depth_u8, colormap) # BGR
30
 
31
- # ========== 加载模型(只加载一次) ==========
32
  def load_model(config_path: str):
 
 
 
 
 
 
 
33
  with open(config_path, "r") as f:
34
  config = yaml.load(f, Loader=yaml.FullLoader)
35
 
@@ -39,46 +53,42 @@ def load_model(config_path: str):
39
 
40
  state = torch.load(model_path, map_location=device)
41
 
42
- model = make(config["model"])
43
  if any(k.startswith("module") for k in state.keys()):
44
- model = nn.DataParallel(model)
45
-
46
- model = model.to(device)
47
 
48
- model_state = model.state_dict()
49
- model.load_state_dict({k: v for k, v in state.items() if k in model_state}, strict=False)
50
- model.eval()
 
51
  print("✅ Model loaded.")
52
- return model
53
 
54
- # 这里改成你仓库里的 config 路径
55
- CONFIG_PATH = "config/infer.yaml"
56
  model = load_model(CONFIG_PATH)
57
 
58
- # ========== 单张图推理 ==========
59
- @torch.no_grad()
60
  def predict(img_rgb: np.ndarray):
61
- """
62
- img_rgb: H x W x 3 (RGB), uint8
63
- return: depth_color_rgb, depth_gray
64
- """
65
  if img_rgb is None:
66
  return None, None
67
 
 
 
68
  img = img_rgb.astype(np.float32) / 255.0
69
  tensor = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).to(device)
70
 
71
- outputs = model(tensor)
 
72
 
73
- if isinstance(outputs, dict) and "pred_depth" in outputs:
74
- # 你原来的 mask 逻辑
75
- if "pred_mask" in outputs:
76
- outputs["pred_mask"] = 1 - outputs["pred_mask"]
77
- outputs["pred_mask"] = (outputs["pred_mask"] > 0.5)
78
- outputs["pred_depth"][~outputs["pred_mask"]] = 1
79
- pred = outputs["pred_depth"][0].detach().cpu().squeeze().numpy()
80
- else:
81
- pred = outputs[0].detach().cpu().squeeze().numpy()
82
 
83
  pred_clip = np.clip(pred, 0.001, 1.0)
84
  depth_gray = (pred_clip * 255).astype(np.uint8)
@@ -88,6 +98,7 @@ def predict(img_rgb: np.ndarray):
88
 
89
  return depth_color_rgb, depth_gray
90
 
 
91
  demo = gr.Interface(
92
  fn=predict,
93
  inputs=gr.Image(type="numpy", label="Input Image"),
@@ -96,7 +107,7 @@ demo = gr.Interface(
96
  gr.Image(type="numpy", label="Depth (Gray)"),
97
  ],
98
  title="DAP Depth Prediction Demo",
99
- description="Upload an image and get depth prediction."
100
  )
101
 
102
  demo.launch(
 
3
  import os, sys
4
  import cv2
5
  import yaml
 
6
  import numpy as np
 
7
  import gradio as gr
8
  from huggingface_hub import hf_hub_download
9
 
10
+ # ✅ 必须最早 import spaces(在 torch / 任何 CUDA 初始化之前)
11
+ try:
12
+ import spaces # noqa: F401
13
+ except Exception:
14
+ spaces = None # 不影响本地跑
15
+
16
  # ========== 让 Space 能 import 你的工程 ==========
17
+ PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
18
  sys.path.append(PROJECT_ROOT)
19
 
20
+ from networks.models import make # noqa: E402
 
 
21
 
22
+ # ====== HF 权重仓库配置 ======
23
  WEIGHTS_REPO = "Insta360-Research/DAP-weights"
24
  WEIGHTS_FILE = "model.pth"
25
+ CONFIG_PATH = "config/infer.yaml"
26
+
27
+ # 先定义全局占位
28
+ model = None
29
+ device = "cpu"
30
+
31
 
 
32
  def colorize_depth(depth, colormap=cv2.COLORMAP_JET):
33
  depth = depth.astype(np.float32)
34
  depth_norm = (depth - depth.min()) / (depth.max() - depth.min() + 1e-6)
35
  depth_u8 = (depth_norm * 255).astype(np.uint8)
36
  return cv2.applyColorMap(depth_u8, colormap) # BGR
37
 
38
+
39
  def load_model(config_path: str):
40
+ # ✅ torch 放到这里 import,避免在 spaces import 之前触发 CUDA
41
+ import torch
42
+ import torch.nn as nn
43
+
44
+ global device
45
+ device = "cuda" if torch.cuda.is_available() else "cpu"
46
+
47
  with open(config_path, "r") as f:
48
  config = yaml.load(f, Loader=yaml.FullLoader)
49
 
 
53
 
54
  state = torch.load(model_path, map_location=device)
55
 
56
+ m = make(config["model"])
57
  if any(k.startswith("module") for k in state.keys()):
58
+ m = nn.DataParallel(m)
 
 
59
 
60
+ m = m.to(device)
61
+ m_state = m.state_dict()
62
+ m.load_state_dict({k: v for k, v in state.items() if k in m_state}, strict=False)
63
+ m.eval()
64
  print("✅ Model loaded.")
65
+ return m
66
 
67
+
68
+ # 启动时加载一次模型
69
  model = load_model(CONFIG_PATH)
70
 
71
+
 
72
  def predict(img_rgb: np.ndarray):
 
 
 
 
73
  if img_rgb is None:
74
  return None, None
75
 
76
+ import torch # 这里用到 torch,再 import 一次没关系
77
+
78
  img = img_rgb.astype(np.float32) / 255.0
79
  tensor = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).to(device)
80
 
81
+ with torch.no_grad():
82
+ outputs = model(tensor)
83
 
84
+ if isinstance(outputs, dict) and "pred_depth" in outputs:
85
+ if "pred_mask" in outputs:
86
+ outputs["pred_mask"] = 1 - outputs["pred_mask"]
87
+ outputs["pred_mask"] = (outputs["pred_mask"] > 0.5)
88
+ outputs["pred_depth"][~outputs["pred_mask"]] = 1
89
+ pred = outputs["pred_depth"][0].detach().cpu().squeeze().numpy()
90
+ else:
91
+ pred = outputs[0].detach().cpu().squeeze().numpy()
 
92
 
93
  pred_clip = np.clip(pred, 0.001, 1.0)
94
  depth_gray = (pred_clip * 255).astype(np.uint8)
 
98
 
99
  return depth_color_rgb, depth_gray
100
 
101
+
102
  demo = gr.Interface(
103
  fn=predict,
104
  inputs=gr.Image(type="numpy", label="Input Image"),
 
107
  gr.Image(type="numpy", label="Depth (Gray)"),
108
  ],
109
  title="DAP Depth Prediction Demo",
110
+ description="Upload an image and get depth prediction.",
111
  )
112
 
113
  demo.launch(