Shengxiao0709 commited on
Commit
1960bb2
·
verified ·
1 Parent(s): a84ae56

Upload 5 files

Browse files
Files changed (6) hide show
  1. .gitattributes +1 -0
  2. 003_img.png +3 -0
  3. app.py +96 -0
  4. inference.py +48 -0
  5. medsam_vit_b.pth +3 -0
  6. requirements.txt +6 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ 003_img.png filter=lfs diff=lfs merge=lfs -text
003_img.png ADDED

Git LFS Details

  • SHA256: 41515cf5d7405135db4656c2cc61b59ab341143bfbee952b44a9542944e8528f
  • Pointer size: 131 Bytes
  • Size of remote file: 302 kB
app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio_bbox_annotator import BBoxAnnotator
3
+ from PIL import Image
4
+ import numpy as np
5
+
6
+ # 你已有的推理代码
7
+ from inference import load_model, get_embedding, run
8
+
9
+ # ---- 仅加载一次模型 ----
10
+ model, device = load_model("medsam_vit_b.pth")
11
+ def predict(value):
12
+ # value: (image_path, [(xmin, ymin, xmax, ymax, label), ...])
13
+ return value # 直接回显
14
+
15
+ def make_example(path):
16
+ return [path, []]
17
+
18
+ def parse_first_bbox(bboxes):
19
+ """
20
+ 从 annot 的 bboxes 里取第一个框,返回 (xmin, ymin, xmax, ymax)
21
+ 兼容两种格式:
22
+ - dict: {"x":..,"y":..,"width":..,"height":..}
23
+ - list: [xmin, ymin, xmax, ymax, ...]
24
+ """
25
+ if not bboxes:
26
+ return None
27
+ b = bboxes[0]
28
+ if isinstance(b, dict):
29
+ x, y = float(b["x"]), float(b["y"])
30
+ w, h = float(b["width"]), float(b["height"])
31
+ return x, y, x + w, y + h
32
+ if isinstance(b, (list, tuple)) and len(b) >= 4:
33
+ return float(b[0]), float(b[1]), float(b[2]), float(b[3])
34
+ return None
35
+
36
+ def segment(annot_value):
37
+ """
38
+ annot_value 形如 [image_path, bboxes]
39
+ - image_path: 字符串
40
+ - bboxes: 框列表
41
+ """
42
+ if annot_value is None or len(annot_value) < 1:
43
+ return None, "请先在上方上传图片并拖一个矩形框。"
44
+
45
+ img_path = annot_value[0]
46
+ bboxes = annot_value[1] if len(annot_value) > 1 else []
47
+
48
+ if not bboxes:
49
+ return None, "未检测到矩形框,请在标注区按住左键拖拽一个框。"
50
+
51
+ # 读取图片
52
+ img = Image.open(img_path).convert("RGB")
53
+ img_np = np.array(img)
54
+ H, W, _ = img_np.shape
55
+
56
+ # 取第一个框
57
+ box = parse_first_bbox(bboxes)
58
+ if box is None:
59
+ return None, "解析矩形框失败,请重画。"
60
+
61
+ xmin, ymin, xmax, ymax = box
62
+
63
+ # 归一化到 1024 并推理
64
+ box_np = np.array([[xmin, ymin, xmax, ymax]], dtype=float)
65
+ box_1024 = box_np / np.array([W, H, W, H]) * 1024.0
66
+
67
+ embedding = get_embedding(model, img_np, device)
68
+ mask = run(model, embedding, box_1024, H, W) # (H, W) 0/1
69
+
70
+ # 黑白 mask(白=前景)
71
+ mask_rgb = np.stack([mask * 255] * 3, axis=-1).astype(np.uint8)
72
+ bbox_text = f"xmin={int(xmin)}, ymin={int(ymin)}, xmax={int(xmax)}, ymax={int(ymax)}"
73
+
74
+ return Image.fromarray(mask_rgb), bbox_text
75
+
76
+ # --- 构造一个可用的示例值(让画布里有图可直接拖) ---
77
+ example = ("003_img.png", [(50, 60, 120, 150, "cell")])
78
+
79
+ demo = gr.Interface(
80
+ fn=segment, # ← 调你的推理函数
81
+ inputs=BBoxAnnotator(
82
+ value=example, # 默认示例;组件里自带“上传”按钮,可以换图
83
+ categories=["cell", "nucleus"],
84
+ label="upload"
85
+ ),
86
+ outputs=[
87
+ gr.Image(type="pil", label="Mask result"),
88
+ gr.Textbox(label="location")
89
+ ],
90
+ examples=[[example]],
91
+ cache_examples=False
92
+ )
93
+
94
+ if __name__ == "__main__":
95
+ demo.launch(server_port=7860)
96
+
inference.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from skimage import transform
4
+ from segment_anything import sam_model_registry
5
+
6
+ MEDSAM_IMG_INPUT_SIZE = 1024
7
+
8
+ def load_model(checkpoint_path):
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ model = sam_model_registry["vit_b"](checkpoint=checkpoint_path).to(device)
11
+ model.eval()
12
+ return model, device
13
+
14
+ @torch.no_grad()
15
+ def get_embedding(model, img_np, device):
16
+ img_1024 = transform.resize(
17
+ img_np, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True
18
+ ).astype(np.uint8)
19
+
20
+ img_1024 = (img_1024 - img_1024.min()) / np.clip(img_1024.max() - img_1024.min(), 1e-8, None)
21
+ img_tensor = torch.tensor(img_1024).float().permute(2, 0, 1).unsqueeze(0).to(device)
22
+ return model.image_encoder(img_tensor)
23
+
24
+ @torch.no_grad()
25
+ def run(model, embedding, box_1024, H, W):
26
+ box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=embedding.device)
27
+ if len(box_torch.shape) == 2:
28
+ box_torch = box_torch[:, None, :] # (B, 1, 4)
29
+
30
+ sparse_embeddings, dense_embeddings = model.prompt_encoder(
31
+ points=None,
32
+ boxes=box_torch,
33
+ masks=None,
34
+ )
35
+ low_res_logits, _ = model.mask_decoder(
36
+ image_embeddings=embedding,
37
+ image_pe=model.prompt_encoder.get_dense_pe(),
38
+ sparse_prompt_embeddings=sparse_embeddings,
39
+ dense_prompt_embeddings=dense_embeddings,
40
+ multimask_output=False,
41
+ )
42
+
43
+ low_res_pred = torch.sigmoid(low_res_logits)
44
+ low_res_pred = torch.nn.functional.interpolate(
45
+ low_res_pred, size=(H, W), mode="bilinear", align_corners=False
46
+ )
47
+ low_res_pred = low_res_pred.squeeze().cpu().numpy()
48
+ return (low_res_pred > 0.5).astype(np.uint8)
medsam_vit_b.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:34b34b78c1d18cb8c6bf84cf9c00e135d6d6c965699f3c0e31ef1bc9dcb5be74
3
+ size 375049145
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ streamlit
2
+ torch
3
+ numpy
4
+ Pillow
5
+ scikit-image
6
+