Insta360-Research commited on
Commit
77df1b9
·
verified ·
1 Parent(s): 116a4b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -42
app.py CHANGED
@@ -7,16 +7,21 @@ import numpy as np
7
  import gradio as gr
8
  from huggingface_hub import hf_hub_download
9
 
10
- # 必须最早 import spaces(在 torch / 任何 CUDA 初始化之前)
11
- import spaces # 在 HF Spaces 一定存在
12
-
13
- import matplotlib # 用你的 colormap
14
-
 
 
 
 
15
  PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
16
  sys.path.append(PROJECT_ROOT)
17
 
18
  from networks.models import make # noqa: E402
19
 
 
20
  WEIGHTS_REPO = "Insta360-Research/DAP-weights"
21
  WEIGHTS_FILE = "model.pth"
22
  CONFIG_PATH = "config/infer.yaml"
@@ -24,21 +29,20 @@ CONFIG_PATH = "config/infer.yaml"
24
  model = None
25
  device = "cpu"
26
 
 
 
27
 
 
 
 
 
 
 
 
 
 
28
 
29
- def colorize_depth_matplotlib(depth: np.ndarray, mask: np.ndarray = None, normalize: bool = True, cmap: str = 'Spectral') -> np.ndarray:
30
- if mask is None:
31
- depth = np.where(depth > 0, depth, np.nan)
32
- else:
33
- depth = np.where((depth > 0) & mask, depth, np.nan)
34
-
35
- disp = depth / 255.0
36
-
37
- colored = np.nan_to_num(matplotlib.colormaps[cmap](disp)[..., :3], 0)
38
- colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8))
39
- return colored
40
-
41
-
42
  def load_model(config_path: str):
43
  import torch
44
  import torch.nn as nn
@@ -50,7 +54,10 @@ def load_model(config_path: str):
50
  config = yaml.load(f, Loader=yaml.FullLoader)
51
 
52
  print(f"Downloading weights from HF: {WEIGHTS_REPO}/{WEIGHTS_FILE}")
53
- model_path = hf_hub_download(repo_id=WEIGHTS_REPO, filename=WEIGHTS_FILE)
 
 
 
54
  print(f"✅ Weights downloaded to: {model_path}")
55
 
56
  state = torch.load(model_path, map_location=device)
@@ -61,55 +68,53 @@ def load_model(config_path: str):
61
 
62
  m = m.to(device)
63
  m_state = m.state_dict()
64
- m.load_state_dict({k: v for k, v in state.items() if k in m_state}, strict=False)
 
 
 
65
  m.eval()
66
  print("✅ Model loaded.")
67
  return m
68
 
69
-
70
  model = load_model(CONFIG_PATH)
71
 
72
-
73
- @spaces.GPU
74
  def predict(img_rgb: np.ndarray):
75
  if img_rgb is None:
76
  return None, None
77
 
78
  import torch
79
 
 
80
  img = img_rgb.astype(np.float32) / 255.0
81
- tensor = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).to(device)
 
 
82
 
83
  with torch.inference_mode():
84
  outputs = model(tensor)
85
 
86
  if isinstance(outputs, dict) and "pred_depth" in outputs:
87
  if "pred_mask" in outputs:
88
- pm = 1 - outputs["pred_mask"]
89
- pm = (pm > 0.5)
90
- outputs["pred_depth"][~pm] = 1
91
- pred = outputs["pred_depth"][0].detach().cpu().squeeze().numpy()
92
  else:
93
- pred = outputs[0].detach().cpu().squeeze().numpy()
94
 
95
- # 灰度图:如果你 pred 本来就在 0~1,就直接 *255;否则先归一化
96
  pred = pred.astype(np.float32)
97
- pred_clip = np.clip(pred, 1e-6, np.nanmax(pred) if np.isfinite(pred).any() else 1.0)
98
-
99
- # 让灰度输出稳定:用分位数做一次归一化
100
- lo = np.nanquantile(pred_clip, 0.001)
101
- hi = np.nanquantile(pred_clip, 0.99)
102
- pred_norm = (pred_clip - lo) / (hi - lo + 1e-6)
103
- pred_norm = np.clip(pred_norm, 0.0, 1.0)
104
 
105
- depth_gray = (pred_norm * 255).astype(np.uint8)
106
-
107
- # 彩色图:用你改进的可视化
108
- depth_color_rgb = colorize_depth_matplotlib(pred_norm, normalize=False, cmap="Spectral")
109
 
110
  return depth_color_rgb, depth_gray
111
 
112
-
113
  demo = gr.Interface(
114
  fn=predict,
115
  inputs=gr.Image(type="numpy", label="Input Image"),
 
7
  import gradio as gr
8
  from huggingface_hub import hf_hub_download
9
 
10
+ # ================== 必须最早 import spaces ==================
11
+ try:
12
+ import spaces
13
+ gpu_decorator = spaces.GPU
14
+ except Exception:
15
+ # 本地环境没有 spaces 时兜底
16
+ gpu_decorator = lambda f: f
17
+
18
+ # ================== 工程路径 ==================
19
  PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
20
  sys.path.append(PROJECT_ROOT)
21
 
22
  from networks.models import make # noqa: E402
23
 
24
+ # ================== HF 模型配置 ==================
25
  WEIGHTS_REPO = "Insta360-Research/DAP-weights"
26
  WEIGHTS_FILE = "model.pth"
27
  CONFIG_PATH = "config/infer.yaml"
 
29
  model = None
30
  device = "cpu"
31
 
32
+ # ================== 固定颜色映射(颜色一致) ==================
33
+ import matplotlib
34
 
35
+ def colorize_depth_fixed(depth_u8: np.ndarray, cmap: str = "Spectral") -> np.ndarray:
36
+ """
37
+ depth_u8: uint8, 0~255
38
+ return: RGB uint8, 颜色全局一致
39
+ """
40
+ disp = depth_u8.astype(np.float32) / 255.0 # 固定映射
41
+ colored = matplotlib.colormaps[cmap](disp)[..., :3] # RGB float
42
+ colored = (colored * 255).astype(np.uint8)
43
+ return np.ascontiguousarray(colored)
44
 
45
+ # ================== 模型加载 ==================
 
 
 
 
 
 
 
 
 
 
 
 
46
  def load_model(config_path: str):
47
  import torch
48
  import torch.nn as nn
 
54
  config = yaml.load(f, Loader=yaml.FullLoader)
55
 
56
  print(f"Downloading weights from HF: {WEIGHTS_REPO}/{WEIGHTS_FILE}")
57
+ model_path = hf_hub_download(
58
+ repo_id=WEIGHTS_REPO,
59
+ filename=WEIGHTS_FILE
60
+ )
61
  print(f"✅ Weights downloaded to: {model_path}")
62
 
63
  state = torch.load(model_path, map_location=device)
 
68
 
69
  m = m.to(device)
70
  m_state = m.state_dict()
71
+ m.load_state_dict(
72
+ {k: v for k, v in state.items() if k in m_state},
73
+ strict=False
74
+ )
75
  m.eval()
76
  print("✅ Model loaded.")
77
  return m
78
 
79
+ # ================== 启动时加载一次模型 ==================
80
  model = load_model(CONFIG_PATH)
81
 
82
+ # ================== 推理函数(ZeroGPU 必须) ==================
83
+ @gpu_decorator
84
  def predict(img_rgb: np.ndarray):
85
  if img_rgb is None:
86
  return None, None
87
 
88
  import torch
89
 
90
+ # 输入处理
91
  img = img_rgb.astype(np.float32) / 255.0
92
+ tensor = torch.from_numpy(
93
+ img.transpose(2, 0, 1)
94
+ ).unsqueeze(0).to(device)
95
 
96
  with torch.inference_mode():
97
  outputs = model(tensor)
98
 
99
  if isinstance(outputs, dict) and "pred_depth" in outputs:
100
  if "pred_mask" in outputs:
101
+ mask = 1 - outputs["pred_mask"]
102
+ mask = mask > 0.5
103
+ outputs["pred_depth"][~mask] = 1
104
+ pred = outputs["pred_depth"][0].cpu().squeeze().numpy()
105
  else:
106
+ pred = outputs[0].cpu().squeeze().numpy()
107
 
108
+ # ================== 固定尺度(假设模型输出 0~1 ==================
109
  pred = pred.astype(np.float32)
110
+ pred_clip = np.clip(pred, 0.0, 1.0)
 
 
 
 
 
 
111
 
112
+ depth_gray = (pred_clip * 255).astype(np.uint8)
113
+ depth_color_rgb = colorize_depth_fixed(depth_gray, cmap="Spectral")
 
 
114
 
115
  return depth_color_rgb, depth_gray
116
 
117
+ # ================== Gradio UI ==================
118
  demo = gr.Interface(
119
  fn=predict,
120
  inputs=gr.Image(type="numpy", label="Input Image"),