Spaces:
Runtime error
Runtime error
Initial commit
Browse files- README.md +52 -5
- app.py +375 -0
- requirements.txt +7 -0
- tracker.py +379 -0
README.md
CHANGED
|
@@ -1,12 +1,59 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
colorTo: indigo
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: SurgiTrack - Surgical Tool Tracking
|
| 3 |
+
emoji: 🔬
|
| 4 |
+
colorFrom: purple
|
| 5 |
colorTo: indigo
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 4.44.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# SurgiTrack - Surgical Tool Tracking
|
| 14 |
+
|
| 15 |
+
Multi-class multi-tool tracking system for laparoscopic surgery videos.
|
| 16 |
+
|
| 17 |
+
## Overview
|
| 18 |
+
|
| 19 |
+
This demo implements the tracking pipeline from ["SurgiTrack: Fine-Grained Multi-Class Multi-Tool Tracking in Surgical Videos"](https://arxiv.org/abs/2312.07352), trained and evaluated on the CholecTrack20 dataset.
|
| 20 |
+
|
| 21 |
+
## Pipeline
|
| 22 |
+
|
| 23 |
+
1. **Detection**: YOLOv11x trained on 7 surgical tool classes
|
| 24 |
+
2. **Direction Estimation**: EfficientNet-B0 + Coordinate Attention predicts operator (MSLH, MSRH, ASRH)
|
| 25 |
+
3. **Tracking**: Operator-based slot assignment for graspers, fixed IDs for other tools
|
| 26 |
+
|
| 27 |
+
## Results
|
| 28 |
+
|
| 29 |
+
| Metric | Score |
|
| 30 |
+
|--------|-------|
|
| 31 |
+
| HOTA | 64.48% |
|
| 32 |
+
| AssA | 71.19% |
|
| 33 |
+
| DetA | 58.51% |
|
| 34 |
+
|
| 35 |
+
## Tool Classes
|
| 36 |
+
|
| 37 |
+
- Grasper (tracked by operator)
|
| 38 |
+
- Bipolar
|
| 39 |
+
- Hook
|
| 40 |
+
- Scissors
|
| 41 |
+
- Clipper
|
| 42 |
+
- Irrigator
|
| 43 |
+
- Specimen Bag
|
| 44 |
+
|
| 45 |
+
## Citation
|
| 46 |
+
|
| 47 |
+
```bibtex
|
| 48 |
+
@InProceedings{nwoye2023cholectrack20,
|
| 49 |
+
author = {Nwoye, Chinedu Innocent and Elgohary, Kareem and Srinivas, Anvita and Zaid, Fauzan and Lavanchy, Joël L. and Padoy, Nicolas},
|
| 50 |
+
title = {CholecTrack20: A Multi-Perspective Tracking Dataset for Surgical Tools},
|
| 51 |
+
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
| 52 |
+
year = {2025},
|
| 53 |
+
month = {June}
|
| 54 |
+
}
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
## Author
|
| 58 |
+
|
| 59 |
+
[Djalil Khelladi](https://github.com/akhellad)
|
app.py
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SurgiTrack Demo - Surgical Tool Tracking
|
| 3 |
+
Based on CholecTrack20 dataset (Nwoye et al., CVPR 2025)
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import cv2
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from collections import deque
|
| 13 |
+
|
| 14 |
+
# Import models (will be loaded on startup)
|
| 15 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 16 |
+
YOLO_MODEL = None
|
| 17 |
+
DIRECTION_MODEL = None
|
| 18 |
+
TRACKER = None
|
| 19 |
+
|
| 20 |
+
CLASS_NAMES = ['grasper', 'bipolar', 'hook', 'scissors', 'clipper', 'irrigator', 'specimenbag']
|
| 21 |
+
|
| 22 |
+
COLORS = {
|
| 23 |
+
'grasper': (255, 100, 100),
|
| 24 |
+
'bipolar': (100, 255, 100),
|
| 25 |
+
'hook': (100, 100, 255),
|
| 26 |
+
'scissors': (255, 255, 100),
|
| 27 |
+
'clipper': (255, 100, 255),
|
| 28 |
+
'irrigator': (100, 255, 255),
|
| 29 |
+
'specimenbag': (200, 200, 200),
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
OPERATOR_COLORS = {
|
| 33 |
+
0: (0, 255, 0), # MSLH - Green
|
| 34 |
+
1: (0, 0, 255), # MSRH - Red
|
| 35 |
+
2: (255, 165, 0), # ASRH - Orange
|
| 36 |
+
3: (128, 128, 128) # NULL - Gray
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def load_models():
|
| 41 |
+
"""Load YOLO and Direction Estimator models"""
|
| 42 |
+
global YOLO_MODEL, DIRECTION_MODEL, TRACKER
|
| 43 |
+
|
| 44 |
+
from ultralytics import YOLO
|
| 45 |
+
from tracker import DirectionEstimator, OperatorBasedTracker
|
| 46 |
+
|
| 47 |
+
# Load YOLO
|
| 48 |
+
yolo_path = "weights/best.pt"
|
| 49 |
+
if os.path.exists(yolo_path):
|
| 50 |
+
YOLO_MODEL = YOLO(yolo_path)
|
| 51 |
+
print(f"YOLO model loaded from {yolo_path}")
|
| 52 |
+
else:
|
| 53 |
+
print(f"Warning: YOLO model not found at {yolo_path}")
|
| 54 |
+
return False
|
| 55 |
+
|
| 56 |
+
# Load Direction Estimator
|
| 57 |
+
direction_path = "weights/direction_estimator.pth"
|
| 58 |
+
if os.path.exists(direction_path):
|
| 59 |
+
DIRECTION_MODEL = DirectionEstimator(num_classes=4, pretrained=False)
|
| 60 |
+
checkpoint = torch.load(direction_path, map_location=DEVICE, weights_only=False)
|
| 61 |
+
DIRECTION_MODEL.load_state_dict(checkpoint['model_state_dict'])
|
| 62 |
+
DIRECTION_MODEL.to(DEVICE)
|
| 63 |
+
DIRECTION_MODEL.eval()
|
| 64 |
+
print(f"Direction model loaded from {direction_path}")
|
| 65 |
+
else:
|
| 66 |
+
print(f"Warning: Direction model not found at {direction_path}")
|
| 67 |
+
DIRECTION_MODEL = None
|
| 68 |
+
|
| 69 |
+
# Initialize tracker
|
| 70 |
+
TRACKER = OperatorBasedTracker(
|
| 71 |
+
direction_model=DIRECTION_MODEL,
|
| 72 |
+
max_inactive_frames=150,
|
| 73 |
+
iou_threshold=0.2,
|
| 74 |
+
direction_confidence_threshold=0.4,
|
| 75 |
+
device=DEVICE
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
return True
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def draw_tracking_results(frame, slots, trajectories, frame_count):
|
| 82 |
+
"""Draw bounding boxes, IDs, and trajectories on frame"""
|
| 83 |
+
for slot in slots:
|
| 84 |
+
if slot.bbox is None:
|
| 85 |
+
continue
|
| 86 |
+
|
| 87 |
+
x1, y1, x2, y2 = slot.bbox.astype(int)
|
| 88 |
+
track_id = slot.track_id
|
| 89 |
+
class_name = slot.class_name
|
| 90 |
+
|
| 91 |
+
# Update trajectory
|
| 92 |
+
center = (int((x1 + x2) / 2), int((y1 + y2) / 2))
|
| 93 |
+
if track_id not in trajectories:
|
| 94 |
+
trajectories[track_id] = deque(maxlen=30)
|
| 95 |
+
trajectories[track_id].append(center)
|
| 96 |
+
|
| 97 |
+
# Get colors
|
| 98 |
+
bbox_color = COLORS.get(class_name, (255, 255, 255))
|
| 99 |
+
op_color = OPERATOR_COLORS.get(slot.operator_id, (128, 128, 128))
|
| 100 |
+
|
| 101 |
+
# Draw bbox
|
| 102 |
+
cv2.rectangle(frame, (x1, y1), (x2, y2), bbox_color, 2)
|
| 103 |
+
|
| 104 |
+
# Draw operator indicator
|
| 105 |
+
cv2.circle(frame, (x2 - 10, y1 + 10), 8, op_color, -1)
|
| 106 |
+
cv2.circle(frame, (x2 - 10, y1 + 10), 8, (0, 0, 0), 1)
|
| 107 |
+
|
| 108 |
+
# Draw label
|
| 109 |
+
label = f"ID:{track_id} {class_name}"
|
| 110 |
+
(lw, lh), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
| 111 |
+
cv2.rectangle(frame, (x1, y1 - lh - 8), (x1 + lw + 4, y1), bbox_color, -1)
|
| 112 |
+
cv2.putText(frame, label, (x1 + 2, y1 - 4), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
|
| 113 |
+
|
| 114 |
+
# Draw trajectory
|
| 115 |
+
traj = list(trajectories[track_id])
|
| 116 |
+
for i in range(1, len(traj)):
|
| 117 |
+
alpha = i / len(traj)
|
| 118 |
+
thickness = max(1, int(alpha * 3))
|
| 119 |
+
color = tuple(int(c * alpha) for c in bbox_color)
|
| 120 |
+
cv2.line(frame, traj[i-1], traj[i], color, thickness)
|
| 121 |
+
|
| 122 |
+
# Draw frame counter
|
| 123 |
+
cv2.putText(frame, f"Frame: {frame_count}", (10, 30),
|
| 124 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
|
| 125 |
+
|
| 126 |
+
return frame, trajectories
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def process_video_live(video_path, confidence_threshold, progress=gr.Progress()):
|
| 130 |
+
"""Process video with live inference"""
|
| 131 |
+
global YOLO_MODEL, TRACKER
|
| 132 |
+
|
| 133 |
+
if YOLO_MODEL is None:
|
| 134 |
+
return None, "Error: Models not loaded"
|
| 135 |
+
|
| 136 |
+
from tracker import Detection
|
| 137 |
+
|
| 138 |
+
# Reset tracker
|
| 139 |
+
TRACKER.reset()
|
| 140 |
+
|
| 141 |
+
cap = cv2.VideoCapture(video_path)
|
| 142 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 143 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 144 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 145 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 146 |
+
|
| 147 |
+
# Output video
|
| 148 |
+
output_path = "/tmp/output_tracked.mp4"
|
| 149 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 150 |
+
writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
| 151 |
+
|
| 152 |
+
trajectories = {}
|
| 153 |
+
frame_count = 0
|
| 154 |
+
total_detections = 0
|
| 155 |
+
unique_tracks = set()
|
| 156 |
+
|
| 157 |
+
while True:
|
| 158 |
+
ret, frame = cap.read()
|
| 159 |
+
if not ret:
|
| 160 |
+
break
|
| 161 |
+
|
| 162 |
+
# YOLO detection
|
| 163 |
+
results = YOLO_MODEL.predict(frame, conf=confidence_threshold, verbose=False)
|
| 164 |
+
|
| 165 |
+
detections = []
|
| 166 |
+
if len(results) > 0 and results[0].boxes is not None:
|
| 167 |
+
boxes = results[0].boxes
|
| 168 |
+
for i in range(len(boxes)):
|
| 169 |
+
class_id = int(boxes.cls[i])
|
| 170 |
+
detections.append(Detection(
|
| 171 |
+
bbox=boxes.xyxy[i].cpu().numpy(),
|
| 172 |
+
class_id=class_id,
|
| 173 |
+
class_name=CLASS_NAMES[class_id] if class_id < len(CLASS_NAMES) else "unknown",
|
| 174 |
+
confidence=float(boxes.conf[i]),
|
| 175 |
+
frame_id=frame_count
|
| 176 |
+
))
|
| 177 |
+
|
| 178 |
+
total_detections += len(detections)
|
| 179 |
+
|
| 180 |
+
# Update tracker
|
| 181 |
+
slots = TRACKER.update(frame, detections)
|
| 182 |
+
|
| 183 |
+
for slot in slots:
|
| 184 |
+
unique_tracks.add(slot.track_id)
|
| 185 |
+
|
| 186 |
+
# Draw results
|
| 187 |
+
frame, trajectories = draw_tracking_results(frame, slots, trajectories, frame_count)
|
| 188 |
+
|
| 189 |
+
writer.write(frame)
|
| 190 |
+
frame_count += 1
|
| 191 |
+
|
| 192 |
+
progress(frame_count / total_frames, desc=f"Processing frame {frame_count}/{total_frames}")
|
| 193 |
+
|
| 194 |
+
cap.release()
|
| 195 |
+
writer.release()
|
| 196 |
+
|
| 197 |
+
# Stats
|
| 198 |
+
stats = f"""
|
| 199 |
+
**Processing Complete**
|
| 200 |
+
- Total frames: {frame_count}
|
| 201 |
+
- Total detections: {total_detections}
|
| 202 |
+
- Unique tracks: {len(unique_tracks)}
|
| 203 |
+
- Average detections/frame: {total_detections/frame_count:.2f}
|
| 204 |
+
- Device: {DEVICE}
|
| 205 |
+
"""
|
| 206 |
+
|
| 207 |
+
return output_path, stats
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def show_precomputed_demo(demo_name):
|
| 211 |
+
"""Show a precomputed demo video"""
|
| 212 |
+
demo_videos = {
|
| 213 |
+
"Demo 1 - Multi-tool tracking": "demos/demo1_tracked.mp4",
|
| 214 |
+
"Demo 2 - Occlusion handling": "demos/demo2_tracked.mp4",
|
| 215 |
+
"Demo 3 - Tool re-identification": "demos/demo3_tracked.mp4",
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
video_path = demo_videos.get(demo_name)
|
| 219 |
+
|
| 220 |
+
if video_path and os.path.exists(video_path):
|
| 221 |
+
# Get stats from companion json if exists
|
| 222 |
+
stats = f"""
|
| 223 |
+
**{demo_name}**
|
| 224 |
+
|
| 225 |
+
Pre-computed tracking results using:
|
| 226 |
+
- YOLOv11x for detection
|
| 227 |
+
- Direction Estimator for operator prediction
|
| 228 |
+
- Operator-based tracker for multi-tool tracking
|
| 229 |
+
|
| 230 |
+
*Results computed on GPU, displayed instantly.*
|
| 231 |
+
"""
|
| 232 |
+
return video_path, stats
|
| 233 |
+
else:
|
| 234 |
+
return None, f"Demo video not found: {video_path}"
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def get_available_demos():
|
| 238 |
+
"""Get list of available demo videos"""
|
| 239 |
+
demos_dir = Path("demos")
|
| 240 |
+
if demos_dir.exists():
|
| 241 |
+
return [f.stem.replace("_tracked", "") for f in demos_dir.glob("*_tracked.mp4")]
|
| 242 |
+
return ["Demo 1 - Multi-tool tracking", "Demo 2 - Occlusion handling", "Demo 3 - Tool re-identification"]
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
# Build Gradio interface
|
| 246 |
+
def create_interface():
|
| 247 |
+
with gr.Blocks(
|
| 248 |
+
title="SurgiTrack - Surgical Tool Tracking",
|
| 249 |
+
theme=gr.themes.Base(
|
| 250 |
+
primary_hue="purple",
|
| 251 |
+
secondary_hue="gray",
|
| 252 |
+
neutral_hue="gray",
|
| 253 |
+
).set(
|
| 254 |
+
body_background_fill="#0a0a0f",
|
| 255 |
+
body_background_fill_dark="#0a0a0f",
|
| 256 |
+
block_background_fill="#12121a",
|
| 257 |
+
block_background_fill_dark="#12121a",
|
| 258 |
+
block_border_color="#2a2a3a",
|
| 259 |
+
block_border_color_dark="#2a2a3a",
|
| 260 |
+
button_primary_background_fill="#a855f7",
|
| 261 |
+
button_primary_background_fill_hover="#9333ea",
|
| 262 |
+
),
|
| 263 |
+
css="""
|
| 264 |
+
.gradio-container { max-width: 1200px !important; }
|
| 265 |
+
.gr-button { font-weight: 500; }
|
| 266 |
+
footer { display: none !important; }
|
| 267 |
+
"""
|
| 268 |
+
) as demo:
|
| 269 |
+
|
| 270 |
+
gr.Markdown("""
|
| 271 |
+
# 🔬 SurgiTrack - Surgical Tool Tracking
|
| 272 |
+
|
| 273 |
+
Multi-class multi-tool tracking in laparoscopic surgery videos.
|
| 274 |
+
Based on the [SurgiTrack paper](https://arxiv.org/abs/2312.07352) and trained on CholecTrack20 dataset.
|
| 275 |
+
|
| 276 |
+
**Pipeline:** YOLOv11x Detection → Direction Estimation → Operator-based Tracking
|
| 277 |
+
|
| 278 |
+
---
|
| 279 |
+
""")
|
| 280 |
+
|
| 281 |
+
with gr.Tabs():
|
| 282 |
+
# Tab 1: Pre-computed demos (instant)
|
| 283 |
+
with gr.TabItem("📽️ Demo Videos (Instant)"):
|
| 284 |
+
gr.Markdown("""
|
| 285 |
+
### Pre-computed Results
|
| 286 |
+
Watch tracking results instantly. These videos were processed on GPU with full pipeline.
|
| 287 |
+
""")
|
| 288 |
+
|
| 289 |
+
with gr.Row():
|
| 290 |
+
demo_dropdown = gr.Dropdown(
|
| 291 |
+
choices=get_available_demos(),
|
| 292 |
+
label="Select Demo",
|
| 293 |
+
value=get_available_demos()[0] if get_available_demos() else None
|
| 294 |
+
)
|
| 295 |
+
demo_btn = gr.Button("▶️ Show Demo", variant="primary")
|
| 296 |
+
|
| 297 |
+
with gr.Row():
|
| 298 |
+
demo_video = gr.Video(label="Tracking Result")
|
| 299 |
+
demo_stats = gr.Markdown(label="Statistics")
|
| 300 |
+
|
| 301 |
+
demo_btn.click(
|
| 302 |
+
fn=show_precomputed_demo,
|
| 303 |
+
inputs=[demo_dropdown],
|
| 304 |
+
outputs=[demo_video, demo_stats]
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
# Tab 2: Live inference (slower but real)
|
| 308 |
+
with gr.TabItem("🔄 Live Inference (CPU)"):
|
| 309 |
+
gr.Markdown("""
|
| 310 |
+
### Real-time Processing
|
| 311 |
+
Upload a short video clip (5-15 seconds recommended) for live tracking.
|
| 312 |
+
|
| 313 |
+
⚠️ **Note:** Running on CPU - processing may take a few minutes.
|
| 314 |
+
""")
|
| 315 |
+
|
| 316 |
+
with gr.Row():
|
| 317 |
+
with gr.Column():
|
| 318 |
+
input_video = gr.Video(label="Upload Video")
|
| 319 |
+
confidence_slider = gr.Slider(
|
| 320 |
+
minimum=0.1, maximum=0.9, value=0.25, step=0.05,
|
| 321 |
+
label="Detection Confidence Threshold"
|
| 322 |
+
)
|
| 323 |
+
process_btn = gr.Button("🚀 Run Tracking", variant="primary")
|
| 324 |
+
|
| 325 |
+
with gr.Column():
|
| 326 |
+
output_video = gr.Video(label="Tracked Video")
|
| 327 |
+
output_stats = gr.Markdown(label="Statistics")
|
| 328 |
+
|
| 329 |
+
process_btn.click(
|
| 330 |
+
fn=process_video_live,
|
| 331 |
+
inputs=[input_video, confidence_slider],
|
| 332 |
+
outputs=[output_video, output_stats]
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
gr.Markdown("""
|
| 336 |
+
---
|
| 337 |
+
|
| 338 |
+
### 📊 Method Overview
|
| 339 |
+
|
| 340 |
+
| Component | Description |
|
| 341 |
+
|-----------|-------------|
|
| 342 |
+
| **Detection** | YOLOv11x trained on CholecTrack20 (7 tool classes) |
|
| 343 |
+
| **Direction Estimator** | EfficientNet-B0 + Coordinate Attention → Operator prediction |
|
| 344 |
+
| **Tracker** | Operator-based slots for graspers, fixed IDs for other tools |
|
| 345 |
+
|
| 346 |
+
### 📈 Results on CholecTrack20 Test Set
|
| 347 |
+
|
| 348 |
+
| Metric | Score |
|
| 349 |
+
|--------|-------|
|
| 350 |
+
| **HOTA** | 64.48% |
|
| 351 |
+
| **AssA** | 71.19% |
|
| 352 |
+
| **DetA** | 58.51% |
|
| 353 |
+
|
| 354 |
+
---
|
| 355 |
+
|
| 356 |
+
**Dataset:** [CholecTrack20](https://arxiv.org/abs/2312.07352) (Nwoye et al., CVPR 2025)
|
| 357 |
+
|
| 358 |
+
**Author:** [Djalil Khelladi](https://github.com/akhellad)
|
| 359 |
+
""")
|
| 360 |
+
|
| 361 |
+
return demo
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
if __name__ == "__main__":
|
| 365 |
+
print(f"Starting SurgiTrack Demo on {DEVICE}...")
|
| 366 |
+
|
| 367 |
+
# Try to load models
|
| 368 |
+
models_loaded = load_models()
|
| 369 |
+
|
| 370 |
+
if not models_loaded:
|
| 371 |
+
print("Warning: Models not loaded. Only pre-computed demos will work.")
|
| 372 |
+
|
| 373 |
+
# Create and launch interface
|
| 374 |
+
demo = create_interface()
|
| 375 |
+
demo.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=4.0.0
|
| 2 |
+
torch>=2.0.0
|
| 3 |
+
torchvision>=0.15.0
|
| 4 |
+
ultralytics>=8.0.0
|
| 5 |
+
opencv-python-headless>=4.8.0
|
| 6 |
+
numpy>=1.24.0
|
| 7 |
+
scipy>=1.10.0
|
tracker.py
ADDED
|
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SurgiTrack - Tracker Module (Simplified for HF Space)
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torchvision import models
|
| 9 |
+
import numpy as np
|
| 10 |
+
from scipy.optimize import linear_sum_assignment
|
| 11 |
+
from dataclasses import dataclass, field
|
| 12 |
+
from typing import List, Dict, Optional
|
| 13 |
+
import cv2
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
CLASS_NAMES = ['grasper', 'bipolar', 'hook', 'scissors', 'clipper', 'irrigator', 'specimenbag']
|
| 17 |
+
OPERATORS = ['MSLH', 'MSRH', 'ASRH', 'NULL']
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class CoordinateAttention(nn.Module):
|
| 21 |
+
def __init__(self, in_channels, reduction=32):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
|
| 24 |
+
self.pool_w = nn.AdaptiveAvgPool2d((1, None))
|
| 25 |
+
|
| 26 |
+
mid_channels = max(8, in_channels // reduction)
|
| 27 |
+
self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1)
|
| 28 |
+
self.bn1 = nn.BatchNorm2d(mid_channels)
|
| 29 |
+
self.act = nn.ReLU(inplace=True)
|
| 30 |
+
|
| 31 |
+
self.conv_h = nn.Conv2d(mid_channels, in_channels, kernel_size=1)
|
| 32 |
+
self.conv_w = nn.Conv2d(mid_channels, in_channels, kernel_size=1)
|
| 33 |
+
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
B, C, H, W = x.shape
|
| 36 |
+
|
| 37 |
+
x_h = self.pool_h(x)
|
| 38 |
+
x_w = self.pool_w(x).permute(0, 1, 3, 2)
|
| 39 |
+
|
| 40 |
+
y = torch.cat([x_h, x_w], dim=2)
|
| 41 |
+
y = self.act(self.bn1(self.conv1(y)))
|
| 42 |
+
|
| 43 |
+
x_h, x_w = torch.split(y, [H, W], dim=2)
|
| 44 |
+
x_w = x_w.permute(0, 1, 3, 2)
|
| 45 |
+
|
| 46 |
+
a_h = self.conv_h(x_h).sigmoid()
|
| 47 |
+
a_w = self.conv_w(x_w).sigmoid()
|
| 48 |
+
|
| 49 |
+
return x * a_h * a_w
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class DirectionEstimator(nn.Module):
|
| 53 |
+
def __init__(self, num_classes=4, embedding_dim=128, pretrained=True):
|
| 54 |
+
super().__init__()
|
| 55 |
+
|
| 56 |
+
self.backbone = models.efficientnet_b0(
|
| 57 |
+
weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1 if pretrained else None
|
| 58 |
+
)
|
| 59 |
+
backbone_out = self.backbone.classifier[1].in_features
|
| 60 |
+
self.backbone.classifier = nn.Identity()
|
| 61 |
+
|
| 62 |
+
self.coord_attention = CoordinateAttention(backbone_out)
|
| 63 |
+
|
| 64 |
+
self.embedding_head = nn.Sequential(
|
| 65 |
+
nn.Linear(backbone_out, 512),
|
| 66 |
+
nn.ReLU(inplace=True),
|
| 67 |
+
nn.Dropout(0.3),
|
| 68 |
+
nn.Linear(512, embedding_dim)
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
self.direction_head = nn.Sequential(
|
| 72 |
+
nn.Linear(embedding_dim, 64),
|
| 73 |
+
nn.ReLU(inplace=True),
|
| 74 |
+
nn.Dropout(0.2),
|
| 75 |
+
nn.Linear(64, num_classes)
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
self.embedding_dim = embedding_dim
|
| 79 |
+
|
| 80 |
+
def forward(self, x, return_embedding=False):
|
| 81 |
+
features = self.backbone.features(x)
|
| 82 |
+
features = self.coord_attention(features)
|
| 83 |
+
features = self.backbone.avgpool(features)
|
| 84 |
+
features = features.flatten(1)
|
| 85 |
+
|
| 86 |
+
embedding = self.embedding_head(features)
|
| 87 |
+
embedding = F.normalize(embedding, p=2, dim=1)
|
| 88 |
+
|
| 89 |
+
direction = self.direction_head(embedding)
|
| 90 |
+
|
| 91 |
+
if return_embedding:
|
| 92 |
+
return direction, embedding
|
| 93 |
+
return direction
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
@dataclass
|
| 97 |
+
class Detection:
|
| 98 |
+
bbox: np.ndarray
|
| 99 |
+
class_id: int
|
| 100 |
+
class_name: str
|
| 101 |
+
confidence: float
|
| 102 |
+
frame_id: int
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@dataclass
|
| 106 |
+
class OperatorSlot:
|
| 107 |
+
operator_id: int
|
| 108 |
+
operator_name: str
|
| 109 |
+
track_id: int
|
| 110 |
+
|
| 111 |
+
active: bool = False
|
| 112 |
+
class_id: int = -1
|
| 113 |
+
class_name: str = ""
|
| 114 |
+
bbox: np.ndarray = None
|
| 115 |
+
confidence: float = 0.0
|
| 116 |
+
embedding: np.ndarray = None
|
| 117 |
+
|
| 118 |
+
last_seen_frame: int = -1
|
| 119 |
+
total_detections: int = 0
|
| 120 |
+
bbox_history: List[np.ndarray] = field(default_factory=list)
|
| 121 |
+
class_history: List[int] = field(default_factory=list)
|
| 122 |
+
|
| 123 |
+
def update(self, detection: Detection, embedding: np.ndarray, frame_id: int):
|
| 124 |
+
self.active = True
|
| 125 |
+
self.bbox = detection.bbox
|
| 126 |
+
self.class_id = detection.class_id
|
| 127 |
+
self.class_name = detection.class_name
|
| 128 |
+
self.confidence = detection.confidence
|
| 129 |
+
self.embedding = embedding
|
| 130 |
+
self.last_seen_frame = frame_id
|
| 131 |
+
self.total_detections += 1
|
| 132 |
+
|
| 133 |
+
self.bbox_history.append(detection.bbox.copy())
|
| 134 |
+
self.class_history.append(detection.class_id)
|
| 135 |
+
|
| 136 |
+
if len(self.bbox_history) > 100:
|
| 137 |
+
self.bbox_history.pop(0)
|
| 138 |
+
self.class_history.pop(0)
|
| 139 |
+
|
| 140 |
+
def mark_inactive(self):
|
| 141 |
+
self.active = False
|
| 142 |
+
|
| 143 |
+
def frames_since_seen(self, current_frame: int) -> int:
|
| 144 |
+
if self.last_seen_frame < 0:
|
| 145 |
+
return float('inf')
|
| 146 |
+
return current_frame - self.last_seen_frame
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class OperatorBasedTracker:
|
| 150 |
+
MAX_GRASPERS = 3
|
| 151 |
+
GRASPER_CLASS_ID = 0
|
| 152 |
+
SINGLE_INSTANCE_CLASSES = {1, 2, 3, 4, 5, 6}
|
| 153 |
+
|
| 154 |
+
def __init__(
|
| 155 |
+
self,
|
| 156 |
+
direction_model: DirectionEstimator = None,
|
| 157 |
+
max_inactive_frames: int = 300,
|
| 158 |
+
iou_threshold: float = 0.3,
|
| 159 |
+
direction_confidence_threshold: float = 0.5,
|
| 160 |
+
device: str = "cuda"
|
| 161 |
+
):
|
| 162 |
+
self.direction_model = direction_model
|
| 163 |
+
self.max_inactive_frames = max_inactive_frames
|
| 164 |
+
self.iou_threshold = iou_threshold
|
| 165 |
+
self.direction_confidence_threshold = direction_confidence_threshold
|
| 166 |
+
self.device = device
|
| 167 |
+
|
| 168 |
+
self.grasper_slots: List[OperatorSlot] = []
|
| 169 |
+
self.class_slots: Dict[int, OperatorSlot] = {}
|
| 170 |
+
|
| 171 |
+
self.next_track_id = 1
|
| 172 |
+
self.frame_count = 0
|
| 173 |
+
|
| 174 |
+
self._initialize_slots()
|
| 175 |
+
|
| 176 |
+
if self.direction_model is not None:
|
| 177 |
+
self.direction_model.to(device)
|
| 178 |
+
self.direction_model.eval()
|
| 179 |
+
|
| 180 |
+
def _initialize_slots(self):
|
| 181 |
+
for i in range(self.MAX_GRASPERS):
|
| 182 |
+
slot = OperatorSlot(
|
| 183 |
+
operator_id=-1,
|
| 184 |
+
operator_name=f"grasper_{i+1}",
|
| 185 |
+
track_id=self.next_track_id
|
| 186 |
+
)
|
| 187 |
+
slot.class_id = self.GRASPER_CLASS_ID
|
| 188 |
+
slot.class_name = 'grasper'
|
| 189 |
+
self.next_track_id += 1
|
| 190 |
+
self.grasper_slots.append(slot)
|
| 191 |
+
|
| 192 |
+
for class_id in self.SINGLE_INSTANCE_CLASSES:
|
| 193 |
+
slot = OperatorSlot(
|
| 194 |
+
operator_id=3,
|
| 195 |
+
operator_name=f"CLASS_{CLASS_NAMES[class_id]}",
|
| 196 |
+
track_id=self.next_track_id
|
| 197 |
+
)
|
| 198 |
+
slot.class_id = class_id
|
| 199 |
+
slot.class_name = CLASS_NAMES[class_id]
|
| 200 |
+
self.next_track_id += 1
|
| 201 |
+
self.class_slots[class_id] = slot
|
| 202 |
+
|
| 203 |
+
def _get_direction_prediction(self, frame: np.ndarray, bbox: np.ndarray):
|
| 204 |
+
if self.direction_model is None:
|
| 205 |
+
return 3, np.array([0.25, 0.25, 0.25, 0.25])
|
| 206 |
+
|
| 207 |
+
x1, y1, x2, y2 = bbox.astype(int)
|
| 208 |
+
h, w = frame.shape[:2]
|
| 209 |
+
|
| 210 |
+
pad_x = int((x2 - x1) * 0.3)
|
| 211 |
+
pad_y = int((y2 - y1) * 0.5)
|
| 212 |
+
|
| 213 |
+
x1 = max(0, x1 - pad_x)
|
| 214 |
+
y1 = max(0, y1 - pad_y)
|
| 215 |
+
x2 = min(w, x2 + pad_x)
|
| 216 |
+
y2 = min(h, y2 + pad_y)
|
| 217 |
+
|
| 218 |
+
crop = frame[y1:y2, x1:x2]
|
| 219 |
+
if crop.size == 0:
|
| 220 |
+
return 3, np.array([0.25, 0.25, 0.25, 0.25])
|
| 221 |
+
|
| 222 |
+
crop = cv2.resize(crop, (224, 224))
|
| 223 |
+
crop = crop.astype(np.float32) / 255.0
|
| 224 |
+
crop = (crop - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
|
| 225 |
+
crop = torch.from_numpy(crop).permute(2, 0, 1).unsqueeze(0).float().to(self.device)
|
| 226 |
+
|
| 227 |
+
with torch.no_grad():
|
| 228 |
+
logits, embedding = self.direction_model(crop, return_embedding=True)
|
| 229 |
+
probs = F.softmax(logits, dim=1).cpu().numpy()[0]
|
| 230 |
+
|
| 231 |
+
return np.argmax(probs), probs
|
| 232 |
+
|
| 233 |
+
def _compute_iou(self, bbox1: np.ndarray, bbox2: np.ndarray) -> float:
|
| 234 |
+
if bbox1 is None or bbox2 is None:
|
| 235 |
+
return 0.0
|
| 236 |
+
|
| 237 |
+
x1 = max(bbox1[0], bbox2[0])
|
| 238 |
+
y1 = max(bbox1[1], bbox2[1])
|
| 239 |
+
x2 = min(bbox1[2], bbox2[2])
|
| 240 |
+
y2 = min(bbox1[3], bbox2[3])
|
| 241 |
+
|
| 242 |
+
inter = max(0, x2 - x1) * max(0, y2 - y1)
|
| 243 |
+
area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
|
| 244 |
+
area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
|
| 245 |
+
union = area1 + area2 - inter
|
| 246 |
+
|
| 247 |
+
return inter / (union + 1e-6)
|
| 248 |
+
|
| 249 |
+
def _find_best_slot(self, detection: Detection, predicted_op: int, direction_probs: np.ndarray) -> Optional[OperatorSlot]:
|
| 250 |
+
class_id = detection.class_id
|
| 251 |
+
|
| 252 |
+
if class_id in self.SINGLE_INSTANCE_CLASSES:
|
| 253 |
+
slot = self.class_slots.get(class_id)
|
| 254 |
+
if slot:
|
| 255 |
+
recency = slot.frames_since_seen(self.frame_count)
|
| 256 |
+
if not slot.active and recency >= 75:
|
| 257 |
+
slot.track_id = self.next_track_id
|
| 258 |
+
self.next_track_id += 1
|
| 259 |
+
return slot
|
| 260 |
+
|
| 261 |
+
if class_id == self.GRASPER_CLASS_ID:
|
| 262 |
+
direction_confident = predicted_op < 3 and direction_probs[predicted_op] > self.direction_confidence_threshold
|
| 263 |
+
|
| 264 |
+
best_slot = None
|
| 265 |
+
best_score = -1
|
| 266 |
+
for slot in self.grasper_slots:
|
| 267 |
+
if slot.bbox is None:
|
| 268 |
+
continue
|
| 269 |
+
|
| 270 |
+
recency = slot.frames_since_seen(self.frame_count)
|
| 271 |
+
if recency >= 75:
|
| 272 |
+
continue
|
| 273 |
+
|
| 274 |
+
iou = self._compute_iou(detection.bbox, slot.bbox)
|
| 275 |
+
|
| 276 |
+
det_center = (detection.bbox[:2] + detection.bbox[2:]) / 2
|
| 277 |
+
slot_center = (slot.bbox[:2] + slot.bbox[2:]) / 2
|
| 278 |
+
dist = np.linalg.norm(det_center - slot_center)
|
| 279 |
+
|
| 280 |
+
if iou > self.iou_threshold:
|
| 281 |
+
score = iou + (0.2 if slot.operator_id == predicted_op else 0)
|
| 282 |
+
elif dist < 150 and recency < 30:
|
| 283 |
+
score = 0.1 + (0.2 if slot.operator_id == predicted_op else 0)
|
| 284 |
+
else:
|
| 285 |
+
continue
|
| 286 |
+
|
| 287 |
+
if score > best_score:
|
| 288 |
+
best_score = score
|
| 289 |
+
best_slot = slot
|
| 290 |
+
|
| 291 |
+
if best_slot:
|
| 292 |
+
return best_slot
|
| 293 |
+
|
| 294 |
+
if direction_confident:
|
| 295 |
+
for slot in self.grasper_slots:
|
| 296 |
+
if slot.active or slot.bbox is None:
|
| 297 |
+
continue
|
| 298 |
+
if slot.operator_id == predicted_op and slot.frames_since_seen(self.frame_count) < 75:
|
| 299 |
+
return slot
|
| 300 |
+
|
| 301 |
+
if not direction_confident:
|
| 302 |
+
for slot in self.grasper_slots:
|
| 303 |
+
if slot.active or slot.bbox is None:
|
| 304 |
+
continue
|
| 305 |
+
if slot.frames_since_seen(self.frame_count) < 30:
|
| 306 |
+
det_center = (detection.bbox[:2] + detection.bbox[2:]) / 2
|
| 307 |
+
slot_center = (slot.bbox[:2] + slot.bbox[2:]) / 2
|
| 308 |
+
dist = np.linalg.norm(det_center - slot_center)
|
| 309 |
+
if dist < 100:
|
| 310 |
+
return slot
|
| 311 |
+
|
| 312 |
+
for slot in self.grasper_slots:
|
| 313 |
+
if not slot.active:
|
| 314 |
+
slot.track_id = self.next_track_id
|
| 315 |
+
self.next_track_id += 1
|
| 316 |
+
return slot
|
| 317 |
+
|
| 318 |
+
worst_slot = None
|
| 319 |
+
worst_iou = 1.0
|
| 320 |
+
for slot in self.grasper_slots:
|
| 321 |
+
iou = self._compute_iou(detection.bbox, slot.bbox)
|
| 322 |
+
if iou < worst_iou:
|
| 323 |
+
worst_iou = iou
|
| 324 |
+
worst_slot = slot
|
| 325 |
+
|
| 326 |
+
if worst_slot:
|
| 327 |
+
worst_slot.track_id = self.next_track_id
|
| 328 |
+
self.next_track_id += 1
|
| 329 |
+
return worst_slot
|
| 330 |
+
|
| 331 |
+
return None
|
| 332 |
+
|
| 333 |
+
def update(self, frame: np.ndarray, detections: List[Detection]) -> List[OperatorSlot]:
|
| 334 |
+
self.frame_count += 1
|
| 335 |
+
|
| 336 |
+
all_slots = self.grasper_slots + list(self.class_slots.values())
|
| 337 |
+
for slot in all_slots:
|
| 338 |
+
if slot.active and slot.frames_since_seen(self.frame_count) > 150:
|
| 339 |
+
slot.mark_inactive()
|
| 340 |
+
|
| 341 |
+
if len(detections) == 0:
|
| 342 |
+
return self._get_active_slots()
|
| 343 |
+
|
| 344 |
+
detection_info = []
|
| 345 |
+
for det in detections:
|
| 346 |
+
pred_op, probs = self._get_direction_prediction(frame, det.bbox)
|
| 347 |
+
detection_info.append((det, pred_op, probs))
|
| 348 |
+
|
| 349 |
+
detection_info.sort(key=lambda x: -x[0].confidence)
|
| 350 |
+
|
| 351 |
+
assigned_slots = set()
|
| 352 |
+
|
| 353 |
+
for det, pred_op, probs in detection_info:
|
| 354 |
+
slot = self._find_best_slot(det, pred_op, probs)
|
| 355 |
+
|
| 356 |
+
if slot and slot.track_id not in assigned_slots:
|
| 357 |
+
slot.update(det, probs, self.frame_count)
|
| 358 |
+
if det.class_id == self.GRASPER_CLASS_ID:
|
| 359 |
+
slot.operator_id = pred_op
|
| 360 |
+
assigned_slots.add(slot.track_id)
|
| 361 |
+
|
| 362 |
+
return self._get_active_slots()
|
| 363 |
+
|
| 364 |
+
def _get_active_slots(self) -> List[OperatorSlot]:
|
| 365 |
+
active = []
|
| 366 |
+
for slot in self.grasper_slots:
|
| 367 |
+
if slot.active and slot.last_seen_frame == self.frame_count:
|
| 368 |
+
active.append(slot)
|
| 369 |
+
for slot in self.class_slots.values():
|
| 370 |
+
if slot.active and slot.last_seen_frame == self.frame_count:
|
| 371 |
+
active.append(slot)
|
| 372 |
+
return active
|
| 373 |
+
|
| 374 |
+
def reset(self):
|
| 375 |
+
self.grasper_slots = []
|
| 376 |
+
self.class_slots = {}
|
| 377 |
+
self.next_track_id = 1
|
| 378 |
+
self.frame_count = 0
|
| 379 |
+
self._initialize_slots()
|