Shengxiao0709 commited on
Commit
ac1ce99
·
verified ·
1 Parent(s): a4fce5b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -45
app.py CHANGED
@@ -3,37 +3,19 @@ from gradio_bbox_annotator import BBoxAnnotator
3
  from PIL import Image
4
  import numpy as np
5
 
6
- # 你已有的推理代码
7
  from inference import load_model, get_embedding, run
8
  import torch
9
  import os
10
  import spaces
11
 
12
- # @spaces.GPU
13
- # def warmup():
14
- # import torch
15
- # if torch.cuda.is_available():
16
- # torch.zeros(1).to("cuda")
17
 
18
-
19
- # # ---- 仅加载一次模型 ----
20
- # model, device = load_model("medsam_vit_b.pth")
21
- # def predict(value):
22
- # # value: (image_path, [(xmin, ymin, xmax, ymax, label), ...])
23
- # return value # 直接回显
24
-
25
- # def make_example(path):
26
- # return [path, []]
27
-
28
- # --------- 全局状态 ---------
29
  MODEL = None
30
  DEVICE = torch.device("cpu")
31
  CUDA_READY = False
32
 
33
  def load_model_cpu(checkpoint_path: str):
34
  global MODEL, DEVICE
35
- # 要求 inference.load_model 不要在内部 .to("cuda")
36
- MODEL, _ = load_model(checkpoint_path) # 或者直接返回 model
37
  MODEL = MODEL.to("cpu")
38
  MODEL.eval()
39
  DEVICE = torch.device("cpu")
@@ -44,22 +26,17 @@ load_model_cpu("medsam_vit_b.pth")
44
  def prepare_cuda():
45
  global MODEL, DEVICE, CUDA_READY
46
  if torch.cuda.is_available() and not CUDA_READY:
47
- print("🎯 CUDA is available. Moving model to GPU...")
48
  MODEL.to("cuda")
49
  DEVICE = torch.device("cuda")
50
  CUDA_READY = True
51
  _ = torch.zeros(1, device=DEVICE)
52
- print("Model moved to CUDA.")
53
  else:
54
- print("CUDA not available or already initialized.")
55
 
56
  def parse_first_bbox(bboxes):
57
- """
58
- 从 annot 的 bboxes 里取第一个框,返回 (xmin, ymin, xmax, ymax)
59
- 兼容两种格式:
60
- - dict: {"x":..,"y":..,"width":..,"height":..}
61
- - list: [xmin, ymin, xmax, ymax, ...]
62
- """
63
  if not bboxes:
64
  return None
65
  b = bboxes[0]
@@ -73,52 +50,46 @@ def parse_first_bbox(bboxes):
73
 
74
  def segment(annot_value):
75
  prepare_cuda()
76
- """
77
- annot_value 形如 [image_path, bboxes]
78
- - image_path: 字符串
79
- - bboxes: 框列表
80
- """
81
  if annot_value is None or len(annot_value) < 1:
82
- return None, "请先在上方上传图片并拖一个矩形框。"
83
 
84
  img_path = annot_value[0]
85
  bboxes = annot_value[1] if len(annot_value) > 1 else []
86
 
87
  if not bboxes:
88
- return None, "未检测到矩形框,请在标注区按住左键拖拽一个框。"
89
 
90
- # 读取图片
91
  img = Image.open(img_path).convert("RGB")
92
  img_np = np.array(img)
93
  H, W, _ = img_np.shape
94
 
95
- # 取第一个框
96
  box = parse_first_bbox(bboxes)
97
  if box is None:
98
  return None, "解析矩形框失败,请重画。"
99
 
100
  xmin, ymin, xmax, ymax = box
101
  xmin, ymin, xmax, ymax = map(int, [xmin, ymin, xmax, ymax])
102
- # 归一化到 1024 并推理
103
  box_np = np.array([[xmin, ymin, xmax, ymax]], dtype=float)
104
  box_1024 = box_np / np.array([W, H, W, H]) * 1024.0
105
 
106
  embedding = get_embedding(MODEL, img_np, DEVICE)
107
- mask = run(MODEL, embedding, box_1024, H, W) # (H, W) 0/1
 
108
 
109
- # 黑白 mask(白=前景)
110
  mask_rgb = np.stack([mask * 255] * 3, axis=-1).astype(np.uint8)
111
  bbox_text = f"xmin={int(xmin)}, ymin={int(ymin)}, xmax={int(xmax)}, ymax={int(ymax)}"
112
 
113
  return Image.fromarray(mask_rgb), bbox_text
114
 
115
- # --- 构造一个可用的示例值(让画布里有图可直接拖) ---
116
  example = ("003_img.png", [(50, 60, 120, 150, "cell")])
117
 
118
  demo = gr.Interface(
119
- fn=segment, # ← 调你的推理函数
120
  inputs=BBoxAnnotator(
121
- value=example, # 默认示例;组件里自带“上传”按钮,可以换图
122
  categories=["cell", "nucleus"],
123
  label="upload"
124
  ),
@@ -134,9 +105,9 @@ if __name__ == "__main__":
134
  demo.queue().launch(
135
  server_name="0.0.0.0",
136
  server_port=7860,
137
- share=False, # 不需要 public link,HF 会自动映射
138
  show_error=True,
139
- ssr_mode=False # 关闭 SSR(这个是触发崩溃的常见元凶)
140
  )
141
 
142
 
 
3
  from PIL import Image
4
  import numpy as np
5
 
 
6
  from inference import load_model, get_embedding, run
7
  import torch
8
  import os
9
  import spaces
10
 
 
 
 
 
 
11
 
 
 
 
 
 
 
 
 
 
 
 
12
  MODEL = None
13
  DEVICE = torch.device("cpu")
14
  CUDA_READY = False
15
 
16
  def load_model_cpu(checkpoint_path: str):
17
  global MODEL, DEVICE
18
+ MODEL, _ = load_model(checkpoint_path)
 
19
  MODEL = MODEL.to("cpu")
20
  MODEL.eval()
21
  DEVICE = torch.device("cpu")
 
26
  def prepare_cuda():
27
  global MODEL, DEVICE, CUDA_READY
28
  if torch.cuda.is_available() and not CUDA_READY:
29
+ print("CUDA is available. Moving model to GPU...")
30
  MODEL.to("cuda")
31
  DEVICE = torch.device("cuda")
32
  CUDA_READY = True
33
  _ = torch.zeros(1, device=DEVICE)
34
+ print("Model moved to CUDA.")
35
  else:
36
+ print("CUDA not available or already initialized.")
37
 
38
  def parse_first_bbox(bboxes):
39
+
 
 
 
 
 
40
  if not bboxes:
41
  return None
42
  b = bboxes[0]
 
50
 
51
  def segment(annot_value):
52
  prepare_cuda()
53
+
 
 
 
 
54
  if annot_value is None or len(annot_value) < 1:
55
+ return None,
56
 
57
  img_path = annot_value[0]
58
  bboxes = annot_value[1] if len(annot_value) > 1 else []
59
 
60
  if not bboxes:
61
+ return None,
62
 
 
63
  img = Image.open(img_path).convert("RGB")
64
  img_np = np.array(img)
65
  H, W, _ = img_np.shape
66
 
67
+
68
  box = parse_first_bbox(bboxes)
69
  if box is None:
70
  return None, "解析矩形框失败,请重画。"
71
 
72
  xmin, ymin, xmax, ymax = box
73
  xmin, ymin, xmax, ymax = map(int, [xmin, ymin, xmax, ymax])
 
74
  box_np = np.array([[xmin, ymin, xmax, ymax]], dtype=float)
75
  box_1024 = box_np / np.array([W, H, W, H]) * 1024.0
76
 
77
  embedding = get_embedding(MODEL, img_np, DEVICE)
78
+ mask = run(MODEL, embedding, box_1024, H, W)
79
+
80
 
 
81
  mask_rgb = np.stack([mask * 255] * 3, axis=-1).astype(np.uint8)
82
  bbox_text = f"xmin={int(xmin)}, ymin={int(ymin)}, xmax={int(xmax)}, ymax={int(ymax)}"
83
 
84
  return Image.fromarray(mask_rgb), bbox_text
85
 
86
+
87
  example = ("003_img.png", [(50, 60, 120, 150, "cell")])
88
 
89
  demo = gr.Interface(
90
+ fn=segment,
91
  inputs=BBoxAnnotator(
92
+ value=example,
93
  categories=["cell", "nucleus"],
94
  label="upload"
95
  ),
 
105
  demo.queue().launch(
106
  server_name="0.0.0.0",
107
  server_port=7860,
108
+ share=False,
109
  show_error=True,
110
+ ssr_mode=False
111
  )
112
 
113