|
|
|
|
| from detectron2.config import configurable
|
| from detectron2.utils.registry import Registry
|
|
|
| from ..config.config import CfgNode as CfgNode_
|
| from ..structures import Instances
|
|
|
| TRACKER_HEADS_REGISTRY = Registry("TRACKER_HEADS")
|
| TRACKER_HEADS_REGISTRY.__doc__ = """
|
| Registry for tracking classes.
|
| """
|
|
|
|
|
| class BaseTracker:
|
| """
|
| A parent class for all trackers
|
| """
|
|
|
| @configurable
|
| def __init__(self, **kwargs):
|
| self._prev_instances = None
|
| self._matched_idx = set()
|
| self._matched_ID = set()
|
| self._untracked_prev_idx = set()
|
| self._id_count = 0
|
|
|
| @classmethod
|
| def from_config(cls, cfg: CfgNode_):
|
| raise NotImplementedError("Calling BaseTracker::from_config")
|
|
|
| def update(self, predictions: Instances) -> Instances:
|
| """
|
| Args:
|
| predictions: D2 Instances for predictions of the current frame
|
| Return:
|
| D2 Instances for predictions of the current frame with ID assigned
|
|
|
| _prev_instances and instances will have the following fields:
|
| .pred_boxes (shape=[N, 4])
|
| .scores (shape=[N,])
|
| .pred_classes (shape=[N,])
|
| .pred_keypoints (shape=[N, M, 3], Optional)
|
| .pred_masks (shape=List[2D_MASK], Optional) 2D_MASK: shape=[H, W]
|
| .ID (shape=[N,])
|
|
|
| N: # of detected bboxes
|
| H and W: height and width of 2D mask
|
| """
|
| raise NotImplementedError("Calling BaseTracker::update")
|
|
|
|
|
| def build_tracker_head(cfg: CfgNode_) -> BaseTracker:
|
| """
|
| Build a tracker head from `cfg.TRACKER_HEADS.TRACKER_NAME`.
|
|
|
| Args:
|
| cfg: D2 CfgNode, config file with tracker information
|
| Return:
|
| tracker object
|
| """
|
| name = cfg.TRACKER_HEADS.TRACKER_NAME
|
| tracker_class = TRACKER_HEADS_REGISTRY.get(name)
|
| return tracker_class(cfg)
|
|
|