| | |
| | |
| | from detectron2.config import configurable |
| | from detectron2.utils.registry import Registry |
| |
|
| | from ..config.config import CfgNode as CfgNode_ |
| | from ..structures import Instances |
| |
|
| | TRACKER_HEADS_REGISTRY = Registry("TRACKER_HEADS") |
| | TRACKER_HEADS_REGISTRY.__doc__ = """ |
| | Registry for tracking classes. |
| | """ |
| |
|
| |
|
| | class BaseTracker: |
| | """ |
| | A parent class for all trackers |
| | """ |
| |
|
| | @configurable |
| | def __init__(self, **kwargs): |
| | self._prev_instances = None |
| | self._matched_idx = set() |
| | self._matched_ID = set() |
| | self._untracked_prev_idx = set() |
| | self._id_count = 0 |
| |
|
| | @classmethod |
| | def from_config(cls, cfg: CfgNode_): |
| | raise NotImplementedError("Calling BaseTracker::from_config") |
| |
|
| | def update(self, predictions: Instances) -> Instances: |
| | """ |
| | Args: |
| | predictions: D2 Instances for predictions of the current frame |
| | Return: |
| | D2 Instances for predictions of the current frame with ID assigned |
| | |
| | _prev_instances and instances will have the following fields: |
| | .pred_boxes (shape=[N, 4]) |
| | .scores (shape=[N,]) |
| | .pred_classes (shape=[N,]) |
| | .pred_keypoints (shape=[N, M, 3], Optional) |
| | .pred_masks (shape=List[2D_MASK], Optional) 2D_MASK: shape=[H, W] |
| | .ID (shape=[N,]) |
| | |
| | N: # of detected bboxes |
| | H and W: height and width of 2D mask |
| | """ |
| | raise NotImplementedError("Calling BaseTracker::update") |
| |
|
| |
|
| | def build_tracker_head(cfg: CfgNode_) -> BaseTracker: |
| | """ |
| | Build a tracker head from `cfg.TRACKER_HEADS.TRACKER_NAME`. |
| | |
| | Args: |
| | cfg: D2 CfgNode, config file with tracker information |
| | Return: |
| | tracker object |
| | """ |
| | name = cfg.TRACKER_HEADS.TRACKER_NAME |
| | tracker_class = TRACKER_HEADS_REGISTRY.get(name) |
| | return tracker_class(cfg) |
| |
|