Spaces:
Runtime error
Runtime error
| """ | |
| SurgiTrack - Tracker Module (Simplified for HF Space) | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision import models | |
| import numpy as np | |
| from scipy.optimize import linear_sum_assignment | |
| from dataclasses import dataclass, field | |
| from typing import List, Dict, Optional | |
| import cv2 | |
| CLASS_NAMES = ['grasper', 'bipolar', 'hook', 'scissors', 'clipper', 'irrigator', 'specimenbag'] | |
| OPERATORS = ['MSLH', 'MSRH', 'ASRH', 'NULL'] | |
| class CoordinateAttention(nn.Module): | |
| def __init__(self, in_channels, reduction=32): | |
| super().__init__() | |
| self.pool_h = nn.AdaptiveAvgPool2d((None, 1)) | |
| self.pool_w = nn.AdaptiveAvgPool2d((1, None)) | |
| mid_channels = max(8, in_channels // reduction) | |
| self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1) | |
| self.bn1 = nn.BatchNorm2d(mid_channels) | |
| self.act = nn.ReLU(inplace=True) | |
| self.conv_h = nn.Conv2d(mid_channels, in_channels, kernel_size=1) | |
| self.conv_w = nn.Conv2d(mid_channels, in_channels, kernel_size=1) | |
| def forward(self, x): | |
| B, C, H, W = x.shape | |
| x_h = self.pool_h(x) | |
| x_w = self.pool_w(x).permute(0, 1, 3, 2) | |
| y = torch.cat([x_h, x_w], dim=2) | |
| y = self.act(self.bn1(self.conv1(y))) | |
| x_h, x_w = torch.split(y, [H, W], dim=2) | |
| x_w = x_w.permute(0, 1, 3, 2) | |
| a_h = self.conv_h(x_h).sigmoid() | |
| a_w = self.conv_w(x_w).sigmoid() | |
| return x * a_h * a_w | |
| class DirectionEstimator(nn.Module): | |
| def __init__(self, num_classes=4, embedding_dim=128, pretrained=True): | |
| super().__init__() | |
| self.backbone = models.efficientnet_b0( | |
| weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1 if pretrained else None | |
| ) | |
| backbone_out = self.backbone.classifier[1].in_features | |
| self.backbone.classifier = nn.Identity() | |
| self.coord_attention = CoordinateAttention(backbone_out) | |
| self.embedding_head = nn.Sequential( | |
| nn.Linear(backbone_out, 512), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(0.3), | |
| nn.Linear(512, embedding_dim) | |
| ) | |
| self.direction_head = nn.Sequential( | |
| nn.Linear(embedding_dim, 64), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(0.2), | |
| nn.Linear(64, num_classes) | |
| ) | |
| self.embedding_dim = embedding_dim | |
| def forward(self, x, return_embedding=False): | |
| features = self.backbone.features(x) | |
| features = self.coord_attention(features) | |
| features = self.backbone.avgpool(features) | |
| features = features.flatten(1) | |
| embedding = self.embedding_head(features) | |
| embedding = F.normalize(embedding, p=2, dim=1) | |
| direction = self.direction_head(embedding) | |
| if return_embedding: | |
| return direction, embedding | |
| return direction | |
| class Detection: | |
| bbox: np.ndarray | |
| class_id: int | |
| class_name: str | |
| confidence: float | |
| frame_id: int | |
| class OperatorSlot: | |
| operator_id: int | |
| operator_name: str | |
| track_id: int | |
| active: bool = False | |
| class_id: int = -1 | |
| class_name: str = "" | |
| bbox: np.ndarray = None | |
| confidence: float = 0.0 | |
| embedding: np.ndarray = None | |
| last_seen_frame: int = -1 | |
| total_detections: int = 0 | |
| bbox_history: List[np.ndarray] = field(default_factory=list) | |
| class_history: List[int] = field(default_factory=list) | |
| def update(self, detection: Detection, embedding: np.ndarray, frame_id: int): | |
| self.active = True | |
| self.bbox = detection.bbox | |
| self.class_id = detection.class_id | |
| self.class_name = detection.class_name | |
| self.confidence = detection.confidence | |
| self.embedding = embedding | |
| self.last_seen_frame = frame_id | |
| self.total_detections += 1 | |
| self.bbox_history.append(detection.bbox.copy()) | |
| self.class_history.append(detection.class_id) | |
| if len(self.bbox_history) > 100: | |
| self.bbox_history.pop(0) | |
| self.class_history.pop(0) | |
| def mark_inactive(self): | |
| self.active = False | |
| def frames_since_seen(self, current_frame: int) -> int: | |
| if self.last_seen_frame < 0: | |
| return float('inf') | |
| return current_frame - self.last_seen_frame | |
| class OperatorBasedTracker: | |
| MAX_GRASPERS = 3 | |
| GRASPER_CLASS_ID = 0 | |
| SINGLE_INSTANCE_CLASSES = {1, 2, 3, 4, 5, 6} | |
| def __init__( | |
| self, | |
| direction_model: DirectionEstimator = None, | |
| max_inactive_frames: int = 300, | |
| iou_threshold: float = 0.3, | |
| direction_confidence_threshold: float = 0.5, | |
| device: str = "cuda" | |
| ): | |
| self.direction_model = direction_model | |
| self.max_inactive_frames = max_inactive_frames | |
| self.iou_threshold = iou_threshold | |
| self.direction_confidence_threshold = direction_confidence_threshold | |
| self.device = device | |
| self.grasper_slots: List[OperatorSlot] = [] | |
| self.class_slots: Dict[int, OperatorSlot] = {} | |
| self.next_track_id = 1 | |
| self.frame_count = 0 | |
| self._initialize_slots() | |
| if self.direction_model is not None: | |
| self.direction_model.to(device) | |
| self.direction_model.eval() | |
| def _initialize_slots(self): | |
| for i in range(self.MAX_GRASPERS): | |
| slot = OperatorSlot( | |
| operator_id=-1, | |
| operator_name=f"grasper_{i+1}", | |
| track_id=self.next_track_id | |
| ) | |
| slot.class_id = self.GRASPER_CLASS_ID | |
| slot.class_name = 'grasper' | |
| self.next_track_id += 1 | |
| self.grasper_slots.append(slot) | |
| for class_id in self.SINGLE_INSTANCE_CLASSES: | |
| slot = OperatorSlot( | |
| operator_id=3, | |
| operator_name=f"CLASS_{CLASS_NAMES[class_id]}", | |
| track_id=self.next_track_id | |
| ) | |
| slot.class_id = class_id | |
| slot.class_name = CLASS_NAMES[class_id] | |
| self.next_track_id += 1 | |
| self.class_slots[class_id] = slot | |
| def _get_direction_prediction(self, frame: np.ndarray, bbox: np.ndarray): | |
| if self.direction_model is None: | |
| return 3, np.array([0.25, 0.25, 0.25, 0.25]) | |
| x1, y1, x2, y2 = bbox.astype(int) | |
| h, w = frame.shape[:2] | |
| pad_x = int((x2 - x1) * 0.3) | |
| pad_y = int((y2 - y1) * 0.5) | |
| x1 = max(0, x1 - pad_x) | |
| y1 = max(0, y1 - pad_y) | |
| x2 = min(w, x2 + pad_x) | |
| y2 = min(h, y2 + pad_y) | |
| crop = frame[y1:y2, x1:x2] | |
| if crop.size == 0: | |
| return 3, np.array([0.25, 0.25, 0.25, 0.25]) | |
| crop = cv2.resize(crop, (224, 224)) | |
| crop = crop.astype(np.float32) / 255.0 | |
| crop = (crop - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225] | |
| crop = torch.from_numpy(crop).permute(2, 0, 1).unsqueeze(0).float().to(self.device) | |
| with torch.no_grad(): | |
| logits, embedding = self.direction_model(crop, return_embedding=True) | |
| probs = F.softmax(logits, dim=1).cpu().numpy()[0] | |
| return np.argmax(probs), probs | |
| def _compute_iou(self, bbox1: np.ndarray, bbox2: np.ndarray) -> float: | |
| if bbox1 is None or bbox2 is None: | |
| return 0.0 | |
| x1 = max(bbox1[0], bbox2[0]) | |
| y1 = max(bbox1[1], bbox2[1]) | |
| x2 = min(bbox1[2], bbox2[2]) | |
| y2 = min(bbox1[3], bbox2[3]) | |
| inter = max(0, x2 - x1) * max(0, y2 - y1) | |
| area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]) | |
| area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]) | |
| union = area1 + area2 - inter | |
| return inter / (union + 1e-6) | |
| def _find_best_slot(self, detection: Detection, predicted_op: int, direction_probs: np.ndarray) -> Optional[OperatorSlot]: | |
| class_id = detection.class_id | |
| if class_id in self.SINGLE_INSTANCE_CLASSES: | |
| slot = self.class_slots.get(class_id) | |
| if slot: | |
| recency = slot.frames_since_seen(self.frame_count) | |
| if not slot.active and recency >= 75: | |
| slot.track_id = self.next_track_id | |
| self.next_track_id += 1 | |
| return slot | |
| if class_id == self.GRASPER_CLASS_ID: | |
| direction_confident = predicted_op < 3 and direction_probs[predicted_op] > self.direction_confidence_threshold | |
| best_slot = None | |
| best_score = -1 | |
| for slot in self.grasper_slots: | |
| if slot.bbox is None: | |
| continue | |
| recency = slot.frames_since_seen(self.frame_count) | |
| if recency >= 75: | |
| continue | |
| iou = self._compute_iou(detection.bbox, slot.bbox) | |
| det_center = (detection.bbox[:2] + detection.bbox[2:]) / 2 | |
| slot_center = (slot.bbox[:2] + slot.bbox[2:]) / 2 | |
| dist = np.linalg.norm(det_center - slot_center) | |
| if iou > self.iou_threshold: | |
| score = iou + (0.2 if slot.operator_id == predicted_op else 0) | |
| elif dist < 150 and recency < 30: | |
| score = 0.1 + (0.2 if slot.operator_id == predicted_op else 0) | |
| else: | |
| continue | |
| if score > best_score: | |
| best_score = score | |
| best_slot = slot | |
| if best_slot: | |
| return best_slot | |
| if direction_confident: | |
| for slot in self.grasper_slots: | |
| if slot.active or slot.bbox is None: | |
| continue | |
| if slot.operator_id == predicted_op and slot.frames_since_seen(self.frame_count) < 75: | |
| return slot | |
| if not direction_confident: | |
| for slot in self.grasper_slots: | |
| if slot.active or slot.bbox is None: | |
| continue | |
| if slot.frames_since_seen(self.frame_count) < 30: | |
| det_center = (detection.bbox[:2] + detection.bbox[2:]) / 2 | |
| slot_center = (slot.bbox[:2] + slot.bbox[2:]) / 2 | |
| dist = np.linalg.norm(det_center - slot_center) | |
| if dist < 100: | |
| return slot | |
| for slot in self.grasper_slots: | |
| if not slot.active: | |
| slot.track_id = self.next_track_id | |
| self.next_track_id += 1 | |
| return slot | |
| worst_slot = None | |
| worst_iou = 1.0 | |
| for slot in self.grasper_slots: | |
| iou = self._compute_iou(detection.bbox, slot.bbox) | |
| if iou < worst_iou: | |
| worst_iou = iou | |
| worst_slot = slot | |
| if worst_slot: | |
| worst_slot.track_id = self.next_track_id | |
| self.next_track_id += 1 | |
| return worst_slot | |
| return None | |
| def update(self, frame: np.ndarray, detections: List[Detection]) -> List[OperatorSlot]: | |
| self.frame_count += 1 | |
| all_slots = self.grasper_slots + list(self.class_slots.values()) | |
| for slot in all_slots: | |
| if slot.active and slot.frames_since_seen(self.frame_count) > 150: | |
| slot.mark_inactive() | |
| if len(detections) == 0: | |
| return self._get_active_slots() | |
| detection_info = [] | |
| for det in detections: | |
| pred_op, probs = self._get_direction_prediction(frame, det.bbox) | |
| detection_info.append((det, pred_op, probs)) | |
| detection_info.sort(key=lambda x: -x[0].confidence) | |
| assigned_slots = set() | |
| for det, pred_op, probs in detection_info: | |
| slot = self._find_best_slot(det, pred_op, probs) | |
| if slot and slot.track_id not in assigned_slots: | |
| slot.update(det, probs, self.frame_count) | |
| if det.class_id == self.GRASPER_CLASS_ID: | |
| slot.operator_id = pred_op | |
| assigned_slots.add(slot.track_id) | |
| return self._get_active_slots() | |
| def _get_active_slots(self) -> List[OperatorSlot]: | |
| active = [] | |
| for slot in self.grasper_slots: | |
| if slot.active and slot.last_seen_frame == self.frame_count: | |
| active.append(slot) | |
| for slot in self.class_slots.values(): | |
| if slot.active and slot.last_seen_frame == self.frame_count: | |
| active.append(slot) | |
| return active | |
| def reset(self): | |
| self.grasper_slots = [] | |
| self.class_slots = {} | |
| self.next_track_id = 1 | |
| self.frame_count = 0 | |
| self._initialize_slots() |