import gradio as gr from gradio_bbox_annotator import BBoxAnnotator from PIL import Image import numpy as np from inference import load_model, get_embedding, run import torch import os import spaces MODEL = None DEVICE = torch.device("cpu") CUDA_READY = False def load_model_cpu(checkpoint_path: str): global MODEL, DEVICE MODEL, _ = load_model(checkpoint_path) MODEL = MODEL.to("cpu") MODEL.eval() DEVICE = torch.device("cpu") load_model_cpu("medsam_vit_b.pth") @spaces.GPU def prepare_cuda(): global MODEL, DEVICE, CUDA_READY if torch.cuda.is_available() and not CUDA_READY: print("CUDA is available. Moving model to GPU...") MODEL.to("cuda") DEVICE = torch.device("cuda") CUDA_READY = True _ = torch.zeros(1, device=DEVICE) print("Model moved to CUDA.") else: print("CUDA not available or already initialized.") def parse_first_bbox(bboxes): if not bboxes: return None b = bboxes[0] if isinstance(b, dict): x, y = float(b["x"]), float(b["y"]) w, h = float(b["width"]), float(b["height"]) return x, y, x + w, y + h if isinstance(b, (list, tuple)) and len(b) >= 4: return float(b[0]), float(b[1]), float(b[2]), float(b[3]) return None def segment(annot_value): prepare_cuda() if annot_value is None or len(annot_value) < 1: return None, img_path = annot_value[0] bboxes = annot_value[1] if len(annot_value) > 1 else [] if not bboxes: return None, img = Image.open(img_path).convert("RGB") img_np = np.array(img) H, W, _ = img_np.shape box = parse_first_bbox(bboxes) if box is None: return None, "解析矩形框失败,请重画。" xmin, ymin, xmax, ymax = box xmin, ymin, xmax, ymax = map(int, [xmin, ymin, xmax, ymax]) box_np = np.array([[xmin, ymin, xmax, ymax]], dtype=float) box_1024 = box_np / np.array([W, H, W, H]) * 1024.0 embedding = get_embedding(MODEL, img_np, DEVICE) mask = run(MODEL, embedding, box_1024, H, W) mask_rgb = np.stack([mask * 255] * 3, axis=-1).astype(np.uint8) bbox_text = f"xmin={int(xmin)}, ymin={int(ymin)}, xmax={int(xmax)}, ymax={int(ymax)}" return Image.fromarray(mask_rgb), bbox_text example = ("003_img.png", [(50, 60, 120, 150, "cell")]) demo = gr.Interface( fn=segment, inputs=BBoxAnnotator( value=example, categories=["cell", "nucleus"], label="upload" ), outputs=[ gr.Image(type="pil", label="Mask result"), gr.Textbox(label="location") ], examples=[[example]], cache_examples=False ) if __name__ == "__main__": demo.queue().launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True, ssr_mode=False )