""" Surgical-DeSAM Gradio App for Hugging Face Spaces Supports both Image and Video segmentation with ZeroGPU """ import os import spaces import gradio as gr import torch import numpy as np import cv2 from PIL import Image from huggingface_hub import hf_hub_download import tempfile # Model imports from models.detr_seg import DETR, SAMModel from models.backbone import build_backbone from models.transformer import build_transformer from util.misc import NestedTensor # Configuration MODEL_REPO = os.environ.get("MODEL_REPO", "IFMedTech/surgical-desam-weights") HF_TOKEN = os.environ.get("HF_TOKEN") INSTRUMENT_CLASSES = ( 'bipolar_forceps', 'prograsp_forceps', 'large_needle_driver', 'monopolar_curved_scissors', 'ultrasound_probe', 'suction', 'clip_applier', 'stapler' ) COLORS = [ [0, 114, 189], [217, 83, 25], [237, 177, 32], [126, 47, 142], [119, 172, 48], [77, 190, 238], [162, 20, 47], [76, 76, 76] ] # Global model variables model = None seg_model = None device = None def download_weights(): """Download model weights from private HF repo""" weights_dir = "weights" os.makedirs(weights_dir, exist_ok=True) desam_path = hf_hub_download( repo_id=MODEL_REPO, filename="surgical_desam_1024.pth", token=HF_TOKEN, local_dir=weights_dir ) sam_path = hf_hub_download( repo_id=MODEL_REPO, filename="sam_vit_b_01ec64.pth", token=HF_TOKEN, local_dir=weights_dir ) swin_dir = "swin_backbone" os.makedirs(swin_dir, exist_ok=True) hf_hub_download( repo_id=MODEL_REPO, filename="swin_base_patch4_window7_224_22kto1k.pth", token=HF_TOKEN, local_dir=swin_dir ) return desam_path, sam_path class Args: """Mock args for model building""" backbone = 'swin_B_224_22k' dilation = False position_embedding = 'sine' hidden_dim = 256 dropout = 0.1 nheads = 8 dim_feedforward = 2048 enc_layers = 6 dec_layers = 6 pre_norm = False num_queries = 100 aux_loss = False lr_backbone = 1e-5 masks = False dataset_file = 'endovis18' device = 'cuda' backbone_dir = './swin_backbone' def load_models(): """Load DETR and SAM models""" global model, seg_model, device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") desam_path, sam_path = download_weights() args = Args() args.device = str(device) backbone = build_backbone(args) transformer = build_transformer(args) model = DETR( backbone, transformer, num_classes=9, num_queries=args.num_queries, aux_loss=args.aux_loss, ) checkpoint = torch.load(desam_path, map_location='cpu', weights_only=False) model.load_state_dict(checkpoint['model'], strict=False) model.to(device) model.eval() seg_model = SAMModel(device=device, ckpt_path=sam_path) if 'seg_model' in checkpoint: seg_model.load_state_dict(checkpoint['seg_model']) seg_model.to(device) seg_model.eval() print("Models loaded successfully!") def preprocess_frame(frame): """Preprocess frame for model input""" img = cv2.resize(frame, (1024, 1024)) img = img.astype(np.float32) / 255.0 mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) img = (img - mean) / std img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).float() return img_tensor def box_cxcywh_to_xyxy(x): """Convert boxes from center format to corner format""" x_c, y_c, w, h = x.unbind(-1) b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] return torch.stack(b, dim=-1) def process_single_frame(frame_rgb, h, w): """Process a single frame and return segmented result""" global model, seg_model, device img_tensor = preprocess_frame(frame_rgb).unsqueeze(0).to(device) mask = torch.zeros((1, 1024, 1024), dtype=torch.bool, device=device) samples = NestedTensor(img_tensor, mask) with torch.no_grad(): outputs, image_embeddings = model(samples) probas = outputs['pred_logits'].softmax(-1)[0, :, :-1] keep = probas.max(-1).values > 0.3 if not keep.any(): return frame_rgb # No detections boxes = outputs['pred_boxes'][0, keep] scores = probas[keep].max(-1).values.cpu().numpy() labels = probas[keep].argmax(-1).cpu().numpy() boxes_scaled = box_cxcywh_to_xyxy(boxes) * torch.tensor([w, h, w, h], device=device) boxes_np = boxes_scaled.cpu().numpy() low_res_masks, pred_masks, _ = seg_model( img_tensor, boxes, image_embeddings, sizes=(1024, 1024), add_noise=False ) masks_np = pred_masks.cpu().numpy() # Draw on frame result = frame_rgb.copy() for i, (box, label, mask_pred, score) in enumerate(zip(boxes_np, labels, masks_np, scores)): if score < 0.3: continue color = COLORS[label % len(COLORS)] # Draw mask mask_resized = cv2.resize(mask_pred, (w, h)) mask_bool = mask_resized > 0.5 overlay = result.copy() overlay[mask_bool] = color result = cv2.addWeighted(result, 0.6, overlay, 0.4, 0) # Draw box x1, y1, x2, y2 = box.astype(int) cv2.rectangle(result, (x1, y1), (x2, y2), color, 2) # Draw label label_text = f"{INSTRUMENT_CLASSES[label]}: {score:.2f}" cv2.putText(result, label_text, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) return result @spaces.GPU def predict_image(image): """Run inference on input image""" global model, seg_model, device if model is None: load_models() if image is None: return None frame_rgb = np.array(image) h, w = frame_rgb.shape[:2] result = process_single_frame(frame_rgb, h, w) return Image.fromarray(result) @spaces.GPU(duration=300) def predict_video(video_path, progress=gr.Progress()): """Process video and return segmented video""" global model, seg_model, device if model is None: progress(0, desc="Loading models...") load_models() if video_path is None: return None # Open video cap = cv2.VideoCapture(video_path) fps = int(cap.get(cv2.CAP_PROP_FPS)) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # Output video output_path = tempfile.mktemp(suffix=".mp4") fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) frame_count = 0 while True: ret, frame = cap.read() if not ret: break # BGR to RGB frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # Process frame result_rgb = process_single_frame(frame_rgb, height, width) # RGB to BGR for output result_bgr = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR) out.write(result_bgr) frame_count += 1 progress(frame_count / total_frames, desc=f"Processing frame {frame_count}/{total_frames}") cap.release() out.release() return output_path # Create Gradio interface with gr.Blocks(title="Surgical-DeSAM", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🔬 Surgical-DeSAM") gr.Markdown("Segment surgical instruments in images or videos using DeSAM architecture.") with gr.Tabs(): # Video Tab with gr.TabItem("🎬 Video Segmentation"): with gr.Row(): with gr.Column(): input_video = gr.Video(label="Input Video") video_btn = gr.Button("Segment Video", variant="primary") with gr.Column(): output_video = gr.Video(label="Segmentation Result") video_btn.click(fn=predict_video, inputs=input_video, outputs=output_video) gr.Examples( examples=["examples/surgical_demo.mp4", "examples/output.mp4"], inputs=input_video, label="Example Surgical Video" ) # Image Tab with gr.TabItem("🖼️ Image Segmentation"): with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil", label="Input Image") image_btn = gr.Button("Segment Image", variant="primary") with gr.Column(): output_image = gr.Image(type="pil", label="Segmentation Result") image_btn.click(fn=predict_image, inputs=input_image, outputs=output_image) gr.Examples( examples=[ "examples/example_2.png", "examples/example_3.png", "examples/example_4.png", ], inputs=input_image, label="Example Surgical Images" ) gr.Markdown(""" ## Detected Classes Bipolar Forceps | Prograsp Forceps | Large Needle Driver | Monopolar Curved Scissors | Ultrasound Probe | Suction | Clip Applier | Stapler """) if __name__ == "__main__": demo.launch()