Spaces:
Sleeping
Sleeping
File size: 2,896 Bytes
e8de836 4b238e9 8180bcc 4b238e9 4f5ff11 e8de836 a5f49f2 ac1ce99 a5f49f2 4f5ff11 a5f49f2 4f5ff11 ac1ce99 4f5ff11 051beda ac1ce99 051beda ac1ce99 e8de836 ac1ce99 e8de836 dd17bac ac1ce99 e8de836 ac1ce99 e8de836 ac1ce99 e8de836 ac1ce99 e8de836 2421c22 e8de836 a5f49f2 ac1ce99 e8de836 ac1ce99 e8de836 ac1ce99 e8de836 ac1ce99 e8de836 dfc6d2f e8de836 ac1ce99 e8de836 ac1ce99 e8de836 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import gradio as gr
from gradio_bbox_annotator import BBoxAnnotator
from PIL import Image
import numpy as np
from inference import load_model, get_embedding, run
import torch
import os
import spaces
MODEL = None
DEVICE = torch.device("cpu")
CUDA_READY = False
def load_model_cpu(checkpoint_path: str):
global MODEL, DEVICE
MODEL, _ = load_model(checkpoint_path)
MODEL = MODEL.to("cpu")
MODEL.eval()
DEVICE = torch.device("cpu")
load_model_cpu("medsam_vit_b.pth")
@spaces.GPU
def prepare_cuda():
global MODEL, DEVICE, CUDA_READY
if torch.cuda.is_available() and not CUDA_READY:
print("CUDA is available. Moving model to GPU...")
MODEL.to("cuda")
DEVICE = torch.device("cuda")
CUDA_READY = True
_ = torch.zeros(1, device=DEVICE)
print("Model moved to CUDA.")
else:
print("CUDA not available or already initialized.")
def parse_first_bbox(bboxes):
if not bboxes:
return None
b = bboxes[0]
if isinstance(b, dict):
x, y = float(b["x"]), float(b["y"])
w, h = float(b["width"]), float(b["height"])
return x, y, x + w, y + h
if isinstance(b, (list, tuple)) and len(b) >= 4:
return float(b[0]), float(b[1]), float(b[2]), float(b[3])
return None
def segment(annot_value):
prepare_cuda()
if annot_value is None or len(annot_value) < 1:
return None,
img_path = annot_value[0]
bboxes = annot_value[1] if len(annot_value) > 1 else []
if not bboxes:
return None,
img = Image.open(img_path).convert("RGB")
img_np = np.array(img)
H, W, _ = img_np.shape
box = parse_first_bbox(bboxes)
if box is None:
return None, "解析矩形框失败,请重画。"
xmin, ymin, xmax, ymax = box
xmin, ymin, xmax, ymax = map(int, [xmin, ymin, xmax, ymax])
box_np = np.array([[xmin, ymin, xmax, ymax]], dtype=float)
box_1024 = box_np / np.array([W, H, W, H]) * 1024.0
embedding = get_embedding(MODEL, img_np, DEVICE)
mask = run(MODEL, embedding, box_1024, H, W)
mask_rgb = np.stack([mask * 255] * 3, axis=-1).astype(np.uint8)
bbox_text = f"xmin={int(xmin)}, ymin={int(ymin)}, xmax={int(xmax)}, ymax={int(ymax)}"
return Image.fromarray(mask_rgb), bbox_text
example = ("003_img.png", [(50, 60, 120, 150, "cell")])
demo = gr.Interface(
fn=segment,
inputs=BBoxAnnotator(
value=example,
categories=["cell", "nucleus"],
label="upload"
),
outputs=[
gr.Image(type="pil", label="Mask result"),
gr.Textbox(label="location")
],
examples=[[example]],
cache_examples=False
)
if __name__ == "__main__":
demo.queue().launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True,
ssr_mode=False
)
|