""" SurgiTrack Demo - Surgical Tool Tracking Based on CholecTrack20 dataset (Nwoye et al., CVPR 2025) """ import os import gradio as gr import cv2 import numpy as np import torch from pathlib import Path from collections import deque # Import models (will be loaded on startup) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" YOLO_MODEL = None DIRECTION_MODEL = None TRACKER = None CLASS_NAMES = ['grasper', 'bipolar', 'hook', 'scissors', 'clipper', 'irrigator', 'specimenbag'] COLORS = { 'grasper': (255, 100, 100), 'bipolar': (100, 255, 100), 'hook': (100, 100, 255), 'scissors': (255, 255, 100), 'clipper': (255, 100, 255), 'irrigator': (100, 255, 255), 'specimenbag': (200, 200, 200), } OPERATOR_COLORS = { 0: (0, 255, 0), # MSLH - Green 1: (0, 0, 255), # MSRH - Red 2: (255, 165, 0), # ASRH - Orange 3: (128, 128, 128) # NULL - Gray } def load_models(): """Load YOLO and Direction Estimator models""" global YOLO_MODEL, DIRECTION_MODEL, TRACKER from ultralytics import YOLO from tracker import DirectionEstimator, OperatorBasedTracker # Load YOLO yolo_path = "weights/best.pt" if os.path.exists(yolo_path): YOLO_MODEL = YOLO(yolo_path) print(f"YOLO model loaded from {yolo_path}") else: print(f"Warning: YOLO model not found at {yolo_path}") return False # Load Direction Estimator direction_path = "weights/direction_estimator.pth" if os.path.exists(direction_path): DIRECTION_MODEL = DirectionEstimator(num_classes=4, pretrained=False) checkpoint = torch.load(direction_path, map_location=DEVICE, weights_only=False) DIRECTION_MODEL.load_state_dict(checkpoint['model_state_dict']) DIRECTION_MODEL.to(DEVICE) DIRECTION_MODEL.eval() print(f"Direction model loaded from {direction_path}") else: print(f"Warning: Direction model not found at {direction_path}") DIRECTION_MODEL = None # Initialize tracker TRACKER = OperatorBasedTracker( direction_model=DIRECTION_MODEL, max_inactive_frames=150, iou_threshold=0.2, direction_confidence_threshold=0.4, device=DEVICE ) return True def draw_tracking_results(frame, slots, trajectories, frame_count): """Draw bounding boxes, IDs, and trajectories on frame""" for slot in slots: if slot.bbox is None: continue x1, y1, x2, y2 = slot.bbox.astype(int) track_id = slot.track_id class_name = slot.class_name # Update trajectory center = (int((x1 + x2) / 2), int((y1 + y2) / 2)) if track_id not in trajectories: trajectories[track_id] = deque(maxlen=30) trajectories[track_id].append(center) # Get colors bbox_color = COLORS.get(class_name, (255, 255, 255)) op_color = OPERATOR_COLORS.get(slot.operator_id, (128, 128, 128)) # Draw bbox cv2.rectangle(frame, (x1, y1), (x2, y2), bbox_color, 2) # Draw operator indicator cv2.circle(frame, (x2 - 10, y1 + 10), 8, op_color, -1) cv2.circle(frame, (x2 - 10, y1 + 10), 8, (0, 0, 0), 1) # Draw label label = f"ID:{track_id} {class_name}" (lw, lh), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) cv2.rectangle(frame, (x1, y1 - lh - 8), (x1 + lw + 4, y1), bbox_color, -1) cv2.putText(frame, label, (x1 + 2, y1 - 4), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1) # Draw trajectory traj = list(trajectories[track_id]) for i in range(1, len(traj)): alpha = i / len(traj) thickness = max(1, int(alpha * 3)) color = tuple(int(c * alpha) for c in bbox_color) cv2.line(frame, traj[i-1], traj[i], color, thickness) # Draw frame counter cv2.putText(frame, f"Frame: {frame_count}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) return frame, trajectories def process_video_live(video_path, confidence_threshold, progress=gr.Progress()): """Process video with live inference""" global YOLO_MODEL, TRACKER if YOLO_MODEL is None: return None, "Error: Models not loaded" from tracker import Detection # Reset tracker TRACKER.reset() cap = cv2.VideoCapture(video_path) fps = cap.get(cv2.CAP_PROP_FPS) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # Output video output_path = "/tmp/output_tracked.mp4" fourcc = cv2.VideoWriter_fourcc(*'mp4v') writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) trajectories = {} frame_count = 0 total_detections = 0 unique_tracks = set() while True: ret, frame = cap.read() if not ret: break # YOLO detection results = YOLO_MODEL.predict(frame, conf=confidence_threshold, verbose=False) detections = [] if len(results) > 0 and results[0].boxes is not None: boxes = results[0].boxes for i in range(len(boxes)): class_id = int(boxes.cls[i]) detections.append(Detection( bbox=boxes.xyxy[i].cpu().numpy(), class_id=class_id, class_name=CLASS_NAMES[class_id] if class_id < len(CLASS_NAMES) else "unknown", confidence=float(boxes.conf[i]), frame_id=frame_count )) total_detections += len(detections) # Update tracker slots = TRACKER.update(frame, detections) for slot in slots: unique_tracks.add(slot.track_id) # Draw results frame, trajectories = draw_tracking_results(frame, slots, trajectories, frame_count) writer.write(frame) frame_count += 1 progress(frame_count / total_frames, desc=f"Processing frame {frame_count}/{total_frames}") cap.release() writer.release() # Stats stats = f""" **Processing Complete** - Total frames: {frame_count} - Total detections: {total_detections} - Unique tracks: {len(unique_tracks)} - Average detections/frame: {total_detections/frame_count:.2f} - Device: {DEVICE} """ return output_path, stats def show_precomputed_demo(demo_name): """Show a precomputed demo video""" demo_videos = { "Demo 1 - Multi-tool tracking": "demos/demo1_tracked.mp4", "Demo 2 - Occlusion handling": "demos/demo2_tracked.mp4", "Demo 3 - Tool re-identification": "demos/demo3_tracked.mp4", } video_path = demo_videos.get(demo_name) if video_path and os.path.exists(video_path): # Get stats from companion json if exists stats = f""" **{demo_name}** Pre-computed tracking results using: - YOLOv11x for detection - Direction Estimator for operator prediction - Operator-based tracker for multi-tool tracking *Results computed on GPU, displayed instantly.* """ return video_path, stats else: return None, f"Demo video not found: {video_path}" def get_available_demos(): """Get list of available demo videos""" demos_dir = Path("demos") if demos_dir.exists(): return [f.stem.replace("_tracked", "") for f in demos_dir.glob("*_tracked.mp4")] return ["Demo 1 - Multi-tool tracking", "Demo 2 - Occlusion handling", "Demo 3 - Tool re-identification"] # Build Gradio interface def create_interface(): with gr.Blocks( title="SurgiTrack - Surgical Tool Tracking", theme=gr.themes.Base( primary_hue="purple", secondary_hue="gray", neutral_hue="gray", ).set( body_background_fill="#0a0a0f", body_background_fill_dark="#0a0a0f", block_background_fill="#12121a", block_background_fill_dark="#12121a", block_border_color="#2a2a3a", block_border_color_dark="#2a2a3a", button_primary_background_fill="#a855f7", button_primary_background_fill_hover="#9333ea", ), css=""" .gradio-container { max-width: 1200px !important; } .gr-button { font-weight: 500; } footer { display: none !important; } """ ) as demo: gr.Markdown(""" # 🔬 SurgiTrack - Surgical Tool Tracking Multi-class multi-tool tracking in laparoscopic surgery videos. Based on the [SurgiTrack paper](https://arxiv.org/abs/2312.07352) and trained on CholecTrack20 dataset. **Pipeline:** YOLOv11x Detection → Direction Estimation → Operator-based Tracking --- """) with gr.Tabs(): # Tab 1: Pre-computed demos (instant) with gr.TabItem("📽️ Demo Videos (Instant)"): gr.Markdown(""" ### Pre-computed Results Watch tracking results instantly. These videos were processed on GPU with full pipeline. """) with gr.Row(): demo_dropdown = gr.Dropdown( choices=get_available_demos(), label="Select Demo", value=get_available_demos()[0] if get_available_demos() else None ) demo_btn = gr.Button("▶️ Show Demo", variant="primary") with gr.Row(): demo_video = gr.Video(label="Tracking Result") demo_stats = gr.Markdown(label="Statistics") demo_btn.click( fn=show_precomputed_demo, inputs=[demo_dropdown], outputs=[demo_video, demo_stats] ) # Tab 2: Live inference (slower but real) with gr.TabItem("🔄 Live Inference (CPU)"): gr.Markdown(""" ### Real-time Processing Upload a short video clip (5-15 seconds recommended) for live tracking. ⚠️ **Note:** Running on CPU - processing may take a few minutes. """) with gr.Row(): with gr.Column(): input_video = gr.Video(label="Upload Video") confidence_slider = gr.Slider( minimum=0.1, maximum=0.9, value=0.25, step=0.05, label="Detection Confidence Threshold" ) process_btn = gr.Button("🚀 Run Tracking", variant="primary") with gr.Column(): output_video = gr.Video(label="Tracked Video") output_stats = gr.Markdown(label="Statistics") process_btn.click( fn=process_video_live, inputs=[input_video, confidence_slider], outputs=[output_video, output_stats] ) gr.Markdown(""" --- ### 📊 Method Overview | Component | Description | |-----------|-------------| | **Detection** | YOLOv11x trained on CholecTrack20 (7 tool classes) | | **Direction Estimator** | EfficientNet-B0 + Coordinate Attention → Operator prediction | | **Tracker** | Operator-based slots for graspers, fixed IDs for other tools | ### 📈 Results on CholecTrack20 Test Set | Metric | Score | |--------|-------| | **HOTA** | 64.48% | | **AssA** | 71.19% | | **DetA** | 58.51% | --- **Dataset:** [CholecTrack20](https://arxiv.org/abs/2312.07352) (Nwoye et al., CVPR 2025) **Author:** [Djalil Khelladi](https://github.com/akhellad) """) return demo if __name__ == "__main__": print(f"Starting SurgiTrack Demo on {DEVICE}...") # Try to load models models_loaded = load_models() if not models_loaded: print("Warning: Models not loaded. Only pre-computed demos will work.") # Create and launch interface demo = create_interface() demo.launch()