""" SurgiTrack - Tracker Module (Simplified for HF Space) """ import torch import torch.nn as nn import torch.nn.functional as F from torchvision import models import numpy as np from scipy.optimize import linear_sum_assignment from dataclasses import dataclass, field from typing import List, Dict, Optional import cv2 CLASS_NAMES = ['grasper', 'bipolar', 'hook', 'scissors', 'clipper', 'irrigator', 'specimenbag'] OPERATORS = ['MSLH', 'MSRH', 'ASRH', 'NULL'] class CoordinateAttention(nn.Module): def __init__(self, in_channels, reduction=32): super().__init__() self.pool_h = nn.AdaptiveAvgPool2d((None, 1)) self.pool_w = nn.AdaptiveAvgPool2d((1, None)) mid_channels = max(8, in_channels // reduction) self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1) self.bn1 = nn.BatchNorm2d(mid_channels) self.act = nn.ReLU(inplace=True) self.conv_h = nn.Conv2d(mid_channels, in_channels, kernel_size=1) self.conv_w = nn.Conv2d(mid_channels, in_channels, kernel_size=1) def forward(self, x): B, C, H, W = x.shape x_h = self.pool_h(x) x_w = self.pool_w(x).permute(0, 1, 3, 2) y = torch.cat([x_h, x_w], dim=2) y = self.act(self.bn1(self.conv1(y))) x_h, x_w = torch.split(y, [H, W], dim=2) x_w = x_w.permute(0, 1, 3, 2) a_h = self.conv_h(x_h).sigmoid() a_w = self.conv_w(x_w).sigmoid() return x * a_h * a_w class DirectionEstimator(nn.Module): def __init__(self, num_classes=4, embedding_dim=128, pretrained=True): super().__init__() self.backbone = models.efficientnet_b0( weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1 if pretrained else None ) backbone_out = self.backbone.classifier[1].in_features self.backbone.classifier = nn.Identity() self.coord_attention = CoordinateAttention(backbone_out) self.embedding_head = nn.Sequential( nn.Linear(backbone_out, 512), nn.ReLU(inplace=True), nn.Dropout(0.3), nn.Linear(512, embedding_dim) ) self.direction_head = nn.Sequential( nn.Linear(embedding_dim, 64), nn.ReLU(inplace=True), nn.Dropout(0.2), nn.Linear(64, num_classes) ) self.embedding_dim = embedding_dim def forward(self, x, return_embedding=False): features = self.backbone.features(x) features = self.coord_attention(features) features = self.backbone.avgpool(features) features = features.flatten(1) embedding = self.embedding_head(features) embedding = F.normalize(embedding, p=2, dim=1) direction = self.direction_head(embedding) if return_embedding: return direction, embedding return direction @dataclass class Detection: bbox: np.ndarray class_id: int class_name: str confidence: float frame_id: int @dataclass class OperatorSlot: operator_id: int operator_name: str track_id: int active: bool = False class_id: int = -1 class_name: str = "" bbox: np.ndarray = None confidence: float = 0.0 embedding: np.ndarray = None last_seen_frame: int = -1 total_detections: int = 0 bbox_history: List[np.ndarray] = field(default_factory=list) class_history: List[int] = field(default_factory=list) def update(self, detection: Detection, embedding: np.ndarray, frame_id: int): self.active = True self.bbox = detection.bbox self.class_id = detection.class_id self.class_name = detection.class_name self.confidence = detection.confidence self.embedding = embedding self.last_seen_frame = frame_id self.total_detections += 1 self.bbox_history.append(detection.bbox.copy()) self.class_history.append(detection.class_id) if len(self.bbox_history) > 100: self.bbox_history.pop(0) self.class_history.pop(0) def mark_inactive(self): self.active = False def frames_since_seen(self, current_frame: int) -> int: if self.last_seen_frame < 0: return float('inf') return current_frame - self.last_seen_frame class OperatorBasedTracker: MAX_GRASPERS = 3 GRASPER_CLASS_ID = 0 SINGLE_INSTANCE_CLASSES = {1, 2, 3, 4, 5, 6} def __init__( self, direction_model: DirectionEstimator = None, max_inactive_frames: int = 300, iou_threshold: float = 0.3, direction_confidence_threshold: float = 0.5, device: str = "cuda" ): self.direction_model = direction_model self.max_inactive_frames = max_inactive_frames self.iou_threshold = iou_threshold self.direction_confidence_threshold = direction_confidence_threshold self.device = device self.grasper_slots: List[OperatorSlot] = [] self.class_slots: Dict[int, OperatorSlot] = {} self.next_track_id = 1 self.frame_count = 0 self._initialize_slots() if self.direction_model is not None: self.direction_model.to(device) self.direction_model.eval() def _initialize_slots(self): for i in range(self.MAX_GRASPERS): slot = OperatorSlot( operator_id=-1, operator_name=f"grasper_{i+1}", track_id=self.next_track_id ) slot.class_id = self.GRASPER_CLASS_ID slot.class_name = 'grasper' self.next_track_id += 1 self.grasper_slots.append(slot) for class_id in self.SINGLE_INSTANCE_CLASSES: slot = OperatorSlot( operator_id=3, operator_name=f"CLASS_{CLASS_NAMES[class_id]}", track_id=self.next_track_id ) slot.class_id = class_id slot.class_name = CLASS_NAMES[class_id] self.next_track_id += 1 self.class_slots[class_id] = slot def _get_direction_prediction(self, frame: np.ndarray, bbox: np.ndarray): if self.direction_model is None: return 3, np.array([0.25, 0.25, 0.25, 0.25]) x1, y1, x2, y2 = bbox.astype(int) h, w = frame.shape[:2] pad_x = int((x2 - x1) * 0.3) pad_y = int((y2 - y1) * 0.5) x1 = max(0, x1 - pad_x) y1 = max(0, y1 - pad_y) x2 = min(w, x2 + pad_x) y2 = min(h, y2 + pad_y) crop = frame[y1:y2, x1:x2] if crop.size == 0: return 3, np.array([0.25, 0.25, 0.25, 0.25]) crop = cv2.resize(crop, (224, 224)) crop = crop.astype(np.float32) / 255.0 crop = (crop - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225] crop = torch.from_numpy(crop).permute(2, 0, 1).unsqueeze(0).float().to(self.device) with torch.no_grad(): logits, embedding = self.direction_model(crop, return_embedding=True) probs = F.softmax(logits, dim=1).cpu().numpy()[0] return np.argmax(probs), probs def _compute_iou(self, bbox1: np.ndarray, bbox2: np.ndarray) -> float: if bbox1 is None or bbox2 is None: return 0.0 x1 = max(bbox1[0], bbox2[0]) y1 = max(bbox1[1], bbox2[1]) x2 = min(bbox1[2], bbox2[2]) y2 = min(bbox1[3], bbox2[3]) inter = max(0, x2 - x1) * max(0, y2 - y1) area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]) area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]) union = area1 + area2 - inter return inter / (union + 1e-6) def _find_best_slot(self, detection: Detection, predicted_op: int, direction_probs: np.ndarray) -> Optional[OperatorSlot]: class_id = detection.class_id if class_id in self.SINGLE_INSTANCE_CLASSES: slot = self.class_slots.get(class_id) if slot: recency = slot.frames_since_seen(self.frame_count) if not slot.active and recency >= 75: slot.track_id = self.next_track_id self.next_track_id += 1 return slot if class_id == self.GRASPER_CLASS_ID: direction_confident = predicted_op < 3 and direction_probs[predicted_op] > self.direction_confidence_threshold best_slot = None best_score = -1 for slot in self.grasper_slots: if slot.bbox is None: continue recency = slot.frames_since_seen(self.frame_count) if recency >= 75: continue iou = self._compute_iou(detection.bbox, slot.bbox) det_center = (detection.bbox[:2] + detection.bbox[2:]) / 2 slot_center = (slot.bbox[:2] + slot.bbox[2:]) / 2 dist = np.linalg.norm(det_center - slot_center) if iou > self.iou_threshold: score = iou + (0.2 if slot.operator_id == predicted_op else 0) elif dist < 150 and recency < 30: score = 0.1 + (0.2 if slot.operator_id == predicted_op else 0) else: continue if score > best_score: best_score = score best_slot = slot if best_slot: return best_slot if direction_confident: for slot in self.grasper_slots: if slot.active or slot.bbox is None: continue if slot.operator_id == predicted_op and slot.frames_since_seen(self.frame_count) < 75: return slot if not direction_confident: for slot in self.grasper_slots: if slot.active or slot.bbox is None: continue if slot.frames_since_seen(self.frame_count) < 30: det_center = (detection.bbox[:2] + detection.bbox[2:]) / 2 slot_center = (slot.bbox[:2] + slot.bbox[2:]) / 2 dist = np.linalg.norm(det_center - slot_center) if dist < 100: return slot for slot in self.grasper_slots: if not slot.active: slot.track_id = self.next_track_id self.next_track_id += 1 return slot worst_slot = None worst_iou = 1.0 for slot in self.grasper_slots: iou = self._compute_iou(detection.bbox, slot.bbox) if iou < worst_iou: worst_iou = iou worst_slot = slot if worst_slot: worst_slot.track_id = self.next_track_id self.next_track_id += 1 return worst_slot return None def update(self, frame: np.ndarray, detections: List[Detection]) -> List[OperatorSlot]: self.frame_count += 1 all_slots = self.grasper_slots + list(self.class_slots.values()) for slot in all_slots: if slot.active and slot.frames_since_seen(self.frame_count) > 150: slot.mark_inactive() if len(detections) == 0: return self._get_active_slots() detection_info = [] for det in detections: pred_op, probs = self._get_direction_prediction(frame, det.bbox) detection_info.append((det, pred_op, probs)) detection_info.sort(key=lambda x: -x[0].confidence) assigned_slots = set() for det, pred_op, probs in detection_info: slot = self._find_best_slot(det, pred_op, probs) if slot and slot.track_id not in assigned_slots: slot.update(det, probs, self.frame_count) if det.class_id == self.GRASPER_CLASS_ID: slot.operator_id = pred_op assigned_slots.add(slot.track_id) return self._get_active_slots() def _get_active_slots(self) -> List[OperatorSlot]: active = [] for slot in self.grasper_slots: if slot.active and slot.last_seen_frame == self.frame_count: active.append(slot) for slot in self.class_slots.values(): if slot.active and slot.last_seen_frame == self.frame_count: active.append(slot) return active def reset(self): self.grasper_slots = [] self.class_slots = {} self.next_track_id = 1 self.frame_count = 0 self._initialize_slots()