Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from gradio_bbox_annotator import BBoxAnnotator | |
| from PIL import Image | |
| import numpy as np | |
| from inference import load_model, get_embedding, run | |
| import torch | |
| import os | |
| import spaces | |
| MODEL = None | |
| DEVICE = torch.device("cpu") | |
| CUDA_READY = False | |
| def load_model_cpu(checkpoint_path: str): | |
| global MODEL, DEVICE | |
| MODEL, _ = load_model(checkpoint_path) | |
| MODEL = MODEL.to("cpu") | |
| MODEL.eval() | |
| DEVICE = torch.device("cpu") | |
| load_model_cpu("medsam_vit_b.pth") | |
| def prepare_cuda(): | |
| global MODEL, DEVICE, CUDA_READY | |
| if torch.cuda.is_available() and not CUDA_READY: | |
| print("CUDA is available. Moving model to GPU...") | |
| MODEL.to("cuda") | |
| DEVICE = torch.device("cuda") | |
| CUDA_READY = True | |
| _ = torch.zeros(1, device=DEVICE) | |
| print("Model moved to CUDA.") | |
| else: | |
| print("CUDA not available or already initialized.") | |
| def parse_first_bbox(bboxes): | |
| if not bboxes: | |
| return None | |
| b = bboxes[0] | |
| if isinstance(b, dict): | |
| x, y = float(b["x"]), float(b["y"]) | |
| w, h = float(b["width"]), float(b["height"]) | |
| return x, y, x + w, y + h | |
| if isinstance(b, (list, tuple)) and len(b) >= 4: | |
| return float(b[0]), float(b[1]), float(b[2]), float(b[3]) | |
| return None | |
| def segment(annot_value): | |
| prepare_cuda() | |
| if annot_value is None or len(annot_value) < 1: | |
| return None, | |
| img_path = annot_value[0] | |
| bboxes = annot_value[1] if len(annot_value) > 1 else [] | |
| if not bboxes: | |
| return None, | |
| img = Image.open(img_path).convert("RGB") | |
| img_np = np.array(img) | |
| H, W, _ = img_np.shape | |
| box = parse_first_bbox(bboxes) | |
| if box is None: | |
| return None, "解析矩形框失败,请重画。" | |
| xmin, ymin, xmax, ymax = box | |
| xmin, ymin, xmax, ymax = map(int, [xmin, ymin, xmax, ymax]) | |
| box_np = np.array([[xmin, ymin, xmax, ymax]], dtype=float) | |
| box_1024 = box_np / np.array([W, H, W, H]) * 1024.0 | |
| embedding = get_embedding(MODEL, img_np, DEVICE) | |
| mask = run(MODEL, embedding, box_1024, H, W) | |
| mask_rgb = np.stack([mask * 255] * 3, axis=-1).astype(np.uint8) | |
| bbox_text = f"xmin={int(xmin)}, ymin={int(ymin)}, xmax={int(xmax)}, ymax={int(ymax)}" | |
| return Image.fromarray(mask_rgb), bbox_text | |
| example = ("003_img.png", [(50, 60, 120, 150, "cell")]) | |
| demo = gr.Interface( | |
| fn=segment, | |
| inputs=BBoxAnnotator( | |
| value=example, | |
| categories=["cell", "nucleus"], | |
| label="upload" | |
| ), | |
| outputs=[ | |
| gr.Image(type="pil", label="Mask result"), | |
| gr.Textbox(label="location") | |
| ], | |
| examples=[[example]], | |
| cache_examples=False | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True, | |
| ssr_mode=False | |
| ) | |