File size: 2,030 Bytes
c6abe34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import numpy as np
import torch
import cv2
import os
from typing import Tuple, List, Optional
from ultralytics import SAM

class SAM2Tracker:
    """
    Advanced SAM2 mask propagation for pixel-perfect tracking.
    This provides much more stable tracking during occlusions and 
    scrimmages by tracking the actual shape of the player.
    """

    def __init__(self, model_path: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
        self.device = device
        self.model_path = model_path
        self.predictor = None
        self._prompted = False
        self._use_realtime = False
        self._ultralytics_model = None

        # Build ultralytics model as default robust fallback
        try:
            print(f"Loading SAM2 (ultralytics): {model_path}")
            self._ultralytics_model = SAM(model_path)
        except Exception as e:
            print(f"⚠️ Warning: Could not load SAM2: {e}")

    def segment_frame(self, frame: np.ndarray, boxes: np.ndarray) -> List[np.ndarray]:
        """
        Segment objects in a single frame given their bounding boxes.
        
        Args:
            frame: BGR image.
            boxes: (N, 4) xyxy bounding boxes.
            
        Returns:
            List of boolean masks.
        """
        if self._ultralytics_model is None or len(boxes) == 0:
            return []

        h, w = frame.shape[:2]
        # In YOLO/SAM, bboxes should be xyxy
        results = self._ultralytics_model(
            frame,
            bboxes=boxes,
            device=self.device,
            verbose=False,
        )

        masks = []
        if results and results[0].masks is not None:
            for m in results[0].masks.data:
                mask = m.cpu().numpy().astype(np.uint8)
                if mask.shape[:2] != (h, w):
                    mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
                masks.append(mask.astype(bool))

        return masks

    def reset(self):
        self._prompted = False