akhellad commited on
Commit
26a3529
·
1 Parent(s): 9e5c05e

Initial commit

Browse files
Files changed (4) hide show
  1. README.md +52 -5
  2. app.py +375 -0
  3. requirements.txt +7 -0
  4. tracker.py +379 -0
README.md CHANGED
@@ -1,12 +1,59 @@
1
  ---
2
- title: SurgiTrackDemo
3
- emoji: 📊
4
- colorFrom: indigo
5
  colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 6.1.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: SurgiTrack - Surgical Tool Tracking
3
+ emoji: 🔬
4
+ colorFrom: purple
5
  colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
+ # SurgiTrack - Surgical Tool Tracking
14
+
15
+ Multi-class multi-tool tracking system for laparoscopic surgery videos.
16
+
17
+ ## Overview
18
+
19
+ This demo implements the tracking pipeline from ["SurgiTrack: Fine-Grained Multi-Class Multi-Tool Tracking in Surgical Videos"](https://arxiv.org/abs/2312.07352), trained and evaluated on the CholecTrack20 dataset.
20
+
21
+ ## Pipeline
22
+
23
+ 1. **Detection**: YOLOv11x trained on 7 surgical tool classes
24
+ 2. **Direction Estimation**: EfficientNet-B0 + Coordinate Attention predicts operator (MSLH, MSRH, ASRH)
25
+ 3. **Tracking**: Operator-based slot assignment for graspers, fixed IDs for other tools
26
+
27
+ ## Results
28
+
29
+ | Metric | Score |
30
+ |--------|-------|
31
+ | HOTA | 64.48% |
32
+ | AssA | 71.19% |
33
+ | DetA | 58.51% |
34
+
35
+ ## Tool Classes
36
+
37
+ - Grasper (tracked by operator)
38
+ - Bipolar
39
+ - Hook
40
+ - Scissors
41
+ - Clipper
42
+ - Irrigator
43
+ - Specimen Bag
44
+
45
+ ## Citation
46
+
47
+ ```bibtex
48
+ @InProceedings{nwoye2023cholectrack20,
49
+ author = {Nwoye, Chinedu Innocent and Elgohary, Kareem and Srinivas, Anvita and Zaid, Fauzan and Lavanchy, Joël L. and Padoy, Nicolas},
50
+ title = {CholecTrack20: A Multi-Perspective Tracking Dataset for Surgical Tools},
51
+ booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
52
+ year = {2025},
53
+ month = {June}
54
+ }
55
+ ```
56
+
57
+ ## Author
58
+
59
+ [Djalil Khelladi](https://github.com/akhellad)
app.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SurgiTrack Demo - Surgical Tool Tracking
3
+ Based on CholecTrack20 dataset (Nwoye et al., CVPR 2025)
4
+ """
5
+
6
+ import os
7
+ import gradio as gr
8
+ import cv2
9
+ import numpy as np
10
+ import torch
11
+ from pathlib import Path
12
+ from collections import deque
13
+
14
+ # Import models (will be loaded on startup)
15
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
+ YOLO_MODEL = None
17
+ DIRECTION_MODEL = None
18
+ TRACKER = None
19
+
20
+ CLASS_NAMES = ['grasper', 'bipolar', 'hook', 'scissors', 'clipper', 'irrigator', 'specimenbag']
21
+
22
+ COLORS = {
23
+ 'grasper': (255, 100, 100),
24
+ 'bipolar': (100, 255, 100),
25
+ 'hook': (100, 100, 255),
26
+ 'scissors': (255, 255, 100),
27
+ 'clipper': (255, 100, 255),
28
+ 'irrigator': (100, 255, 255),
29
+ 'specimenbag': (200, 200, 200),
30
+ }
31
+
32
+ OPERATOR_COLORS = {
33
+ 0: (0, 255, 0), # MSLH - Green
34
+ 1: (0, 0, 255), # MSRH - Red
35
+ 2: (255, 165, 0), # ASRH - Orange
36
+ 3: (128, 128, 128) # NULL - Gray
37
+ }
38
+
39
+
40
+ def load_models():
41
+ """Load YOLO and Direction Estimator models"""
42
+ global YOLO_MODEL, DIRECTION_MODEL, TRACKER
43
+
44
+ from ultralytics import YOLO
45
+ from tracker import DirectionEstimator, OperatorBasedTracker
46
+
47
+ # Load YOLO
48
+ yolo_path = "weights/best.pt"
49
+ if os.path.exists(yolo_path):
50
+ YOLO_MODEL = YOLO(yolo_path)
51
+ print(f"YOLO model loaded from {yolo_path}")
52
+ else:
53
+ print(f"Warning: YOLO model not found at {yolo_path}")
54
+ return False
55
+
56
+ # Load Direction Estimator
57
+ direction_path = "weights/direction_estimator.pth"
58
+ if os.path.exists(direction_path):
59
+ DIRECTION_MODEL = DirectionEstimator(num_classes=4, pretrained=False)
60
+ checkpoint = torch.load(direction_path, map_location=DEVICE, weights_only=False)
61
+ DIRECTION_MODEL.load_state_dict(checkpoint['model_state_dict'])
62
+ DIRECTION_MODEL.to(DEVICE)
63
+ DIRECTION_MODEL.eval()
64
+ print(f"Direction model loaded from {direction_path}")
65
+ else:
66
+ print(f"Warning: Direction model not found at {direction_path}")
67
+ DIRECTION_MODEL = None
68
+
69
+ # Initialize tracker
70
+ TRACKER = OperatorBasedTracker(
71
+ direction_model=DIRECTION_MODEL,
72
+ max_inactive_frames=150,
73
+ iou_threshold=0.2,
74
+ direction_confidence_threshold=0.4,
75
+ device=DEVICE
76
+ )
77
+
78
+ return True
79
+
80
+
81
+ def draw_tracking_results(frame, slots, trajectories, frame_count):
82
+ """Draw bounding boxes, IDs, and trajectories on frame"""
83
+ for slot in slots:
84
+ if slot.bbox is None:
85
+ continue
86
+
87
+ x1, y1, x2, y2 = slot.bbox.astype(int)
88
+ track_id = slot.track_id
89
+ class_name = slot.class_name
90
+
91
+ # Update trajectory
92
+ center = (int((x1 + x2) / 2), int((y1 + y2) / 2))
93
+ if track_id not in trajectories:
94
+ trajectories[track_id] = deque(maxlen=30)
95
+ trajectories[track_id].append(center)
96
+
97
+ # Get colors
98
+ bbox_color = COLORS.get(class_name, (255, 255, 255))
99
+ op_color = OPERATOR_COLORS.get(slot.operator_id, (128, 128, 128))
100
+
101
+ # Draw bbox
102
+ cv2.rectangle(frame, (x1, y1), (x2, y2), bbox_color, 2)
103
+
104
+ # Draw operator indicator
105
+ cv2.circle(frame, (x2 - 10, y1 + 10), 8, op_color, -1)
106
+ cv2.circle(frame, (x2 - 10, y1 + 10), 8, (0, 0, 0), 1)
107
+
108
+ # Draw label
109
+ label = f"ID:{track_id} {class_name}"
110
+ (lw, lh), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
111
+ cv2.rectangle(frame, (x1, y1 - lh - 8), (x1 + lw + 4, y1), bbox_color, -1)
112
+ cv2.putText(frame, label, (x1 + 2, y1 - 4), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
113
+
114
+ # Draw trajectory
115
+ traj = list(trajectories[track_id])
116
+ for i in range(1, len(traj)):
117
+ alpha = i / len(traj)
118
+ thickness = max(1, int(alpha * 3))
119
+ color = tuple(int(c * alpha) for c in bbox_color)
120
+ cv2.line(frame, traj[i-1], traj[i], color, thickness)
121
+
122
+ # Draw frame counter
123
+ cv2.putText(frame, f"Frame: {frame_count}", (10, 30),
124
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
125
+
126
+ return frame, trajectories
127
+
128
+
129
+ def process_video_live(video_path, confidence_threshold, progress=gr.Progress()):
130
+ """Process video with live inference"""
131
+ global YOLO_MODEL, TRACKER
132
+
133
+ if YOLO_MODEL is None:
134
+ return None, "Error: Models not loaded"
135
+
136
+ from tracker import Detection
137
+
138
+ # Reset tracker
139
+ TRACKER.reset()
140
+
141
+ cap = cv2.VideoCapture(video_path)
142
+ fps = cap.get(cv2.CAP_PROP_FPS)
143
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
144
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
145
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
146
+
147
+ # Output video
148
+ output_path = "/tmp/output_tracked.mp4"
149
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
150
+ writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
151
+
152
+ trajectories = {}
153
+ frame_count = 0
154
+ total_detections = 0
155
+ unique_tracks = set()
156
+
157
+ while True:
158
+ ret, frame = cap.read()
159
+ if not ret:
160
+ break
161
+
162
+ # YOLO detection
163
+ results = YOLO_MODEL.predict(frame, conf=confidence_threshold, verbose=False)
164
+
165
+ detections = []
166
+ if len(results) > 0 and results[0].boxes is not None:
167
+ boxes = results[0].boxes
168
+ for i in range(len(boxes)):
169
+ class_id = int(boxes.cls[i])
170
+ detections.append(Detection(
171
+ bbox=boxes.xyxy[i].cpu().numpy(),
172
+ class_id=class_id,
173
+ class_name=CLASS_NAMES[class_id] if class_id < len(CLASS_NAMES) else "unknown",
174
+ confidence=float(boxes.conf[i]),
175
+ frame_id=frame_count
176
+ ))
177
+
178
+ total_detections += len(detections)
179
+
180
+ # Update tracker
181
+ slots = TRACKER.update(frame, detections)
182
+
183
+ for slot in slots:
184
+ unique_tracks.add(slot.track_id)
185
+
186
+ # Draw results
187
+ frame, trajectories = draw_tracking_results(frame, slots, trajectories, frame_count)
188
+
189
+ writer.write(frame)
190
+ frame_count += 1
191
+
192
+ progress(frame_count / total_frames, desc=f"Processing frame {frame_count}/{total_frames}")
193
+
194
+ cap.release()
195
+ writer.release()
196
+
197
+ # Stats
198
+ stats = f"""
199
+ **Processing Complete**
200
+ - Total frames: {frame_count}
201
+ - Total detections: {total_detections}
202
+ - Unique tracks: {len(unique_tracks)}
203
+ - Average detections/frame: {total_detections/frame_count:.2f}
204
+ - Device: {DEVICE}
205
+ """
206
+
207
+ return output_path, stats
208
+
209
+
210
+ def show_precomputed_demo(demo_name):
211
+ """Show a precomputed demo video"""
212
+ demo_videos = {
213
+ "Demo 1 - Multi-tool tracking": "demos/demo1_tracked.mp4",
214
+ "Demo 2 - Occlusion handling": "demos/demo2_tracked.mp4",
215
+ "Demo 3 - Tool re-identification": "demos/demo3_tracked.mp4",
216
+ }
217
+
218
+ video_path = demo_videos.get(demo_name)
219
+
220
+ if video_path and os.path.exists(video_path):
221
+ # Get stats from companion json if exists
222
+ stats = f"""
223
+ **{demo_name}**
224
+
225
+ Pre-computed tracking results using:
226
+ - YOLOv11x for detection
227
+ - Direction Estimator for operator prediction
228
+ - Operator-based tracker for multi-tool tracking
229
+
230
+ *Results computed on GPU, displayed instantly.*
231
+ """
232
+ return video_path, stats
233
+ else:
234
+ return None, f"Demo video not found: {video_path}"
235
+
236
+
237
+ def get_available_demos():
238
+ """Get list of available demo videos"""
239
+ demos_dir = Path("demos")
240
+ if demos_dir.exists():
241
+ return [f.stem.replace("_tracked", "") for f in demos_dir.glob("*_tracked.mp4")]
242
+ return ["Demo 1 - Multi-tool tracking", "Demo 2 - Occlusion handling", "Demo 3 - Tool re-identification"]
243
+
244
+
245
+ # Build Gradio interface
246
+ def create_interface():
247
+ with gr.Blocks(
248
+ title="SurgiTrack - Surgical Tool Tracking",
249
+ theme=gr.themes.Base(
250
+ primary_hue="purple",
251
+ secondary_hue="gray",
252
+ neutral_hue="gray",
253
+ ).set(
254
+ body_background_fill="#0a0a0f",
255
+ body_background_fill_dark="#0a0a0f",
256
+ block_background_fill="#12121a",
257
+ block_background_fill_dark="#12121a",
258
+ block_border_color="#2a2a3a",
259
+ block_border_color_dark="#2a2a3a",
260
+ button_primary_background_fill="#a855f7",
261
+ button_primary_background_fill_hover="#9333ea",
262
+ ),
263
+ css="""
264
+ .gradio-container { max-width: 1200px !important; }
265
+ .gr-button { font-weight: 500; }
266
+ footer { display: none !important; }
267
+ """
268
+ ) as demo:
269
+
270
+ gr.Markdown("""
271
+ # 🔬 SurgiTrack - Surgical Tool Tracking
272
+
273
+ Multi-class multi-tool tracking in laparoscopic surgery videos.
274
+ Based on the [SurgiTrack paper](https://arxiv.org/abs/2312.07352) and trained on CholecTrack20 dataset.
275
+
276
+ **Pipeline:** YOLOv11x Detection → Direction Estimation → Operator-based Tracking
277
+
278
+ ---
279
+ """)
280
+
281
+ with gr.Tabs():
282
+ # Tab 1: Pre-computed demos (instant)
283
+ with gr.TabItem("📽️ Demo Videos (Instant)"):
284
+ gr.Markdown("""
285
+ ### Pre-computed Results
286
+ Watch tracking results instantly. These videos were processed on GPU with full pipeline.
287
+ """)
288
+
289
+ with gr.Row():
290
+ demo_dropdown = gr.Dropdown(
291
+ choices=get_available_demos(),
292
+ label="Select Demo",
293
+ value=get_available_demos()[0] if get_available_demos() else None
294
+ )
295
+ demo_btn = gr.Button("▶️ Show Demo", variant="primary")
296
+
297
+ with gr.Row():
298
+ demo_video = gr.Video(label="Tracking Result")
299
+ demo_stats = gr.Markdown(label="Statistics")
300
+
301
+ demo_btn.click(
302
+ fn=show_precomputed_demo,
303
+ inputs=[demo_dropdown],
304
+ outputs=[demo_video, demo_stats]
305
+ )
306
+
307
+ # Tab 2: Live inference (slower but real)
308
+ with gr.TabItem("🔄 Live Inference (CPU)"):
309
+ gr.Markdown("""
310
+ ### Real-time Processing
311
+ Upload a short video clip (5-15 seconds recommended) for live tracking.
312
+
313
+ ⚠️ **Note:** Running on CPU - processing may take a few minutes.
314
+ """)
315
+
316
+ with gr.Row():
317
+ with gr.Column():
318
+ input_video = gr.Video(label="Upload Video")
319
+ confidence_slider = gr.Slider(
320
+ minimum=0.1, maximum=0.9, value=0.25, step=0.05,
321
+ label="Detection Confidence Threshold"
322
+ )
323
+ process_btn = gr.Button("🚀 Run Tracking", variant="primary")
324
+
325
+ with gr.Column():
326
+ output_video = gr.Video(label="Tracked Video")
327
+ output_stats = gr.Markdown(label="Statistics")
328
+
329
+ process_btn.click(
330
+ fn=process_video_live,
331
+ inputs=[input_video, confidence_slider],
332
+ outputs=[output_video, output_stats]
333
+ )
334
+
335
+ gr.Markdown("""
336
+ ---
337
+
338
+ ### 📊 Method Overview
339
+
340
+ | Component | Description |
341
+ |-----------|-------------|
342
+ | **Detection** | YOLOv11x trained on CholecTrack20 (7 tool classes) |
343
+ | **Direction Estimator** | EfficientNet-B0 + Coordinate Attention → Operator prediction |
344
+ | **Tracker** | Operator-based slots for graspers, fixed IDs for other tools |
345
+
346
+ ### 📈 Results on CholecTrack20 Test Set
347
+
348
+ | Metric | Score |
349
+ |--------|-------|
350
+ | **HOTA** | 64.48% |
351
+ | **AssA** | 71.19% |
352
+ | **DetA** | 58.51% |
353
+
354
+ ---
355
+
356
+ **Dataset:** [CholecTrack20](https://arxiv.org/abs/2312.07352) (Nwoye et al., CVPR 2025)
357
+
358
+ **Author:** [Djalil Khelladi](https://github.com/akhellad)
359
+ """)
360
+
361
+ return demo
362
+
363
+
364
+ if __name__ == "__main__":
365
+ print(f"Starting SurgiTrack Demo on {DEVICE}...")
366
+
367
+ # Try to load models
368
+ models_loaded = load_models()
369
+
370
+ if not models_loaded:
371
+ print("Warning: Models not loaded. Only pre-computed demos will work.")
372
+
373
+ # Create and launch interface
374
+ demo = create_interface()
375
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ torch>=2.0.0
3
+ torchvision>=0.15.0
4
+ ultralytics>=8.0.0
5
+ opencv-python-headless>=4.8.0
6
+ numpy>=1.24.0
7
+ scipy>=1.10.0
tracker.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SurgiTrack - Tracker Module (Simplified for HF Space)
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torchvision import models
9
+ import numpy as np
10
+ from scipy.optimize import linear_sum_assignment
11
+ from dataclasses import dataclass, field
12
+ from typing import List, Dict, Optional
13
+ import cv2
14
+
15
+
16
+ CLASS_NAMES = ['grasper', 'bipolar', 'hook', 'scissors', 'clipper', 'irrigator', 'specimenbag']
17
+ OPERATORS = ['MSLH', 'MSRH', 'ASRH', 'NULL']
18
+
19
+
20
+ class CoordinateAttention(nn.Module):
21
+ def __init__(self, in_channels, reduction=32):
22
+ super().__init__()
23
+ self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
24
+ self.pool_w = nn.AdaptiveAvgPool2d((1, None))
25
+
26
+ mid_channels = max(8, in_channels // reduction)
27
+ self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1)
28
+ self.bn1 = nn.BatchNorm2d(mid_channels)
29
+ self.act = nn.ReLU(inplace=True)
30
+
31
+ self.conv_h = nn.Conv2d(mid_channels, in_channels, kernel_size=1)
32
+ self.conv_w = nn.Conv2d(mid_channels, in_channels, kernel_size=1)
33
+
34
+ def forward(self, x):
35
+ B, C, H, W = x.shape
36
+
37
+ x_h = self.pool_h(x)
38
+ x_w = self.pool_w(x).permute(0, 1, 3, 2)
39
+
40
+ y = torch.cat([x_h, x_w], dim=2)
41
+ y = self.act(self.bn1(self.conv1(y)))
42
+
43
+ x_h, x_w = torch.split(y, [H, W], dim=2)
44
+ x_w = x_w.permute(0, 1, 3, 2)
45
+
46
+ a_h = self.conv_h(x_h).sigmoid()
47
+ a_w = self.conv_w(x_w).sigmoid()
48
+
49
+ return x * a_h * a_w
50
+
51
+
52
+ class DirectionEstimator(nn.Module):
53
+ def __init__(self, num_classes=4, embedding_dim=128, pretrained=True):
54
+ super().__init__()
55
+
56
+ self.backbone = models.efficientnet_b0(
57
+ weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1 if pretrained else None
58
+ )
59
+ backbone_out = self.backbone.classifier[1].in_features
60
+ self.backbone.classifier = nn.Identity()
61
+
62
+ self.coord_attention = CoordinateAttention(backbone_out)
63
+
64
+ self.embedding_head = nn.Sequential(
65
+ nn.Linear(backbone_out, 512),
66
+ nn.ReLU(inplace=True),
67
+ nn.Dropout(0.3),
68
+ nn.Linear(512, embedding_dim)
69
+ )
70
+
71
+ self.direction_head = nn.Sequential(
72
+ nn.Linear(embedding_dim, 64),
73
+ nn.ReLU(inplace=True),
74
+ nn.Dropout(0.2),
75
+ nn.Linear(64, num_classes)
76
+ )
77
+
78
+ self.embedding_dim = embedding_dim
79
+
80
+ def forward(self, x, return_embedding=False):
81
+ features = self.backbone.features(x)
82
+ features = self.coord_attention(features)
83
+ features = self.backbone.avgpool(features)
84
+ features = features.flatten(1)
85
+
86
+ embedding = self.embedding_head(features)
87
+ embedding = F.normalize(embedding, p=2, dim=1)
88
+
89
+ direction = self.direction_head(embedding)
90
+
91
+ if return_embedding:
92
+ return direction, embedding
93
+ return direction
94
+
95
+
96
+ @dataclass
97
+ class Detection:
98
+ bbox: np.ndarray
99
+ class_id: int
100
+ class_name: str
101
+ confidence: float
102
+ frame_id: int
103
+
104
+
105
+ @dataclass
106
+ class OperatorSlot:
107
+ operator_id: int
108
+ operator_name: str
109
+ track_id: int
110
+
111
+ active: bool = False
112
+ class_id: int = -1
113
+ class_name: str = ""
114
+ bbox: np.ndarray = None
115
+ confidence: float = 0.0
116
+ embedding: np.ndarray = None
117
+
118
+ last_seen_frame: int = -1
119
+ total_detections: int = 0
120
+ bbox_history: List[np.ndarray] = field(default_factory=list)
121
+ class_history: List[int] = field(default_factory=list)
122
+
123
+ def update(self, detection: Detection, embedding: np.ndarray, frame_id: int):
124
+ self.active = True
125
+ self.bbox = detection.bbox
126
+ self.class_id = detection.class_id
127
+ self.class_name = detection.class_name
128
+ self.confidence = detection.confidence
129
+ self.embedding = embedding
130
+ self.last_seen_frame = frame_id
131
+ self.total_detections += 1
132
+
133
+ self.bbox_history.append(detection.bbox.copy())
134
+ self.class_history.append(detection.class_id)
135
+
136
+ if len(self.bbox_history) > 100:
137
+ self.bbox_history.pop(0)
138
+ self.class_history.pop(0)
139
+
140
+ def mark_inactive(self):
141
+ self.active = False
142
+
143
+ def frames_since_seen(self, current_frame: int) -> int:
144
+ if self.last_seen_frame < 0:
145
+ return float('inf')
146
+ return current_frame - self.last_seen_frame
147
+
148
+
149
+ class OperatorBasedTracker:
150
+ MAX_GRASPERS = 3
151
+ GRASPER_CLASS_ID = 0
152
+ SINGLE_INSTANCE_CLASSES = {1, 2, 3, 4, 5, 6}
153
+
154
+ def __init__(
155
+ self,
156
+ direction_model: DirectionEstimator = None,
157
+ max_inactive_frames: int = 300,
158
+ iou_threshold: float = 0.3,
159
+ direction_confidence_threshold: float = 0.5,
160
+ device: str = "cuda"
161
+ ):
162
+ self.direction_model = direction_model
163
+ self.max_inactive_frames = max_inactive_frames
164
+ self.iou_threshold = iou_threshold
165
+ self.direction_confidence_threshold = direction_confidence_threshold
166
+ self.device = device
167
+
168
+ self.grasper_slots: List[OperatorSlot] = []
169
+ self.class_slots: Dict[int, OperatorSlot] = {}
170
+
171
+ self.next_track_id = 1
172
+ self.frame_count = 0
173
+
174
+ self._initialize_slots()
175
+
176
+ if self.direction_model is not None:
177
+ self.direction_model.to(device)
178
+ self.direction_model.eval()
179
+
180
+ def _initialize_slots(self):
181
+ for i in range(self.MAX_GRASPERS):
182
+ slot = OperatorSlot(
183
+ operator_id=-1,
184
+ operator_name=f"grasper_{i+1}",
185
+ track_id=self.next_track_id
186
+ )
187
+ slot.class_id = self.GRASPER_CLASS_ID
188
+ slot.class_name = 'grasper'
189
+ self.next_track_id += 1
190
+ self.grasper_slots.append(slot)
191
+
192
+ for class_id in self.SINGLE_INSTANCE_CLASSES:
193
+ slot = OperatorSlot(
194
+ operator_id=3,
195
+ operator_name=f"CLASS_{CLASS_NAMES[class_id]}",
196
+ track_id=self.next_track_id
197
+ )
198
+ slot.class_id = class_id
199
+ slot.class_name = CLASS_NAMES[class_id]
200
+ self.next_track_id += 1
201
+ self.class_slots[class_id] = slot
202
+
203
+ def _get_direction_prediction(self, frame: np.ndarray, bbox: np.ndarray):
204
+ if self.direction_model is None:
205
+ return 3, np.array([0.25, 0.25, 0.25, 0.25])
206
+
207
+ x1, y1, x2, y2 = bbox.astype(int)
208
+ h, w = frame.shape[:2]
209
+
210
+ pad_x = int((x2 - x1) * 0.3)
211
+ pad_y = int((y2 - y1) * 0.5)
212
+
213
+ x1 = max(0, x1 - pad_x)
214
+ y1 = max(0, y1 - pad_y)
215
+ x2 = min(w, x2 + pad_x)
216
+ y2 = min(h, y2 + pad_y)
217
+
218
+ crop = frame[y1:y2, x1:x2]
219
+ if crop.size == 0:
220
+ return 3, np.array([0.25, 0.25, 0.25, 0.25])
221
+
222
+ crop = cv2.resize(crop, (224, 224))
223
+ crop = crop.astype(np.float32) / 255.0
224
+ crop = (crop - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
225
+ crop = torch.from_numpy(crop).permute(2, 0, 1).unsqueeze(0).float().to(self.device)
226
+
227
+ with torch.no_grad():
228
+ logits, embedding = self.direction_model(crop, return_embedding=True)
229
+ probs = F.softmax(logits, dim=1).cpu().numpy()[0]
230
+
231
+ return np.argmax(probs), probs
232
+
233
+ def _compute_iou(self, bbox1: np.ndarray, bbox2: np.ndarray) -> float:
234
+ if bbox1 is None or bbox2 is None:
235
+ return 0.0
236
+
237
+ x1 = max(bbox1[0], bbox2[0])
238
+ y1 = max(bbox1[1], bbox2[1])
239
+ x2 = min(bbox1[2], bbox2[2])
240
+ y2 = min(bbox1[3], bbox2[3])
241
+
242
+ inter = max(0, x2 - x1) * max(0, y2 - y1)
243
+ area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
244
+ area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
245
+ union = area1 + area2 - inter
246
+
247
+ return inter / (union + 1e-6)
248
+
249
+ def _find_best_slot(self, detection: Detection, predicted_op: int, direction_probs: np.ndarray) -> Optional[OperatorSlot]:
250
+ class_id = detection.class_id
251
+
252
+ if class_id in self.SINGLE_INSTANCE_CLASSES:
253
+ slot = self.class_slots.get(class_id)
254
+ if slot:
255
+ recency = slot.frames_since_seen(self.frame_count)
256
+ if not slot.active and recency >= 75:
257
+ slot.track_id = self.next_track_id
258
+ self.next_track_id += 1
259
+ return slot
260
+
261
+ if class_id == self.GRASPER_CLASS_ID:
262
+ direction_confident = predicted_op < 3 and direction_probs[predicted_op] > self.direction_confidence_threshold
263
+
264
+ best_slot = None
265
+ best_score = -1
266
+ for slot in self.grasper_slots:
267
+ if slot.bbox is None:
268
+ continue
269
+
270
+ recency = slot.frames_since_seen(self.frame_count)
271
+ if recency >= 75:
272
+ continue
273
+
274
+ iou = self._compute_iou(detection.bbox, slot.bbox)
275
+
276
+ det_center = (detection.bbox[:2] + detection.bbox[2:]) / 2
277
+ slot_center = (slot.bbox[:2] + slot.bbox[2:]) / 2
278
+ dist = np.linalg.norm(det_center - slot_center)
279
+
280
+ if iou > self.iou_threshold:
281
+ score = iou + (0.2 if slot.operator_id == predicted_op else 0)
282
+ elif dist < 150 and recency < 30:
283
+ score = 0.1 + (0.2 if slot.operator_id == predicted_op else 0)
284
+ else:
285
+ continue
286
+
287
+ if score > best_score:
288
+ best_score = score
289
+ best_slot = slot
290
+
291
+ if best_slot:
292
+ return best_slot
293
+
294
+ if direction_confident:
295
+ for slot in self.grasper_slots:
296
+ if slot.active or slot.bbox is None:
297
+ continue
298
+ if slot.operator_id == predicted_op and slot.frames_since_seen(self.frame_count) < 75:
299
+ return slot
300
+
301
+ if not direction_confident:
302
+ for slot in self.grasper_slots:
303
+ if slot.active or slot.bbox is None:
304
+ continue
305
+ if slot.frames_since_seen(self.frame_count) < 30:
306
+ det_center = (detection.bbox[:2] + detection.bbox[2:]) / 2
307
+ slot_center = (slot.bbox[:2] + slot.bbox[2:]) / 2
308
+ dist = np.linalg.norm(det_center - slot_center)
309
+ if dist < 100:
310
+ return slot
311
+
312
+ for slot in self.grasper_slots:
313
+ if not slot.active:
314
+ slot.track_id = self.next_track_id
315
+ self.next_track_id += 1
316
+ return slot
317
+
318
+ worst_slot = None
319
+ worst_iou = 1.0
320
+ for slot in self.grasper_slots:
321
+ iou = self._compute_iou(detection.bbox, slot.bbox)
322
+ if iou < worst_iou:
323
+ worst_iou = iou
324
+ worst_slot = slot
325
+
326
+ if worst_slot:
327
+ worst_slot.track_id = self.next_track_id
328
+ self.next_track_id += 1
329
+ return worst_slot
330
+
331
+ return None
332
+
333
+ def update(self, frame: np.ndarray, detections: List[Detection]) -> List[OperatorSlot]:
334
+ self.frame_count += 1
335
+
336
+ all_slots = self.grasper_slots + list(self.class_slots.values())
337
+ for slot in all_slots:
338
+ if slot.active and slot.frames_since_seen(self.frame_count) > 150:
339
+ slot.mark_inactive()
340
+
341
+ if len(detections) == 0:
342
+ return self._get_active_slots()
343
+
344
+ detection_info = []
345
+ for det in detections:
346
+ pred_op, probs = self._get_direction_prediction(frame, det.bbox)
347
+ detection_info.append((det, pred_op, probs))
348
+
349
+ detection_info.sort(key=lambda x: -x[0].confidence)
350
+
351
+ assigned_slots = set()
352
+
353
+ for det, pred_op, probs in detection_info:
354
+ slot = self._find_best_slot(det, pred_op, probs)
355
+
356
+ if slot and slot.track_id not in assigned_slots:
357
+ slot.update(det, probs, self.frame_count)
358
+ if det.class_id == self.GRASPER_CLASS_ID:
359
+ slot.operator_id = pred_op
360
+ assigned_slots.add(slot.track_id)
361
+
362
+ return self._get_active_slots()
363
+
364
+ def _get_active_slots(self) -> List[OperatorSlot]:
365
+ active = []
366
+ for slot in self.grasper_slots:
367
+ if slot.active and slot.last_seen_frame == self.frame_count:
368
+ active.append(slot)
369
+ for slot in self.class_slots.values():
370
+ if slot.active and slot.last_seen_frame == self.frame_count:
371
+ active.append(slot)
372
+ return active
373
+
374
+ def reset(self):
375
+ self.grasper_slots = []
376
+ self.class_slots = {}
377
+ self.next_track_id = 1
378
+ self.frame_count = 0
379
+ self._initialize_slots()