Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Surgical-DeSAM Gradio App for Hugging Face Spaces | |
| Supports both Image and Video segmentation with ZeroGPU | |
| """ | |
| import os | |
| import spaces | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| import tempfile | |
| # Model imports | |
| from models.detr_seg import DETR, SAMModel | |
| from models.backbone import build_backbone | |
| from models.transformer import build_transformer | |
| from util.misc import NestedTensor | |
| # Configuration | |
| MODEL_REPO = os.environ.get("MODEL_REPO", "IFMedTech/surgical-desam-weights") | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| INSTRUMENT_CLASSES = ( | |
| 'bipolar_forceps', 'prograsp_forceps', 'large_needle_driver', | |
| 'monopolar_curved_scissors', 'ultrasound_probe', 'suction', | |
| 'clip_applier', 'stapler' | |
| ) | |
| COLORS = [ | |
| [0, 114, 189], [217, 83, 25], [237, 177, 32], | |
| [126, 47, 142], [119, 172, 48], [77, 190, 238], | |
| [162, 20, 47], [76, 76, 76] | |
| ] | |
| # Global model variables | |
| model = None | |
| seg_model = None | |
| device = None | |
| def download_weights(): | |
| """Download model weights from private HF repo""" | |
| weights_dir = "weights" | |
| os.makedirs(weights_dir, exist_ok=True) | |
| desam_path = hf_hub_download( | |
| repo_id=MODEL_REPO, | |
| filename="surgical_desam_1024.pth", | |
| token=HF_TOKEN, | |
| local_dir=weights_dir | |
| ) | |
| sam_path = hf_hub_download( | |
| repo_id=MODEL_REPO, | |
| filename="sam_vit_b_01ec64.pth", | |
| token=HF_TOKEN, | |
| local_dir=weights_dir | |
| ) | |
| swin_dir = "swin_backbone" | |
| os.makedirs(swin_dir, exist_ok=True) | |
| hf_hub_download( | |
| repo_id=MODEL_REPO, | |
| filename="swin_base_patch4_window7_224_22kto1k.pth", | |
| token=HF_TOKEN, | |
| local_dir=swin_dir | |
| ) | |
| return desam_path, sam_path | |
| class Args: | |
| """Mock args for model building""" | |
| backbone = 'swin_B_224_22k' | |
| dilation = False | |
| position_embedding = 'sine' | |
| hidden_dim = 256 | |
| dropout = 0.1 | |
| nheads = 8 | |
| dim_feedforward = 2048 | |
| enc_layers = 6 | |
| dec_layers = 6 | |
| pre_norm = False | |
| num_queries = 100 | |
| aux_loss = False | |
| lr_backbone = 1e-5 | |
| masks = False | |
| dataset_file = 'endovis18' | |
| device = 'cuda' | |
| backbone_dir = './swin_backbone' | |
| def load_models(): | |
| """Load DETR and SAM models""" | |
| global model, seg_model, device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| desam_path, sam_path = download_weights() | |
| args = Args() | |
| args.device = str(device) | |
| backbone = build_backbone(args) | |
| transformer = build_transformer(args) | |
| model = DETR( | |
| backbone, | |
| transformer, | |
| num_classes=9, | |
| num_queries=args.num_queries, | |
| aux_loss=args.aux_loss, | |
| ) | |
| checkpoint = torch.load(desam_path, map_location='cpu', weights_only=False) | |
| model.load_state_dict(checkpoint['model'], strict=False) | |
| model.to(device) | |
| model.eval() | |
| seg_model = SAMModel(device=device, ckpt_path=sam_path) | |
| if 'seg_model' in checkpoint: | |
| seg_model.load_state_dict(checkpoint['seg_model']) | |
| seg_model.to(device) | |
| seg_model.eval() | |
| print("Models loaded successfully!") | |
| def preprocess_frame(frame): | |
| """Preprocess frame for model input""" | |
| img = cv2.resize(frame, (1024, 1024)) | |
| img = img.astype(np.float32) / 255.0 | |
| mean = np.array([0.485, 0.456, 0.406]) | |
| std = np.array([0.229, 0.224, 0.225]) | |
| img = (img - mean) / std | |
| img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).float() | |
| return img_tensor | |
| def box_cxcywh_to_xyxy(x): | |
| """Convert boxes from center format to corner format""" | |
| x_c, y_c, w, h = x.unbind(-1) | |
| b = [(x_c - 0.5 * w), (y_c - 0.5 * h), | |
| (x_c + 0.5 * w), (y_c + 0.5 * h)] | |
| return torch.stack(b, dim=-1) | |
| def process_single_frame(frame_rgb, h, w): | |
| """Process a single frame and return segmented result""" | |
| global model, seg_model, device | |
| img_tensor = preprocess_frame(frame_rgb).unsqueeze(0).to(device) | |
| mask = torch.zeros((1, 1024, 1024), dtype=torch.bool, device=device) | |
| samples = NestedTensor(img_tensor, mask) | |
| with torch.no_grad(): | |
| outputs, image_embeddings = model(samples) | |
| probas = outputs['pred_logits'].softmax(-1)[0, :, :-1] | |
| keep = probas.max(-1).values > 0.3 | |
| if not keep.any(): | |
| return frame_rgb # No detections | |
| boxes = outputs['pred_boxes'][0, keep] | |
| scores = probas[keep].max(-1).values.cpu().numpy() | |
| labels = probas[keep].argmax(-1).cpu().numpy() | |
| boxes_scaled = box_cxcywh_to_xyxy(boxes) * torch.tensor([w, h, w, h], device=device) | |
| boxes_np = boxes_scaled.cpu().numpy() | |
| low_res_masks, pred_masks, _ = seg_model( | |
| img_tensor, boxes, image_embeddings, | |
| sizes=(1024, 1024), add_noise=False | |
| ) | |
| masks_np = pred_masks.cpu().numpy() | |
| # Draw on frame | |
| result = frame_rgb.copy() | |
| for i, (box, label, mask_pred, score) in enumerate(zip(boxes_np, labels, masks_np, scores)): | |
| if score < 0.3: | |
| continue | |
| color = COLORS[label % len(COLORS)] | |
| # Draw mask | |
| mask_resized = cv2.resize(mask_pred, (w, h)) | |
| mask_bool = mask_resized > 0.5 | |
| overlay = result.copy() | |
| overlay[mask_bool] = color | |
| result = cv2.addWeighted(result, 0.6, overlay, 0.4, 0) | |
| # Draw box | |
| x1, y1, x2, y2 = box.astype(int) | |
| cv2.rectangle(result, (x1, y1), (x2, y2), color, 2) | |
| # Draw label | |
| label_text = f"{INSTRUMENT_CLASSES[label]}: {score:.2f}" | |
| cv2.putText(result, label_text, (x1, y1 - 10), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) | |
| return result | |
| def predict_image(image): | |
| """Run inference on input image""" | |
| global model, seg_model, device | |
| if model is None: | |
| load_models() | |
| if image is None: | |
| return None | |
| frame_rgb = np.array(image) | |
| h, w = frame_rgb.shape[:2] | |
| result = process_single_frame(frame_rgb, h, w) | |
| return Image.fromarray(result) | |
| def predict_video(video_path, progress=gr.Progress()): | |
| """Process video and return segmented video""" | |
| global model, seg_model, device | |
| if model is None: | |
| progress(0, desc="Loading models...") | |
| load_models() | |
| if video_path is None: | |
| return None | |
| # Open video | |
| cap = cv2.VideoCapture(video_path) | |
| fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| # Output video | |
| output_path = tempfile.mktemp(suffix=".mp4") | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) | |
| frame_count = 0 | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| # BGR to RGB | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| # Process frame | |
| result_rgb = process_single_frame(frame_rgb, height, width) | |
| # RGB to BGR for output | |
| result_bgr = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR) | |
| out.write(result_bgr) | |
| frame_count += 1 | |
| progress(frame_count / total_frames, desc=f"Processing frame {frame_count}/{total_frames}") | |
| cap.release() | |
| out.release() | |
| return output_path | |
| # Create Gradio interface | |
| with gr.Blocks(title="Surgical-DeSAM", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🔬 Surgical-DeSAM") | |
| gr.Markdown("Segment surgical instruments in images or videos using DeSAM architecture.") | |
| with gr.Tabs(): | |
| # Video Tab | |
| with gr.TabItem("🎬 Video Segmentation"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_video = gr.Video(label="Input Video") | |
| video_btn = gr.Button("Segment Video", variant="primary") | |
| with gr.Column(): | |
| output_video = gr.Video(label="Segmentation Result") | |
| video_btn.click(fn=predict_video, inputs=input_video, outputs=output_video) | |
| gr.Examples( | |
| examples=["examples/surgical_demo.mp4", | |
| "examples/output.mp4"], | |
| inputs=input_video, | |
| label="Example Surgical Video" | |
| ) | |
| # Image Tab | |
| with gr.TabItem("🖼️ Image Segmentation"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(type="pil", label="Input Image") | |
| image_btn = gr.Button("Segment Image", variant="primary") | |
| with gr.Column(): | |
| output_image = gr.Image(type="pil", label="Segmentation Result") | |
| image_btn.click(fn=predict_image, inputs=input_image, outputs=output_image) | |
| gr.Examples( | |
| examples=[ | |
| "examples/example_2.png", | |
| "examples/example_3.png", | |
| "examples/example_4.png", | |
| ], | |
| inputs=input_image, | |
| label="Example Surgical Images" | |
| ) | |
| gr.Markdown(""" | |
| ## Detected Classes | |
| Bipolar Forceps | Prograsp Forceps | Large Needle Driver | Monopolar Curved Scissors | | |
| Ultrasound Probe | Suction | Clip Applier | Stapler | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() | |