SurgiTrackDemo / app.py
akhellad's picture
Initial commit
26a3529
"""
SurgiTrack Demo - Surgical Tool Tracking
Based on CholecTrack20 dataset (Nwoye et al., CVPR 2025)
"""
import os
import gradio as gr
import cv2
import numpy as np
import torch
from pathlib import Path
from collections import deque
# Import models (will be loaded on startup)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
YOLO_MODEL = None
DIRECTION_MODEL = None
TRACKER = None
CLASS_NAMES = ['grasper', 'bipolar', 'hook', 'scissors', 'clipper', 'irrigator', 'specimenbag']
COLORS = {
'grasper': (255, 100, 100),
'bipolar': (100, 255, 100),
'hook': (100, 100, 255),
'scissors': (255, 255, 100),
'clipper': (255, 100, 255),
'irrigator': (100, 255, 255),
'specimenbag': (200, 200, 200),
}
OPERATOR_COLORS = {
0: (0, 255, 0), # MSLH - Green
1: (0, 0, 255), # MSRH - Red
2: (255, 165, 0), # ASRH - Orange
3: (128, 128, 128) # NULL - Gray
}
def load_models():
"""Load YOLO and Direction Estimator models"""
global YOLO_MODEL, DIRECTION_MODEL, TRACKER
from ultralytics import YOLO
from tracker import DirectionEstimator, OperatorBasedTracker
# Load YOLO
yolo_path = "weights/best.pt"
if os.path.exists(yolo_path):
YOLO_MODEL = YOLO(yolo_path)
print(f"YOLO model loaded from {yolo_path}")
else:
print(f"Warning: YOLO model not found at {yolo_path}")
return False
# Load Direction Estimator
direction_path = "weights/direction_estimator.pth"
if os.path.exists(direction_path):
DIRECTION_MODEL = DirectionEstimator(num_classes=4, pretrained=False)
checkpoint = torch.load(direction_path, map_location=DEVICE, weights_only=False)
DIRECTION_MODEL.load_state_dict(checkpoint['model_state_dict'])
DIRECTION_MODEL.to(DEVICE)
DIRECTION_MODEL.eval()
print(f"Direction model loaded from {direction_path}")
else:
print(f"Warning: Direction model not found at {direction_path}")
DIRECTION_MODEL = None
# Initialize tracker
TRACKER = OperatorBasedTracker(
direction_model=DIRECTION_MODEL,
max_inactive_frames=150,
iou_threshold=0.2,
direction_confidence_threshold=0.4,
device=DEVICE
)
return True
def draw_tracking_results(frame, slots, trajectories, frame_count):
"""Draw bounding boxes, IDs, and trajectories on frame"""
for slot in slots:
if slot.bbox is None:
continue
x1, y1, x2, y2 = slot.bbox.astype(int)
track_id = slot.track_id
class_name = slot.class_name
# Update trajectory
center = (int((x1 + x2) / 2), int((y1 + y2) / 2))
if track_id not in trajectories:
trajectories[track_id] = deque(maxlen=30)
trajectories[track_id].append(center)
# Get colors
bbox_color = COLORS.get(class_name, (255, 255, 255))
op_color = OPERATOR_COLORS.get(slot.operator_id, (128, 128, 128))
# Draw bbox
cv2.rectangle(frame, (x1, y1), (x2, y2), bbox_color, 2)
# Draw operator indicator
cv2.circle(frame, (x2 - 10, y1 + 10), 8, op_color, -1)
cv2.circle(frame, (x2 - 10, y1 + 10), 8, (0, 0, 0), 1)
# Draw label
label = f"ID:{track_id} {class_name}"
(lw, lh), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
cv2.rectangle(frame, (x1, y1 - lh - 8), (x1 + lw + 4, y1), bbox_color, -1)
cv2.putText(frame, label, (x1 + 2, y1 - 4), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
# Draw trajectory
traj = list(trajectories[track_id])
for i in range(1, len(traj)):
alpha = i / len(traj)
thickness = max(1, int(alpha * 3))
color = tuple(int(c * alpha) for c in bbox_color)
cv2.line(frame, traj[i-1], traj[i], color, thickness)
# Draw frame counter
cv2.putText(frame, f"Frame: {frame_count}", (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
return frame, trajectories
def process_video_live(video_path, confidence_threshold, progress=gr.Progress()):
"""Process video with live inference"""
global YOLO_MODEL, TRACKER
if YOLO_MODEL is None:
return None, "Error: Models not loaded"
from tracker import Detection
# Reset tracker
TRACKER.reset()
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# Output video
output_path = "/tmp/output_tracked.mp4"
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
trajectories = {}
frame_count = 0
total_detections = 0
unique_tracks = set()
while True:
ret, frame = cap.read()
if not ret:
break
# YOLO detection
results = YOLO_MODEL.predict(frame, conf=confidence_threshold, verbose=False)
detections = []
if len(results) > 0 and results[0].boxes is not None:
boxes = results[0].boxes
for i in range(len(boxes)):
class_id = int(boxes.cls[i])
detections.append(Detection(
bbox=boxes.xyxy[i].cpu().numpy(),
class_id=class_id,
class_name=CLASS_NAMES[class_id] if class_id < len(CLASS_NAMES) else "unknown",
confidence=float(boxes.conf[i]),
frame_id=frame_count
))
total_detections += len(detections)
# Update tracker
slots = TRACKER.update(frame, detections)
for slot in slots:
unique_tracks.add(slot.track_id)
# Draw results
frame, trajectories = draw_tracking_results(frame, slots, trajectories, frame_count)
writer.write(frame)
frame_count += 1
progress(frame_count / total_frames, desc=f"Processing frame {frame_count}/{total_frames}")
cap.release()
writer.release()
# Stats
stats = f"""
**Processing Complete**
- Total frames: {frame_count}
- Total detections: {total_detections}
- Unique tracks: {len(unique_tracks)}
- Average detections/frame: {total_detections/frame_count:.2f}
- Device: {DEVICE}
"""
return output_path, stats
def show_precomputed_demo(demo_name):
"""Show a precomputed demo video"""
demo_videos = {
"Demo 1 - Multi-tool tracking": "demos/demo1_tracked.mp4",
"Demo 2 - Occlusion handling": "demos/demo2_tracked.mp4",
"Demo 3 - Tool re-identification": "demos/demo3_tracked.mp4",
}
video_path = demo_videos.get(demo_name)
if video_path and os.path.exists(video_path):
# Get stats from companion json if exists
stats = f"""
**{demo_name}**
Pre-computed tracking results using:
- YOLOv11x for detection
- Direction Estimator for operator prediction
- Operator-based tracker for multi-tool tracking
*Results computed on GPU, displayed instantly.*
"""
return video_path, stats
else:
return None, f"Demo video not found: {video_path}"
def get_available_demos():
"""Get list of available demo videos"""
demos_dir = Path("demos")
if demos_dir.exists():
return [f.stem.replace("_tracked", "") for f in demos_dir.glob("*_tracked.mp4")]
return ["Demo 1 - Multi-tool tracking", "Demo 2 - Occlusion handling", "Demo 3 - Tool re-identification"]
# Build Gradio interface
def create_interface():
with gr.Blocks(
title="SurgiTrack - Surgical Tool Tracking",
theme=gr.themes.Base(
primary_hue="purple",
secondary_hue="gray",
neutral_hue="gray",
).set(
body_background_fill="#0a0a0f",
body_background_fill_dark="#0a0a0f",
block_background_fill="#12121a",
block_background_fill_dark="#12121a",
block_border_color="#2a2a3a",
block_border_color_dark="#2a2a3a",
button_primary_background_fill="#a855f7",
button_primary_background_fill_hover="#9333ea",
),
css="""
.gradio-container { max-width: 1200px !important; }
.gr-button { font-weight: 500; }
footer { display: none !important; }
"""
) as demo:
gr.Markdown("""
# πŸ”¬ SurgiTrack - Surgical Tool Tracking
Multi-class multi-tool tracking in laparoscopic surgery videos.
Based on the [SurgiTrack paper](https://arxiv.org/abs/2312.07352) and trained on CholecTrack20 dataset.
**Pipeline:** YOLOv11x Detection β†’ Direction Estimation β†’ Operator-based Tracking
---
""")
with gr.Tabs():
# Tab 1: Pre-computed demos (instant)
with gr.TabItem("πŸ“½οΈ Demo Videos (Instant)"):
gr.Markdown("""
### Pre-computed Results
Watch tracking results instantly. These videos were processed on GPU with full pipeline.
""")
with gr.Row():
demo_dropdown = gr.Dropdown(
choices=get_available_demos(),
label="Select Demo",
value=get_available_demos()[0] if get_available_demos() else None
)
demo_btn = gr.Button("▢️ Show Demo", variant="primary")
with gr.Row():
demo_video = gr.Video(label="Tracking Result")
demo_stats = gr.Markdown(label="Statistics")
demo_btn.click(
fn=show_precomputed_demo,
inputs=[demo_dropdown],
outputs=[demo_video, demo_stats]
)
# Tab 2: Live inference (slower but real)
with gr.TabItem("πŸ”„ Live Inference (CPU)"):
gr.Markdown("""
### Real-time Processing
Upload a short video clip (5-15 seconds recommended) for live tracking.
⚠️ **Note:** Running on CPU - processing may take a few minutes.
""")
with gr.Row():
with gr.Column():
input_video = gr.Video(label="Upload Video")
confidence_slider = gr.Slider(
minimum=0.1, maximum=0.9, value=0.25, step=0.05,
label="Detection Confidence Threshold"
)
process_btn = gr.Button("πŸš€ Run Tracking", variant="primary")
with gr.Column():
output_video = gr.Video(label="Tracked Video")
output_stats = gr.Markdown(label="Statistics")
process_btn.click(
fn=process_video_live,
inputs=[input_video, confidence_slider],
outputs=[output_video, output_stats]
)
gr.Markdown("""
---
### πŸ“Š Method Overview
| Component | Description |
|-----------|-------------|
| **Detection** | YOLOv11x trained on CholecTrack20 (7 tool classes) |
| **Direction Estimator** | EfficientNet-B0 + Coordinate Attention β†’ Operator prediction |
| **Tracker** | Operator-based slots for graspers, fixed IDs for other tools |
### πŸ“ˆ Results on CholecTrack20 Test Set
| Metric | Score |
|--------|-------|
| **HOTA** | 64.48% |
| **AssA** | 71.19% |
| **DetA** | 58.51% |
---
**Dataset:** [CholecTrack20](https://arxiv.org/abs/2312.07352) (Nwoye et al., CVPR 2025)
**Author:** [Djalil Khelladi](https://github.com/akhellad)
""")
return demo
if __name__ == "__main__":
print(f"Starting SurgiTrack Demo on {DEVICE}...")
# Try to load models
models_loaded = load_models()
if not models_loaded:
print("Warning: Models not loaded. Only pre-computed demos will work.")
# Create and launch interface
demo = create_interface()
demo.launch()