Dingning commited on
Commit
cf8bfe4
·
verified ·
1 Parent(s): a801398

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -27,10 +27,10 @@ css = """
27
  }
28
  """
29
 
30
- # ====== 设备选择 ======
31
- DEVICE = "cuda" else "cpu"
32
 
33
- # ====== 模型加载 ======
34
  model = Bridge()
35
  filepath = hf_hub_download(repo_id=f"Dingning/BRIDGE", filename=f"bridge.pth", repo_type="model")
36
  state_dict = torch.load(filepath, map_location="cpu")
@@ -39,7 +39,7 @@ state_dict = torch.load(filepath, map_location="cpu")
39
  model.load_state_dict(state_dict)
40
  model = model.to(DEVICE).eval()
41
 
42
- # ====== 文本描述 ======
43
  title = "# Bridge Simplified Demo"
44
  description = """
45
  Official demo for Bridge using Gradio.
@@ -49,7 +49,7 @@ Official demo for Bridge using Gradio.
49
 
50
  cmap = matplotlib.colormaps.get_cmap("Spectral_r")
51
 
52
- # ====== 推理函数 ======
53
  @spaces.GPU
54
  def predict_depth(image: np.ndarray) -> np.ndarray:
55
  """Run depth inference on an RGB image (numpy)."""
@@ -59,24 +59,24 @@ def on_submit(image: np.ndarray):
59
  original_image = image.copy()
60
  depth = predict_depth(image)
61
 
62
- # 保存 16-bit 原始深度图
63
  raw_depth = Image.fromarray(depth.astype("uint16"))
64
  tmp_raw_depth = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
65
  raw_depth.save(tmp_raw_depth.name)
66
 
67
- # 归一化 + 着色
68
  depth_norm = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
69
  depth_uint8 = depth_norm.astype(np.uint8)
70
  colored_depth = (cmap(depth_uint8)[:, :, :3] * 255).astype(np.uint8)
71
 
72
- # 保存灰度图
73
  gray_depth = Image.fromarray(depth_uint8)
74
  tmp_gray_depth = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
75
  gray_depth.save(tmp_gray_depth.name)
76
 
77
  return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name]
78
 
79
- # ====== Gradio 界面 ======
80
  with gr.Blocks(css=css) as demo:
81
  gr.Markdown(title)
82
  gr.Markdown(description)
@@ -99,7 +99,7 @@ with gr.Blocks(css=css) as demo:
99
  outputs=[depth_image_slider, gray_depth_file, raw_file]
100
  )
101
 
102
- # 加载示例图片
103
  if os.path.exists("assets/examples"):
104
  example_files = sorted(os.listdir("assets/examples"))
105
  example_files = [os.path.join("assets/examples", f) for f in example_files]
 
27
  }
28
  """
29
 
30
+ # ====== device ====== 
31
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
32
 
33
+ # ====== model load ======
34
  model = Bridge()
35
  filepath = hf_hub_download(repo_id=f"Dingning/BRIDGE", filename=f"bridge.pth", repo_type="model")
36
  state_dict = torch.load(filepath, map_location="cpu")
 
39
  model.load_state_dict(state_dict)
40
  model = model.to(DEVICE).eval()
41
 
42
+ # ====== description ======
43
  title = "# Bridge Simplified Demo"
44
  description = """
45
  Official demo for Bridge using Gradio.
 
49
 
50
  cmap = matplotlib.colormaps.get_cmap("Spectral_r")
51
 
52
+ # ====== inference ======
53
  @spaces.GPU
54
  def predict_depth(image: np.ndarray) -> np.ndarray:
55
  """Run depth inference on an RGB image (numpy)."""
 
59
  original_image = image.copy()
60
  depth = predict_depth(image)
61
 
62
+ # 16-bit depth map
63
  raw_depth = Image.fromarray(depth.astype("uint16"))
64
  tmp_raw_depth = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
65
  raw_depth.save(tmp_raw_depth.name)
66
 
67
+ # normalization and colorize
68
  depth_norm = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
69
  depth_uint8 = depth_norm.astype(np.uint8)
70
  colored_depth = (cmap(depth_uint8)[:, :, :3] * 255).astype(np.uint8)
71
 
72
+ # save depth map
73
  gray_depth = Image.fromarray(depth_uint8)
74
  tmp_gray_depth = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
75
  gray_depth.save(tmp_gray_depth.name)
76
 
77
  return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name]
78
 
79
+ # ====== Gradio UI======
80
  with gr.Blocks(css=css) as demo:
81
  gr.Markdown(title)
82
  gr.Markdown(description)
 
99
  outputs=[depth_image_slider, gray_depth_file, raw_file]
100
  )
101
 
102
+ # examples
103
  if os.path.exists("assets/examples"):
104
  example_files = sorted(os.listdir("assets/examples"))
105
  example_files = [os.path.join("assets/examples", f) for f in example_files]