Spaces:
Runtime error
Runtime error
| """ | |
| SurgiTrack Demo - Surgical Tool Tracking | |
| Based on CholecTrack20 dataset (Nwoye et al., CVPR 2025) | |
| """ | |
| import os | |
| import gradio as gr | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from pathlib import Path | |
| from collections import deque | |
| # Import models (will be loaded on startup) | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| YOLO_MODEL = None | |
| DIRECTION_MODEL = None | |
| TRACKER = None | |
| CLASS_NAMES = ['grasper', 'bipolar', 'hook', 'scissors', 'clipper', 'irrigator', 'specimenbag'] | |
| COLORS = { | |
| 'grasper': (255, 100, 100), | |
| 'bipolar': (100, 255, 100), | |
| 'hook': (100, 100, 255), | |
| 'scissors': (255, 255, 100), | |
| 'clipper': (255, 100, 255), | |
| 'irrigator': (100, 255, 255), | |
| 'specimenbag': (200, 200, 200), | |
| } | |
| OPERATOR_COLORS = { | |
| 0: (0, 255, 0), # MSLH - Green | |
| 1: (0, 0, 255), # MSRH - Red | |
| 2: (255, 165, 0), # ASRH - Orange | |
| 3: (128, 128, 128) # NULL - Gray | |
| } | |
| def load_models(): | |
| """Load YOLO and Direction Estimator models""" | |
| global YOLO_MODEL, DIRECTION_MODEL, TRACKER | |
| from ultralytics import YOLO | |
| from tracker import DirectionEstimator, OperatorBasedTracker | |
| # Load YOLO | |
| yolo_path = "weights/best.pt" | |
| if os.path.exists(yolo_path): | |
| YOLO_MODEL = YOLO(yolo_path) | |
| print(f"YOLO model loaded from {yolo_path}") | |
| else: | |
| print(f"Warning: YOLO model not found at {yolo_path}") | |
| return False | |
| # Load Direction Estimator | |
| direction_path = "weights/direction_estimator.pth" | |
| if os.path.exists(direction_path): | |
| DIRECTION_MODEL = DirectionEstimator(num_classes=4, pretrained=False) | |
| checkpoint = torch.load(direction_path, map_location=DEVICE, weights_only=False) | |
| DIRECTION_MODEL.load_state_dict(checkpoint['model_state_dict']) | |
| DIRECTION_MODEL.to(DEVICE) | |
| DIRECTION_MODEL.eval() | |
| print(f"Direction model loaded from {direction_path}") | |
| else: | |
| print(f"Warning: Direction model not found at {direction_path}") | |
| DIRECTION_MODEL = None | |
| # Initialize tracker | |
| TRACKER = OperatorBasedTracker( | |
| direction_model=DIRECTION_MODEL, | |
| max_inactive_frames=150, | |
| iou_threshold=0.2, | |
| direction_confidence_threshold=0.4, | |
| device=DEVICE | |
| ) | |
| return True | |
| def draw_tracking_results(frame, slots, trajectories, frame_count): | |
| """Draw bounding boxes, IDs, and trajectories on frame""" | |
| for slot in slots: | |
| if slot.bbox is None: | |
| continue | |
| x1, y1, x2, y2 = slot.bbox.astype(int) | |
| track_id = slot.track_id | |
| class_name = slot.class_name | |
| # Update trajectory | |
| center = (int((x1 + x2) / 2), int((y1 + y2) / 2)) | |
| if track_id not in trajectories: | |
| trajectories[track_id] = deque(maxlen=30) | |
| trajectories[track_id].append(center) | |
| # Get colors | |
| bbox_color = COLORS.get(class_name, (255, 255, 255)) | |
| op_color = OPERATOR_COLORS.get(slot.operator_id, (128, 128, 128)) | |
| # Draw bbox | |
| cv2.rectangle(frame, (x1, y1), (x2, y2), bbox_color, 2) | |
| # Draw operator indicator | |
| cv2.circle(frame, (x2 - 10, y1 + 10), 8, op_color, -1) | |
| cv2.circle(frame, (x2 - 10, y1 + 10), 8, (0, 0, 0), 1) | |
| # Draw label | |
| label = f"ID:{track_id} {class_name}" | |
| (lw, lh), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) | |
| cv2.rectangle(frame, (x1, y1 - lh - 8), (x1 + lw + 4, y1), bbox_color, -1) | |
| cv2.putText(frame, label, (x1 + 2, y1 - 4), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1) | |
| # Draw trajectory | |
| traj = list(trajectories[track_id]) | |
| for i in range(1, len(traj)): | |
| alpha = i / len(traj) | |
| thickness = max(1, int(alpha * 3)) | |
| color = tuple(int(c * alpha) for c in bbox_color) | |
| cv2.line(frame, traj[i-1], traj[i], color, thickness) | |
| # Draw frame counter | |
| cv2.putText(frame, f"Frame: {frame_count}", (10, 30), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) | |
| return frame, trajectories | |
| def process_video_live(video_path, confidence_threshold, progress=gr.Progress()): | |
| """Process video with live inference""" | |
| global YOLO_MODEL, TRACKER | |
| if YOLO_MODEL is None: | |
| return None, "Error: Models not loaded" | |
| from tracker import Detection | |
| # Reset tracker | |
| TRACKER.reset() | |
| cap = cv2.VideoCapture(video_path) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| # Output video | |
| output_path = "/tmp/output_tracked.mp4" | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) | |
| trajectories = {} | |
| frame_count = 0 | |
| total_detections = 0 | |
| unique_tracks = set() | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| # YOLO detection | |
| results = YOLO_MODEL.predict(frame, conf=confidence_threshold, verbose=False) | |
| detections = [] | |
| if len(results) > 0 and results[0].boxes is not None: | |
| boxes = results[0].boxes | |
| for i in range(len(boxes)): | |
| class_id = int(boxes.cls[i]) | |
| detections.append(Detection( | |
| bbox=boxes.xyxy[i].cpu().numpy(), | |
| class_id=class_id, | |
| class_name=CLASS_NAMES[class_id] if class_id < len(CLASS_NAMES) else "unknown", | |
| confidence=float(boxes.conf[i]), | |
| frame_id=frame_count | |
| )) | |
| total_detections += len(detections) | |
| # Update tracker | |
| slots = TRACKER.update(frame, detections) | |
| for slot in slots: | |
| unique_tracks.add(slot.track_id) | |
| # Draw results | |
| frame, trajectories = draw_tracking_results(frame, slots, trajectories, frame_count) | |
| writer.write(frame) | |
| frame_count += 1 | |
| progress(frame_count / total_frames, desc=f"Processing frame {frame_count}/{total_frames}") | |
| cap.release() | |
| writer.release() | |
| # Stats | |
| stats = f""" | |
| **Processing Complete** | |
| - Total frames: {frame_count} | |
| - Total detections: {total_detections} | |
| - Unique tracks: {len(unique_tracks)} | |
| - Average detections/frame: {total_detections/frame_count:.2f} | |
| - Device: {DEVICE} | |
| """ | |
| return output_path, stats | |
| def show_precomputed_demo(demo_name): | |
| """Show a precomputed demo video""" | |
| demo_videos = { | |
| "Demo 1 - Multi-tool tracking": "demos/demo1_tracked.mp4", | |
| "Demo 2 - Occlusion handling": "demos/demo2_tracked.mp4", | |
| "Demo 3 - Tool re-identification": "demos/demo3_tracked.mp4", | |
| } | |
| video_path = demo_videos.get(demo_name) | |
| if video_path and os.path.exists(video_path): | |
| # Get stats from companion json if exists | |
| stats = f""" | |
| **{demo_name}** | |
| Pre-computed tracking results using: | |
| - YOLOv11x for detection | |
| - Direction Estimator for operator prediction | |
| - Operator-based tracker for multi-tool tracking | |
| *Results computed on GPU, displayed instantly.* | |
| """ | |
| return video_path, stats | |
| else: | |
| return None, f"Demo video not found: {video_path}" | |
| def get_available_demos(): | |
| """Get list of available demo videos""" | |
| demos_dir = Path("demos") | |
| if demos_dir.exists(): | |
| return [f.stem.replace("_tracked", "") for f in demos_dir.glob("*_tracked.mp4")] | |
| return ["Demo 1 - Multi-tool tracking", "Demo 2 - Occlusion handling", "Demo 3 - Tool re-identification"] | |
| # Build Gradio interface | |
| def create_interface(): | |
| with gr.Blocks( | |
| title="SurgiTrack - Surgical Tool Tracking", | |
| theme=gr.themes.Base( | |
| primary_hue="purple", | |
| secondary_hue="gray", | |
| neutral_hue="gray", | |
| ).set( | |
| body_background_fill="#0a0a0f", | |
| body_background_fill_dark="#0a0a0f", | |
| block_background_fill="#12121a", | |
| block_background_fill_dark="#12121a", | |
| block_border_color="#2a2a3a", | |
| block_border_color_dark="#2a2a3a", | |
| button_primary_background_fill="#a855f7", | |
| button_primary_background_fill_hover="#9333ea", | |
| ), | |
| css=""" | |
| .gradio-container { max-width: 1200px !important; } | |
| .gr-button { font-weight: 500; } | |
| footer { display: none !important; } | |
| """ | |
| ) as demo: | |
| gr.Markdown(""" | |
| # π¬ SurgiTrack - Surgical Tool Tracking | |
| Multi-class multi-tool tracking in laparoscopic surgery videos. | |
| Based on the [SurgiTrack paper](https://arxiv.org/abs/2312.07352) and trained on CholecTrack20 dataset. | |
| **Pipeline:** YOLOv11x Detection β Direction Estimation β Operator-based Tracking | |
| --- | |
| """) | |
| with gr.Tabs(): | |
| # Tab 1: Pre-computed demos (instant) | |
| with gr.TabItem("π½οΈ Demo Videos (Instant)"): | |
| gr.Markdown(""" | |
| ### Pre-computed Results | |
| Watch tracking results instantly. These videos were processed on GPU with full pipeline. | |
| """) | |
| with gr.Row(): | |
| demo_dropdown = gr.Dropdown( | |
| choices=get_available_demos(), | |
| label="Select Demo", | |
| value=get_available_demos()[0] if get_available_demos() else None | |
| ) | |
| demo_btn = gr.Button("βΆοΈ Show Demo", variant="primary") | |
| with gr.Row(): | |
| demo_video = gr.Video(label="Tracking Result") | |
| demo_stats = gr.Markdown(label="Statistics") | |
| demo_btn.click( | |
| fn=show_precomputed_demo, | |
| inputs=[demo_dropdown], | |
| outputs=[demo_video, demo_stats] | |
| ) | |
| # Tab 2: Live inference (slower but real) | |
| with gr.TabItem("π Live Inference (CPU)"): | |
| gr.Markdown(""" | |
| ### Real-time Processing | |
| Upload a short video clip (5-15 seconds recommended) for live tracking. | |
| β οΈ **Note:** Running on CPU - processing may take a few minutes. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_video = gr.Video(label="Upload Video") | |
| confidence_slider = gr.Slider( | |
| minimum=0.1, maximum=0.9, value=0.25, step=0.05, | |
| label="Detection Confidence Threshold" | |
| ) | |
| process_btn = gr.Button("π Run Tracking", variant="primary") | |
| with gr.Column(): | |
| output_video = gr.Video(label="Tracked Video") | |
| output_stats = gr.Markdown(label="Statistics") | |
| process_btn.click( | |
| fn=process_video_live, | |
| inputs=[input_video, confidence_slider], | |
| outputs=[output_video, output_stats] | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### π Method Overview | |
| | Component | Description | | |
| |-----------|-------------| | |
| | **Detection** | YOLOv11x trained on CholecTrack20 (7 tool classes) | | |
| | **Direction Estimator** | EfficientNet-B0 + Coordinate Attention β Operator prediction | | |
| | **Tracker** | Operator-based slots for graspers, fixed IDs for other tools | | |
| ### π Results on CholecTrack20 Test Set | |
| | Metric | Score | | |
| |--------|-------| | |
| | **HOTA** | 64.48% | | |
| | **AssA** | 71.19% | | |
| | **DetA** | 58.51% | | |
| --- | |
| **Dataset:** [CholecTrack20](https://arxiv.org/abs/2312.07352) (Nwoye et al., CVPR 2025) | |
| **Author:** [Djalil Khelladi](https://github.com/akhellad) | |
| """) | |
| return demo | |
| if __name__ == "__main__": | |
| print(f"Starting SurgiTrack Demo on {DEVICE}...") | |
| # Try to load models | |
| models_loaded = load_models() | |
| if not models_loaded: | |
| print("Warning: Models not loaded. Only pre-computed demos will work.") | |
| # Create and launch interface | |
| demo = create_interface() | |
| demo.launch() |