segment / app.py
Shengxiao0709's picture
Update app.py
ac1ce99 verified
import gradio as gr
from gradio_bbox_annotator import BBoxAnnotator
from PIL import Image
import numpy as np
from inference import load_model, get_embedding, run
import torch
import os
import spaces
MODEL = None
DEVICE = torch.device("cpu")
CUDA_READY = False
def load_model_cpu(checkpoint_path: str):
global MODEL, DEVICE
MODEL, _ = load_model(checkpoint_path)
MODEL = MODEL.to("cpu")
MODEL.eval()
DEVICE = torch.device("cpu")
load_model_cpu("medsam_vit_b.pth")
@spaces.GPU
def prepare_cuda():
global MODEL, DEVICE, CUDA_READY
if torch.cuda.is_available() and not CUDA_READY:
print("CUDA is available. Moving model to GPU...")
MODEL.to("cuda")
DEVICE = torch.device("cuda")
CUDA_READY = True
_ = torch.zeros(1, device=DEVICE)
print("Model moved to CUDA.")
else:
print("CUDA not available or already initialized.")
def parse_first_bbox(bboxes):
if not bboxes:
return None
b = bboxes[0]
if isinstance(b, dict):
x, y = float(b["x"]), float(b["y"])
w, h = float(b["width"]), float(b["height"])
return x, y, x + w, y + h
if isinstance(b, (list, tuple)) and len(b) >= 4:
return float(b[0]), float(b[1]), float(b[2]), float(b[3])
return None
def segment(annot_value):
prepare_cuda()
if annot_value is None or len(annot_value) < 1:
return None,
img_path = annot_value[0]
bboxes = annot_value[1] if len(annot_value) > 1 else []
if not bboxes:
return None,
img = Image.open(img_path).convert("RGB")
img_np = np.array(img)
H, W, _ = img_np.shape
box = parse_first_bbox(bboxes)
if box is None:
return None, "解析矩形框失败,请重画。"
xmin, ymin, xmax, ymax = box
xmin, ymin, xmax, ymax = map(int, [xmin, ymin, xmax, ymax])
box_np = np.array([[xmin, ymin, xmax, ymax]], dtype=float)
box_1024 = box_np / np.array([W, H, W, H]) * 1024.0
embedding = get_embedding(MODEL, img_np, DEVICE)
mask = run(MODEL, embedding, box_1024, H, W)
mask_rgb = np.stack([mask * 255] * 3, axis=-1).astype(np.uint8)
bbox_text = f"xmin={int(xmin)}, ymin={int(ymin)}, xmax={int(xmax)}, ymax={int(ymax)}"
return Image.fromarray(mask_rgb), bbox_text
example = ("003_img.png", [(50, 60, 120, 150, "cell")])
demo = gr.Interface(
fn=segment,
inputs=BBoxAnnotator(
value=example,
categories=["cell", "nucleus"],
label="upload"
),
outputs=[
gr.Image(type="pil", label="Mask result"),
gr.Textbox(label="location")
],
examples=[[example]],
cache_examples=False
)
if __name__ == "__main__":
demo.queue().launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True,
ssr_mode=False
)