| | |
| | |
| | import copy |
| | import numpy as np |
| | from typing import Dict |
| | import torch |
| | from scipy.optimize import linear_sum_assignment |
| |
|
| | from detectron2.config import configurable |
| | from detectron2.structures import Boxes, Instances |
| |
|
| | from ..config.config import CfgNode as CfgNode_ |
| | from .base_tracker import BaseTracker |
| |
|
| |
|
| | class BaseHungarianTracker(BaseTracker): |
| | """ |
| | A base class for all Hungarian trackers |
| | """ |
| |
|
| | @configurable |
| | def __init__( |
| | self, |
| | video_height: int, |
| | video_width: int, |
| | max_num_instances: int = 200, |
| | max_lost_frame_count: int = 0, |
| | min_box_rel_dim: float = 0.02, |
| | min_instance_period: int = 1, |
| | **kwargs |
| | ): |
| | """ |
| | Args: |
| | video_height: height the video frame |
| | video_width: width of the video frame |
| | max_num_instances: maximum number of id allowed to be tracked |
| | max_lost_frame_count: maximum number of frame an id can lost tracking |
| | exceed this number, an id is considered as lost |
| | forever |
| | min_box_rel_dim: a percentage, smaller than this dimension, a bbox is |
| | removed from tracking |
| | min_instance_period: an instance will be shown after this number of period |
| | since its first showing up in the video |
| | """ |
| | super().__init__(**kwargs) |
| | self._video_height = video_height |
| | self._video_width = video_width |
| | self._max_num_instances = max_num_instances |
| | self._max_lost_frame_count = max_lost_frame_count |
| | self._min_box_rel_dim = min_box_rel_dim |
| | self._min_instance_period = min_instance_period |
| |
|
| | @classmethod |
| | def from_config(cls, cfg: CfgNode_) -> Dict: |
| | raise NotImplementedError("Calling HungarianTracker::from_config") |
| |
|
| | def build_cost_matrix(self, instances: Instances, prev_instances: Instances) -> np.ndarray: |
| | raise NotImplementedError("Calling HungarianTracker::build_matrix") |
| |
|
| | def update(self, instances: Instances) -> Instances: |
| | if instances.has("pred_keypoints"): |
| | raise NotImplementedError("Need to add support for keypoints") |
| | instances = self._initialize_extra_fields(instances) |
| | if self._prev_instances is not None: |
| | self._untracked_prev_idx = set(range(len(self._prev_instances))) |
| | cost_matrix = self.build_cost_matrix(instances, self._prev_instances) |
| | matched_idx, matched_prev_idx = linear_sum_assignment(cost_matrix) |
| | instances = self._process_matched_idx(instances, matched_idx, matched_prev_idx) |
| | instances = self._process_unmatched_idx(instances, matched_idx) |
| | instances = self._process_unmatched_prev_idx(instances, matched_prev_idx) |
| | self._prev_instances = copy.deepcopy(instances) |
| | return instances |
| |
|
| | def _initialize_extra_fields(self, instances: Instances) -> Instances: |
| | """ |
| | If input instances don't have ID, ID_period, lost_frame_count fields, |
| | this method is used to initialize these fields. |
| | |
| | Args: |
| | instances: D2 Instances, for predictions of the current frame |
| | Return: |
| | D2 Instances with extra fields added |
| | """ |
| | if not instances.has("ID"): |
| | instances.set("ID", [None] * len(instances)) |
| | if not instances.has("ID_period"): |
| | instances.set("ID_period", [None] * len(instances)) |
| | if not instances.has("lost_frame_count"): |
| | instances.set("lost_frame_count", [None] * len(instances)) |
| | if self._prev_instances is None: |
| | instances.ID = list(range(len(instances))) |
| | self._id_count += len(instances) |
| | instances.ID_period = [1] * len(instances) |
| | instances.lost_frame_count = [0] * len(instances) |
| | return instances |
| |
|
| | def _process_matched_idx( |
| | self, instances: Instances, matched_idx: np.ndarray, matched_prev_idx: np.ndarray |
| | ) -> Instances: |
| | assert matched_idx.size == matched_prev_idx.size |
| | for i in range(matched_idx.size): |
| | instances.ID[matched_idx[i]] = self._prev_instances.ID[matched_prev_idx[i]] |
| | instances.ID_period[matched_idx[i]] = ( |
| | self._prev_instances.ID_period[matched_prev_idx[i]] + 1 |
| | ) |
| | instances.lost_frame_count[matched_idx[i]] = 0 |
| | return instances |
| |
|
| | def _process_unmatched_idx(self, instances: Instances, matched_idx: np.ndarray) -> Instances: |
| | untracked_idx = set(range(len(instances))).difference(set(matched_idx)) |
| | for idx in untracked_idx: |
| | instances.ID[idx] = self._id_count |
| | self._id_count += 1 |
| | instances.ID_period[idx] = 1 |
| | instances.lost_frame_count[idx] = 0 |
| | return instances |
| |
|
| | def _process_unmatched_prev_idx( |
| | self, instances: Instances, matched_prev_idx: np.ndarray |
| | ) -> Instances: |
| | untracked_instances = Instances( |
| | image_size=instances.image_size, |
| | pred_boxes=[], |
| | pred_masks=[], |
| | pred_classes=[], |
| | scores=[], |
| | ID=[], |
| | ID_period=[], |
| | lost_frame_count=[], |
| | ) |
| | prev_bboxes = list(self._prev_instances.pred_boxes) |
| | prev_classes = list(self._prev_instances.pred_classes) |
| | prev_scores = list(self._prev_instances.scores) |
| | prev_ID_period = self._prev_instances.ID_period |
| | if instances.has("pred_masks"): |
| | prev_masks = list(self._prev_instances.pred_masks) |
| | untracked_prev_idx = set(range(len(self._prev_instances))).difference(set(matched_prev_idx)) |
| | for idx in untracked_prev_idx: |
| | x_left, y_top, x_right, y_bot = prev_bboxes[idx] |
| | if ( |
| | (1.0 * (x_right - x_left) / self._video_width < self._min_box_rel_dim) |
| | or (1.0 * (y_bot - y_top) / self._video_height < self._min_box_rel_dim) |
| | or self._prev_instances.lost_frame_count[idx] >= self._max_lost_frame_count |
| | or prev_ID_period[idx] <= self._min_instance_period |
| | ): |
| | continue |
| | untracked_instances.pred_boxes.append(list(prev_bboxes[idx].numpy())) |
| | untracked_instances.pred_classes.append(int(prev_classes[idx])) |
| | untracked_instances.scores.append(float(prev_scores[idx])) |
| | untracked_instances.ID.append(self._prev_instances.ID[idx]) |
| | untracked_instances.ID_period.append(self._prev_instances.ID_period[idx]) |
| | untracked_instances.lost_frame_count.append( |
| | self._prev_instances.lost_frame_count[idx] + 1 |
| | ) |
| | if instances.has("pred_masks"): |
| | untracked_instances.pred_masks.append(prev_masks[idx].numpy().astype(np.uint8)) |
| |
|
| | untracked_instances.pred_boxes = Boxes(torch.FloatTensor(untracked_instances.pred_boxes)) |
| | untracked_instances.pred_classes = torch.IntTensor(untracked_instances.pred_classes) |
| | untracked_instances.scores = torch.FloatTensor(untracked_instances.scores) |
| | if instances.has("pred_masks"): |
| | untracked_instances.pred_masks = torch.IntTensor(untracked_instances.pred_masks) |
| | else: |
| | untracked_instances.remove("pred_masks") |
| |
|
| | return Instances.cat( |
| | [ |
| | instances, |
| | untracked_instances, |
| | ] |
| | ) |
| |
|