|
|
|
|
| import copy
|
| import numpy as np
|
| from typing import List
|
| import torch
|
|
|
| from detectron2.config import configurable
|
| from detectron2.structures import Boxes, Instances
|
| from detectron2.structures.boxes import pairwise_iou
|
|
|
| from ..config.config import CfgNode as CfgNode_
|
| from .base_tracker import TRACKER_HEADS_REGISTRY, BaseTracker
|
|
|
|
|
| @TRACKER_HEADS_REGISTRY.register()
|
| class BBoxIOUTracker(BaseTracker):
|
| """
|
| A bounding box tracker to assign ID based on IoU between current and previous instances
|
| """
|
|
|
| @configurable
|
| def __init__(
|
| self,
|
| *,
|
| video_height: int,
|
| video_width: int,
|
| max_num_instances: int = 200,
|
| max_lost_frame_count: int = 0,
|
| min_box_rel_dim: float = 0.02,
|
| min_instance_period: int = 1,
|
| track_iou_threshold: float = 0.5,
|
| **kwargs,
|
| ):
|
| """
|
| Args:
|
| video_height: height the video frame
|
| video_width: width of the video frame
|
| max_num_instances: maximum number of id allowed to be tracked
|
| max_lost_frame_count: maximum number of frame an id can lost tracking
|
| exceed this number, an id is considered as lost
|
| forever
|
| min_box_rel_dim: a percentage, smaller than this dimension, a bbox is
|
| removed from tracking
|
| min_instance_period: an instance will be shown after this number of period
|
| since its first showing up in the video
|
| track_iou_threshold: iou threshold, below this number a bbox pair is removed
|
| from tracking
|
| """
|
| super().__init__(**kwargs)
|
| self._video_height = video_height
|
| self._video_width = video_width
|
| self._max_num_instances = max_num_instances
|
| self._max_lost_frame_count = max_lost_frame_count
|
| self._min_box_rel_dim = min_box_rel_dim
|
| self._min_instance_period = min_instance_period
|
| self._track_iou_threshold = track_iou_threshold
|
|
|
| @classmethod
|
| def from_config(cls, cfg: CfgNode_):
|
| """
|
| Old style initialization using CfgNode
|
|
|
| Args:
|
| cfg: D2 CfgNode, config file
|
| Return:
|
| dictionary storing arguments for __init__ method
|
| """
|
| assert "VIDEO_HEIGHT" in cfg.TRACKER_HEADS
|
| assert "VIDEO_WIDTH" in cfg.TRACKER_HEADS
|
| video_height = cfg.TRACKER_HEADS.get("VIDEO_HEIGHT")
|
| video_width = cfg.TRACKER_HEADS.get("VIDEO_WIDTH")
|
| max_num_instances = cfg.TRACKER_HEADS.get("MAX_NUM_INSTANCES", 200)
|
| max_lost_frame_count = cfg.TRACKER_HEADS.get("MAX_LOST_FRAME_COUNT", 0)
|
| min_box_rel_dim = cfg.TRACKER_HEADS.get("MIN_BOX_REL_DIM", 0.02)
|
| min_instance_period = cfg.TRACKER_HEADS.get("MIN_INSTANCE_PERIOD", 1)
|
| track_iou_threshold = cfg.TRACKER_HEADS.get("TRACK_IOU_THRESHOLD", 0.5)
|
| return {
|
| "_target_": "detectron2.tracking.bbox_iou_tracker.BBoxIOUTracker",
|
| "video_height": video_height,
|
| "video_width": video_width,
|
| "max_num_instances": max_num_instances,
|
| "max_lost_frame_count": max_lost_frame_count,
|
| "min_box_rel_dim": min_box_rel_dim,
|
| "min_instance_period": min_instance_period,
|
| "track_iou_threshold": track_iou_threshold,
|
| }
|
|
|
| def update(self, instances: Instances) -> Instances:
|
| """
|
| See BaseTracker description
|
| """
|
| instances = self._initialize_extra_fields(instances)
|
| if self._prev_instances is not None:
|
|
|
| iou_all = pairwise_iou(
|
| boxes1=instances.pred_boxes,
|
| boxes2=self._prev_instances.pred_boxes,
|
| )
|
|
|
| bbox_pairs = self._create_prediction_pairs(instances, iou_all)
|
|
|
| self._reset_fields()
|
| for bbox_pair in bbox_pairs:
|
| idx = bbox_pair["idx"]
|
| prev_id = bbox_pair["prev_id"]
|
| if (
|
| idx in self._matched_idx
|
| or prev_id in self._matched_ID
|
| or bbox_pair["IoU"] < self._track_iou_threshold
|
| ):
|
| continue
|
| instances.ID[idx] = prev_id
|
| instances.ID_period[idx] = bbox_pair["prev_period"] + 1
|
| instances.lost_frame_count[idx] = 0
|
| self._matched_idx.add(idx)
|
| self._matched_ID.add(prev_id)
|
| self._untracked_prev_idx.remove(bbox_pair["prev_idx"])
|
| instances = self._assign_new_id(instances)
|
| instances = self._merge_untracked_instances(instances)
|
| self._prev_instances = copy.deepcopy(instances)
|
| return instances
|
|
|
| def _create_prediction_pairs(self, instances: Instances, iou_all: np.ndarray) -> List:
|
| """
|
| For all instances in previous and current frames, create pairs. For each
|
| pair, store index of the instance in current frame predcitions, index in
|
| previous predictions, ID in previous predictions, IoU of the bboxes in this
|
| pair, period in previous predictions.
|
|
|
| Args:
|
| instances: D2 Instances, for predictions of the current frame
|
| iou_all: IoU for all bboxes pairs
|
| Return:
|
| A list of IoU for all pairs
|
| """
|
| bbox_pairs = []
|
| for i in range(len(instances)):
|
| for j in range(len(self._prev_instances)):
|
| bbox_pairs.append(
|
| {
|
| "idx": i,
|
| "prev_idx": j,
|
| "prev_id": self._prev_instances.ID[j],
|
| "IoU": iou_all[i, j],
|
| "prev_period": self._prev_instances.ID_period[j],
|
| }
|
| )
|
| return bbox_pairs
|
|
|
| def _initialize_extra_fields(self, instances: Instances) -> Instances:
|
| """
|
| If input instances don't have ID, ID_period, lost_frame_count fields,
|
| this method is used to initialize these fields.
|
|
|
| Args:
|
| instances: D2 Instances, for predictions of the current frame
|
| Return:
|
| D2 Instances with extra fields added
|
| """
|
| if not instances.has("ID"):
|
| instances.set("ID", [None] * len(instances))
|
| if not instances.has("ID_period"):
|
| instances.set("ID_period", [None] * len(instances))
|
| if not instances.has("lost_frame_count"):
|
| instances.set("lost_frame_count", [None] * len(instances))
|
| if self._prev_instances is None:
|
| instances.ID = list(range(len(instances)))
|
| self._id_count += len(instances)
|
| instances.ID_period = [1] * len(instances)
|
| instances.lost_frame_count = [0] * len(instances)
|
| return instances
|
|
|
| def _reset_fields(self):
|
| """
|
| Before each uodate call, reset fields first
|
| """
|
| self._matched_idx = set()
|
| self._matched_ID = set()
|
| self._untracked_prev_idx = set(range(len(self._prev_instances)))
|
|
|
| def _assign_new_id(self, instances: Instances) -> Instances:
|
| """
|
| For each untracked instance, assign a new id
|
|
|
| Args:
|
| instances: D2 Instances, for predictions of the current frame
|
| Return:
|
| D2 Instances with new ID assigned
|
| """
|
| untracked_idx = set(range(len(instances))).difference(self._matched_idx)
|
| for idx in untracked_idx:
|
| instances.ID[idx] = self._id_count
|
| self._id_count += 1
|
| instances.ID_period[idx] = 1
|
| instances.lost_frame_count[idx] = 0
|
| return instances
|
|
|
| def _merge_untracked_instances(self, instances: Instances) -> Instances:
|
| """
|
| For untracked previous instances, under certain condition, still keep them
|
| in tracking and merge with the current instances.
|
|
|
| Args:
|
| instances: D2 Instances, for predictions of the current frame
|
| Return:
|
| D2 Instances merging current instances and instances from previous
|
| frame decided to keep tracking
|
| """
|
| untracked_instances = Instances(
|
| image_size=instances.image_size,
|
| pred_boxes=[],
|
| pred_classes=[],
|
| scores=[],
|
| ID=[],
|
| ID_period=[],
|
| lost_frame_count=[],
|
| )
|
| prev_bboxes = list(self._prev_instances.pred_boxes)
|
| prev_classes = list(self._prev_instances.pred_classes)
|
| prev_scores = list(self._prev_instances.scores)
|
| prev_ID_period = self._prev_instances.ID_period
|
| if instances.has("pred_masks"):
|
| untracked_instances.set("pred_masks", [])
|
| prev_masks = list(self._prev_instances.pred_masks)
|
| if instances.has("pred_keypoints"):
|
| untracked_instances.set("pred_keypoints", [])
|
| prev_keypoints = list(self._prev_instances.pred_keypoints)
|
| if instances.has("pred_keypoint_heatmaps"):
|
| untracked_instances.set("pred_keypoint_heatmaps", [])
|
| prev_keypoint_heatmaps = list(self._prev_instances.pred_keypoint_heatmaps)
|
| for idx in self._untracked_prev_idx:
|
| x_left, y_top, x_right, y_bot = prev_bboxes[idx]
|
| if (
|
| (1.0 * (x_right - x_left) / self._video_width < self._min_box_rel_dim)
|
| or (1.0 * (y_bot - y_top) / self._video_height < self._min_box_rel_dim)
|
| or self._prev_instances.lost_frame_count[idx] >= self._max_lost_frame_count
|
| or prev_ID_period[idx] <= self._min_instance_period
|
| ):
|
| continue
|
| untracked_instances.pred_boxes.append(list(prev_bboxes[idx].numpy()))
|
| untracked_instances.pred_classes.append(int(prev_classes[idx]))
|
| untracked_instances.scores.append(float(prev_scores[idx]))
|
| untracked_instances.ID.append(self._prev_instances.ID[idx])
|
| untracked_instances.ID_period.append(self._prev_instances.ID_period[idx])
|
| untracked_instances.lost_frame_count.append(
|
| self._prev_instances.lost_frame_count[idx] + 1
|
| )
|
| if instances.has("pred_masks"):
|
| untracked_instances.pred_masks.append(prev_masks[idx].numpy().astype(np.uint8))
|
| if instances.has("pred_keypoints"):
|
| untracked_instances.pred_keypoints.append(
|
| prev_keypoints[idx].numpy().astype(np.uint8)
|
| )
|
| if instances.has("pred_keypoint_heatmaps"):
|
| untracked_instances.pred_keypoint_heatmaps.append(
|
| prev_keypoint_heatmaps[idx].numpy().astype(np.float32)
|
| )
|
| untracked_instances.pred_boxes = Boxes(torch.FloatTensor(untracked_instances.pred_boxes))
|
| untracked_instances.pred_classes = torch.IntTensor(untracked_instances.pred_classes)
|
| untracked_instances.scores = torch.FloatTensor(untracked_instances.scores)
|
| if instances.has("pred_masks"):
|
| untracked_instances.pred_masks = torch.IntTensor(untracked_instances.pred_masks)
|
| if instances.has("pred_keypoints"):
|
| untracked_instances.pred_keypoints = torch.IntTensor(untracked_instances.pred_keypoints)
|
| if instances.has("pred_keypoint_heatmaps"):
|
| untracked_instances.pred_keypoint_heatmaps = torch.FloatTensor(
|
| untracked_instances.pred_keypoint_heatmaps
|
| )
|
|
|
| return Instances.cat(
|
| [
|
| instances,
|
| untracked_instances,
|
| ]
|
| )
|
|
|