| | |
| | |
| | import copy |
| | import numpy as np |
| | from typing import List |
| | import torch |
| |
|
| | from detectron2.config import configurable |
| | from detectron2.structures import Boxes, Instances |
| | from detectron2.structures.boxes import pairwise_iou |
| |
|
| | from ..config.config import CfgNode as CfgNode_ |
| | from .base_tracker import TRACKER_HEADS_REGISTRY, BaseTracker |
| |
|
| |
|
| | @TRACKER_HEADS_REGISTRY.register() |
| | class BBoxIOUTracker(BaseTracker): |
| | """ |
| | A bounding box tracker to assign ID based on IoU between current and previous instances |
| | """ |
| |
|
| | @configurable |
| | def __init__( |
| | self, |
| | *, |
| | video_height: int, |
| | video_width: int, |
| | max_num_instances: int = 200, |
| | max_lost_frame_count: int = 0, |
| | min_box_rel_dim: float = 0.02, |
| | min_instance_period: int = 1, |
| | track_iou_threshold: float = 0.5, |
| | **kwargs, |
| | ): |
| | """ |
| | Args: |
| | video_height: height the video frame |
| | video_width: width of the video frame |
| | max_num_instances: maximum number of id allowed to be tracked |
| | max_lost_frame_count: maximum number of frame an id can lost tracking |
| | exceed this number, an id is considered as lost |
| | forever |
| | min_box_rel_dim: a percentage, smaller than this dimension, a bbox is |
| | removed from tracking |
| | min_instance_period: an instance will be shown after this number of period |
| | since its first showing up in the video |
| | track_iou_threshold: iou threshold, below this number a bbox pair is removed |
| | from tracking |
| | """ |
| | super().__init__(**kwargs) |
| | self._video_height = video_height |
| | self._video_width = video_width |
| | self._max_num_instances = max_num_instances |
| | self._max_lost_frame_count = max_lost_frame_count |
| | self._min_box_rel_dim = min_box_rel_dim |
| | self._min_instance_period = min_instance_period |
| | self._track_iou_threshold = track_iou_threshold |
| |
|
| | @classmethod |
| | def from_config(cls, cfg: CfgNode_): |
| | """ |
| | Old style initialization using CfgNode |
| | |
| | Args: |
| | cfg: D2 CfgNode, config file |
| | Return: |
| | dictionary storing arguments for __init__ method |
| | """ |
| | assert "VIDEO_HEIGHT" in cfg.TRACKER_HEADS |
| | assert "VIDEO_WIDTH" in cfg.TRACKER_HEADS |
| | video_height = cfg.TRACKER_HEADS.get("VIDEO_HEIGHT") |
| | video_width = cfg.TRACKER_HEADS.get("VIDEO_WIDTH") |
| | max_num_instances = cfg.TRACKER_HEADS.get("MAX_NUM_INSTANCES", 200) |
| | max_lost_frame_count = cfg.TRACKER_HEADS.get("MAX_LOST_FRAME_COUNT", 0) |
| | min_box_rel_dim = cfg.TRACKER_HEADS.get("MIN_BOX_REL_DIM", 0.02) |
| | min_instance_period = cfg.TRACKER_HEADS.get("MIN_INSTANCE_PERIOD", 1) |
| | track_iou_threshold = cfg.TRACKER_HEADS.get("TRACK_IOU_THRESHOLD", 0.5) |
| | return { |
| | "_target_": "detectron2.tracking.bbox_iou_tracker.BBoxIOUTracker", |
| | "video_height": video_height, |
| | "video_width": video_width, |
| | "max_num_instances": max_num_instances, |
| | "max_lost_frame_count": max_lost_frame_count, |
| | "min_box_rel_dim": min_box_rel_dim, |
| | "min_instance_period": min_instance_period, |
| | "track_iou_threshold": track_iou_threshold, |
| | } |
| |
|
| | def update(self, instances: Instances) -> Instances: |
| | """ |
| | See BaseTracker description |
| | """ |
| | instances = self._initialize_extra_fields(instances) |
| | if self._prev_instances is not None: |
| | |
| | iou_all = pairwise_iou( |
| | boxes1=instances.pred_boxes, |
| | boxes2=self._prev_instances.pred_boxes, |
| | ) |
| | |
| | bbox_pairs = self._create_prediction_pairs(instances, iou_all) |
| | |
| | self._reset_fields() |
| | for bbox_pair in bbox_pairs: |
| | idx = bbox_pair["idx"] |
| | prev_id = bbox_pair["prev_id"] |
| | if ( |
| | idx in self._matched_idx |
| | or prev_id in self._matched_ID |
| | or bbox_pair["IoU"] < self._track_iou_threshold |
| | ): |
| | continue |
| | instances.ID[idx] = prev_id |
| | instances.ID_period[idx] = bbox_pair["prev_period"] + 1 |
| | instances.lost_frame_count[idx] = 0 |
| | self._matched_idx.add(idx) |
| | self._matched_ID.add(prev_id) |
| | self._untracked_prev_idx.remove(bbox_pair["prev_idx"]) |
| | instances = self._assign_new_id(instances) |
| | instances = self._merge_untracked_instances(instances) |
| | self._prev_instances = copy.deepcopy(instances) |
| | return instances |
| |
|
| | def _create_prediction_pairs(self, instances: Instances, iou_all: np.ndarray) -> List: |
| | """ |
| | For all instances in previous and current frames, create pairs. For each |
| | pair, store index of the instance in current frame predcitions, index in |
| | previous predictions, ID in previous predictions, IoU of the bboxes in this |
| | pair, period in previous predictions. |
| | |
| | Args: |
| | instances: D2 Instances, for predictions of the current frame |
| | iou_all: IoU for all bboxes pairs |
| | Return: |
| | A list of IoU for all pairs |
| | """ |
| | bbox_pairs = [] |
| | for i in range(len(instances)): |
| | for j in range(len(self._prev_instances)): |
| | bbox_pairs.append( |
| | { |
| | "idx": i, |
| | "prev_idx": j, |
| | "prev_id": self._prev_instances.ID[j], |
| | "IoU": iou_all[i, j], |
| | "prev_period": self._prev_instances.ID_period[j], |
| | } |
| | ) |
| | return bbox_pairs |
| |
|
| | def _initialize_extra_fields(self, instances: Instances) -> Instances: |
| | """ |
| | If input instances don't have ID, ID_period, lost_frame_count fields, |
| | this method is used to initialize these fields. |
| | |
| | Args: |
| | instances: D2 Instances, for predictions of the current frame |
| | Return: |
| | D2 Instances with extra fields added |
| | """ |
| | if not instances.has("ID"): |
| | instances.set("ID", [None] * len(instances)) |
| | if not instances.has("ID_period"): |
| | instances.set("ID_period", [None] * len(instances)) |
| | if not instances.has("lost_frame_count"): |
| | instances.set("lost_frame_count", [None] * len(instances)) |
| | if self._prev_instances is None: |
| | instances.ID = list(range(len(instances))) |
| | self._id_count += len(instances) |
| | instances.ID_period = [1] * len(instances) |
| | instances.lost_frame_count = [0] * len(instances) |
| | return instances |
| |
|
| | def _reset_fields(self): |
| | """ |
| | Before each uodate call, reset fields first |
| | """ |
| | self._matched_idx = set() |
| | self._matched_ID = set() |
| | self._untracked_prev_idx = set(range(len(self._prev_instances))) |
| |
|
| | def _assign_new_id(self, instances: Instances) -> Instances: |
| | """ |
| | For each untracked instance, assign a new id |
| | |
| | Args: |
| | instances: D2 Instances, for predictions of the current frame |
| | Return: |
| | D2 Instances with new ID assigned |
| | """ |
| | untracked_idx = set(range(len(instances))).difference(self._matched_idx) |
| | for idx in untracked_idx: |
| | instances.ID[idx] = self._id_count |
| | self._id_count += 1 |
| | instances.ID_period[idx] = 1 |
| | instances.lost_frame_count[idx] = 0 |
| | return instances |
| |
|
| | def _merge_untracked_instances(self, instances: Instances) -> Instances: |
| | """ |
| | For untracked previous instances, under certain condition, still keep them |
| | in tracking and merge with the current instances. |
| | |
| | Args: |
| | instances: D2 Instances, for predictions of the current frame |
| | Return: |
| | D2 Instances merging current instances and instances from previous |
| | frame decided to keep tracking |
| | """ |
| | untracked_instances = Instances( |
| | image_size=instances.image_size, |
| | pred_boxes=[], |
| | pred_classes=[], |
| | scores=[], |
| | ID=[], |
| | ID_period=[], |
| | lost_frame_count=[], |
| | ) |
| | prev_bboxes = list(self._prev_instances.pred_boxes) |
| | prev_classes = list(self._prev_instances.pred_classes) |
| | prev_scores = list(self._prev_instances.scores) |
| | prev_ID_period = self._prev_instances.ID_period |
| | if instances.has("pred_masks"): |
| | untracked_instances.set("pred_masks", []) |
| | prev_masks = list(self._prev_instances.pred_masks) |
| | if instances.has("pred_keypoints"): |
| | untracked_instances.set("pred_keypoints", []) |
| | prev_keypoints = list(self._prev_instances.pred_keypoints) |
| | if instances.has("pred_keypoint_heatmaps"): |
| | untracked_instances.set("pred_keypoint_heatmaps", []) |
| | prev_keypoint_heatmaps = list(self._prev_instances.pred_keypoint_heatmaps) |
| | for idx in self._untracked_prev_idx: |
| | x_left, y_top, x_right, y_bot = prev_bboxes[idx] |
| | if ( |
| | (1.0 * (x_right - x_left) / self._video_width < self._min_box_rel_dim) |
| | or (1.0 * (y_bot - y_top) / self._video_height < self._min_box_rel_dim) |
| | or self._prev_instances.lost_frame_count[idx] >= self._max_lost_frame_count |
| | or prev_ID_period[idx] <= self._min_instance_period |
| | ): |
| | continue |
| | untracked_instances.pred_boxes.append(list(prev_bboxes[idx].numpy())) |
| | untracked_instances.pred_classes.append(int(prev_classes[idx])) |
| | untracked_instances.scores.append(float(prev_scores[idx])) |
| | untracked_instances.ID.append(self._prev_instances.ID[idx]) |
| | untracked_instances.ID_period.append(self._prev_instances.ID_period[idx]) |
| | untracked_instances.lost_frame_count.append( |
| | self._prev_instances.lost_frame_count[idx] + 1 |
| | ) |
| | if instances.has("pred_masks"): |
| | untracked_instances.pred_masks.append(prev_masks[idx].numpy().astype(np.uint8)) |
| | if instances.has("pred_keypoints"): |
| | untracked_instances.pred_keypoints.append( |
| | prev_keypoints[idx].numpy().astype(np.uint8) |
| | ) |
| | if instances.has("pred_keypoint_heatmaps"): |
| | untracked_instances.pred_keypoint_heatmaps.append( |
| | prev_keypoint_heatmaps[idx].numpy().astype(np.float32) |
| | ) |
| | untracked_instances.pred_boxes = Boxes(torch.FloatTensor(untracked_instances.pred_boxes)) |
| | untracked_instances.pred_classes = torch.IntTensor(untracked_instances.pred_classes) |
| | untracked_instances.scores = torch.FloatTensor(untracked_instances.scores) |
| | if instances.has("pred_masks"): |
| | untracked_instances.pred_masks = torch.IntTensor(untracked_instances.pred_masks) |
| | if instances.has("pred_keypoints"): |
| | untracked_instances.pred_keypoints = torch.IntTensor(untracked_instances.pred_keypoints) |
| | if instances.has("pred_keypoint_heatmaps"): |
| | untracked_instances.pred_keypoint_heatmaps = torch.FloatTensor( |
| | untracked_instances.pred_keypoint_heatmaps |
| | ) |
| |
|
| | return Instances.cat( |
| | [ |
| | instances, |
| | untracked_instances, |
| | ] |
| | ) |
| |
|