File size: 2,896 Bytes
e8de836
 
 
 
 
 
4b238e9
 
8180bcc
4b238e9
4f5ff11
 
 
 
e8de836
a5f49f2
 
ac1ce99
a5f49f2
 
 
4f5ff11
a5f49f2
4f5ff11
 
 
 
 
ac1ce99
4f5ff11
 
 
051beda
ac1ce99
051beda
ac1ce99
e8de836
 
ac1ce99
e8de836
 
 
 
 
 
 
 
 
 
 
 
dd17bac
ac1ce99
e8de836
ac1ce99
e8de836
 
 
 
 
ac1ce99
e8de836
 
 
 
 
ac1ce99
e8de836
 
 
 
 
2421c22
e8de836
 
 
a5f49f2
ac1ce99
 
e8de836
 
 
 
 
 
ac1ce99
e8de836
 
 
ac1ce99
e8de836
ac1ce99
e8de836
 
 
 
 
 
 
 
 
 
 
 
dfc6d2f
e8de836
 
ac1ce99
e8de836
ac1ce99
e8de836
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import gradio as gr
from gradio_bbox_annotator import BBoxAnnotator
from PIL import Image
import numpy as np

from inference import load_model, get_embedding, run
import torch
import os
import spaces


MODEL = None
DEVICE = torch.device("cpu")
CUDA_READY = False

def load_model_cpu(checkpoint_path: str):
    global MODEL, DEVICE
    MODEL, _ = load_model(checkpoint_path)  
    MODEL = MODEL.to("cpu")
    MODEL.eval()
    DEVICE = torch.device("cpu")

load_model_cpu("medsam_vit_b.pth")

@spaces.GPU
def prepare_cuda():
    global MODEL, DEVICE, CUDA_READY
    if torch.cuda.is_available() and not CUDA_READY:
        print("CUDA is available. Moving model to GPU...")
        MODEL.to("cuda")
        DEVICE = torch.device("cuda")
        CUDA_READY = True
        _ = torch.zeros(1, device=DEVICE)
        print("Model moved to CUDA.")
    else:
        print("CUDA not available or already initialized.")

def parse_first_bbox(bboxes):

    if not bboxes:
        return None
    b = bboxes[0]
    if isinstance(b, dict):
        x, y = float(b["x"]), float(b["y"])
        w, h = float(b["width"]), float(b["height"])
        return x, y, x + w, y + h
    if isinstance(b, (list, tuple)) and len(b) >= 4:
        return float(b[0]), float(b[1]), float(b[2]), float(b[3])
    return None

def segment(annot_value):
    prepare_cuda()

    if annot_value is None or len(annot_value) < 1:
        return None, 

    img_path = annot_value[0]
    bboxes = annot_value[1] if len(annot_value) > 1 else []

    if not bboxes:
        return None, 

    img = Image.open(img_path).convert("RGB")
    img_np = np.array(img)
    H, W, _ = img_np.shape


    box = parse_first_bbox(bboxes)
    if box is None:
        return None, "解析矩形框失败,请重画。"

    xmin, ymin, xmax, ymax = box
    xmin, ymin, xmax, ymax = map(int, [xmin, ymin, xmax, ymax])
    box_np = np.array([[xmin, ymin, xmax, ymax]], dtype=float)
    box_1024 = box_np / np.array([W, H, W, H]) * 1024.0

    embedding = get_embedding(MODEL, img_np, DEVICE)
    mask = run(MODEL, embedding, box_1024, H, W)  


    mask_rgb = np.stack([mask * 255] * 3, axis=-1).astype(np.uint8)
    bbox_text = f"xmin={int(xmin)}, ymin={int(ymin)}, xmax={int(xmax)}, ymax={int(ymax)}"

    return Image.fromarray(mask_rgb), bbox_text


example = ("003_img.png", [(50, 60, 120, 150, "cell")])

demo = gr.Interface(
    fn=segment,  
    inputs=BBoxAnnotator(
        value=example,                 
        categories=["cell", "nucleus"],
        label="upload"
    ),
    outputs=[
        gr.Image(type="pil", label="Mask result"),
        gr.Textbox(label="location")
    ],
    examples=[[example]],
    cache_examples=False
)

if __name__ == "__main__":
    demo.queue().launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False,         
        show_error=True,
        ssr_mode=False       
    )