Spaces:
Runtime error
Runtime error
| """ | |
| Author: Siyuan Li | |
| Licensed: Apache-2.0 License | |
| """ | |
| from typing import List, Tuple | |
| import torch | |
| import torch.nn.functional as F | |
| from mmdet.models.trackers.base_tracker import BaseTracker | |
| from mmdet.registry import MODELS | |
| from mmdet.structures import TrackDataSample | |
| from mmdet.structures.bbox import bbox_overlaps | |
| from mmengine.structures import InstanceData | |
| from torch import Tensor | |
| class MasaTaoTracker(BaseTracker): | |
| """Tracker for MASA on TAO benchmark. | |
| Args: | |
| init_score_thr (float): The cls_score threshold to | |
| initialize a new tracklet. Defaults to 0.8. | |
| obj_score_thr (float): The cls_score threshold to | |
| update a tracked tracklet. Defaults to 0.5. | |
| match_score_thr (float): The match threshold. Defaults to 0.5. | |
| memo_tracklet_frames (int): The most frames in a tracklet memory. | |
| Defaults to 10. | |
| memo_momentum (float): The momentum value for embeds updating. | |
| Defaults to 0.8. | |
| distractor_score_thr (float): The score threshold to consider an object as a distractor. | |
| Defaults to 0.5. | |
| distractor_nms_thr (float): The NMS threshold for filtering out distractors. | |
| Defaults to 0.3. | |
| with_cats (bool): Whether to track with the same category. | |
| Defaults to True. | |
| match_metric (str): The match metric. Can be 'bisoftmax', 'softmax', or 'cosine'. Defaults to 'bisoftmax'. | |
| max_distance (float): Maximum distance for considering matches. Defaults to -1. | |
| fps (int): Frames per second of the input video. Used for calculating growth factor. Defaults to 1. | |
| """ | |
| def __init__( | |
| self, | |
| init_score_thr: float = 0.8, | |
| obj_score_thr: float = 0.5, | |
| match_score_thr: float = 0.5, | |
| memo_tracklet_frames: int = 10, | |
| memo_momentum: float = 0.8, | |
| distractor_score_thr: float = 0.5, | |
| distractor_nms_thr=0.3, | |
| with_cats: bool = True, | |
| max_distance: float = -1, | |
| fps=1, | |
| **kwargs | |
| ): | |
| super().__init__(**kwargs) | |
| assert 0 <= memo_momentum <= 1.0 | |
| assert memo_tracklet_frames >= 0 | |
| self.init_score_thr = init_score_thr | |
| self.obj_score_thr = obj_score_thr | |
| self.match_score_thr = match_score_thr | |
| self.memo_tracklet_frames = memo_tracklet_frames | |
| self.memo_momentum = memo_momentum | |
| self.distractor_score_thr = distractor_score_thr | |
| self.distractor_nms_thr = distractor_nms_thr | |
| self.with_cats = with_cats | |
| self.num_tracks = 0 | |
| self.tracks = dict() | |
| self.backdrops = [] | |
| self.max_distance = max_distance # Maximum distance for considering matches | |
| self.fps = fps | |
| self.growth_factor = self.fps / 6 # Growth factor for the distance mask | |
| self.distance_smoothing_factor = 100 / self.fps | |
| def reset(self): | |
| """Reset the buffer of the tracker.""" | |
| self.num_tracks = 0 | |
| self.tracks = dict() | |
| self.backdrops = [] | |
| def update( | |
| self, | |
| ids: Tensor, | |
| bboxes: Tensor, | |
| embeds: Tensor, | |
| labels: Tensor, | |
| scores: Tensor, | |
| frame_id: int, | |
| ) -> None: | |
| """Tracking forward function. | |
| Args: | |
| ids (Tensor): of shape(N, ). | |
| bboxes (Tensor): of shape (N, 5). | |
| embeds (Tensor): of shape (N, 256). | |
| labels (Tensor): of shape (N, ). | |
| scores (Tensor): of shape (N, ). | |
| frame_id (int): The id of current frame, 0-index. | |
| """ | |
| tracklet_inds = ids > -1 | |
| for id, bbox, embed, label, score in zip( | |
| ids[tracklet_inds], | |
| bboxes[tracklet_inds], | |
| embeds[tracklet_inds], | |
| labels[tracklet_inds], | |
| scores[tracklet_inds], | |
| ): | |
| id = int(id) | |
| # update the tracked ones and initialize new tracks | |
| if id in self.tracks.keys(): | |
| self.tracks[id]["bbox"] = bbox | |
| self.tracks[id]["embed"] = (1 - self.memo_momentum) * self.tracks[id][ | |
| "embed" | |
| ] + self.memo_momentum * embed | |
| self.tracks[id]["last_frame"] = frame_id | |
| self.tracks[id]["label"] = label | |
| self.tracks[id]["score"] = score | |
| else: | |
| self.tracks[id] = dict( | |
| bbox=bbox, | |
| embed=embed, | |
| label=label, | |
| score=score, | |
| last_frame=frame_id, | |
| ) | |
| # pop memo | |
| invalid_ids = [] | |
| for k, v in self.tracks.items(): | |
| if frame_id - v["last_frame"] >= self.memo_tracklet_frames: | |
| invalid_ids.append(k) | |
| for invalid_id in invalid_ids: | |
| self.tracks.pop(invalid_id) | |
| def memo(self) -> Tuple[Tensor, ...]: | |
| """Get tracks memory.""" | |
| memo_embeds = [] | |
| memo_ids = [] | |
| memo_bboxes = [] | |
| memo_labels = [] | |
| memo_frame_ids = [] | |
| # get tracks | |
| for k, v in self.tracks.items(): | |
| memo_bboxes.append(v["bbox"][None, :]) | |
| memo_embeds.append(v["embed"][None, :]) | |
| memo_ids.append(k) | |
| memo_labels.append(v["label"].view(1, 1)) | |
| memo_frame_ids.append(v["last_frame"]) | |
| memo_ids = torch.tensor(memo_ids, dtype=torch.long).view(1, -1) | |
| memo_bboxes = torch.cat(memo_bboxes, dim=0) | |
| memo_embeds = torch.cat(memo_embeds, dim=0) | |
| memo_labels = torch.cat(memo_labels, dim=0).squeeze(1) | |
| memo_frame_ids = torch.tensor(memo_frame_ids, dtype=torch.long).view(1, -1) | |
| return ( | |
| memo_bboxes, | |
| memo_labels, | |
| memo_embeds, | |
| memo_ids.squeeze(0), | |
| memo_frame_ids.squeeze(0), | |
| ) | |
| def compute_distance_mask(self, bboxes1, bboxes2, frame_ids1, frame_ids2): | |
| """Compute a mask based on the pairwise center distances and frame IDs with piecewise soft-weighting.""" | |
| centers1 = (bboxes1[:, :2] + bboxes1[:, 2:]) / 2.0 | |
| centers2 = (bboxes2[:, :2] + bboxes2[:, 2:]) / 2.0 | |
| distances = torch.cdist(centers1, centers2) | |
| frame_id_diff = torch.abs(frame_ids1[:, None] - frame_ids2[None, :]).to( | |
| distances.device | |
| ) | |
| # Define a scaling factor for the distance based on frame difference (exponential growth) | |
| scaling_factor = torch.exp(frame_id_diff.float() / self.growth_factor) | |
| # Apply the scaling factor to max_distance | |
| adaptive_max_distance = self.max_distance * scaling_factor | |
| # Create a piecewise function for soft gating | |
| soft_distance_mask = torch.where( | |
| distances <= adaptive_max_distance, | |
| torch.ones_like(distances), | |
| torch.exp( | |
| -(distances - adaptive_max_distance) / self.distance_smoothing_factor | |
| ), | |
| ) | |
| return soft_distance_mask | |
| def track( | |
| self, | |
| model: torch.nn.Module, | |
| img: torch.Tensor, | |
| feats: List[torch.Tensor], | |
| data_sample: TrackDataSample, | |
| rescale=True, | |
| with_segm=False, | |
| **kwargs | |
| ) -> InstanceData: | |
| """Tracking forward function. | |
| Args: | |
| model (nn.Module): MOT model. | |
| img (Tensor): of shape (T, C, H, W) encoding input image. | |
| Typically these should be mean centered and std scaled. | |
| The T denotes the number of key images and usually is 1. | |
| feats (list[Tensor]): Multi level feature maps of `img`. | |
| data_sample (:obj:`TrackDataSample`): The data sample. | |
| It includes information such as `pred_instances`. | |
| rescale (bool, optional): If True, the bounding boxes should be | |
| rescaled to fit the original scale of the image. Defaults to | |
| True. | |
| Returns: | |
| :obj:`InstanceData`: Tracking results of the input images. | |
| Each InstanceData usually contains ``bboxes``, ``labels``, | |
| ``scores`` and ``instances_id``. | |
| """ | |
| metainfo = data_sample.metainfo | |
| bboxes = data_sample.pred_instances.bboxes | |
| labels = data_sample.pred_instances.labels | |
| scores = data_sample.pred_instances.scores | |
| frame_id = metainfo.get("frame_id", -1) | |
| # create pred_track_instances | |
| pred_track_instances = InstanceData() | |
| # return zero bboxes if there is no track targets | |
| if bboxes.shape[0] == 0: | |
| ids = torch.zeros_like(labels) | |
| pred_track_instances = data_sample.pred_instances.clone() | |
| pred_track_instances.instances_id = ids | |
| pred_track_instances.mask_inds = torch.zeros_like(labels) | |
| return pred_track_instances | |
| # get track feats | |
| rescaled_bboxes = bboxes.clone() | |
| if rescale: | |
| scale_factor = rescaled_bboxes.new_tensor(metainfo["scale_factor"]).repeat( | |
| (1, 2) | |
| ) | |
| rescaled_bboxes = rescaled_bboxes * scale_factor | |
| track_feats = model.track_head.predict(feats, [rescaled_bboxes]) | |
| # sort according to the object_score | |
| _, inds = scores.sort(descending=True) | |
| bboxes = bboxes[inds] | |
| scores = scores[inds] | |
| labels = labels[inds] | |
| embeds = track_feats[inds, :] | |
| if with_segm: | |
| mask_inds = torch.arange(bboxes.size(0)).to(embeds.device) | |
| mask_inds = mask_inds[inds] | |
| else: | |
| mask_inds = [] | |
| bboxes, labels, scores, embeds, mask_inds = self.remove_distractor( | |
| bboxes, | |
| labels, | |
| scores, | |
| track_feats=embeds, | |
| mask_inds=mask_inds, | |
| nms="inter", | |
| distractor_score_thr=self.distractor_score_thr, | |
| distractor_nms_thr=self.distractor_nms_thr, | |
| ) | |
| # init ids container | |
| ids = torch.full((bboxes.size(0),), -1, dtype=torch.long) | |
| # match if buffer is not empty | |
| if bboxes.size(0) > 0 and not self.empty: | |
| ( | |
| memo_bboxes, | |
| memo_labels, | |
| memo_embeds, | |
| memo_ids, | |
| memo_frame_ids, | |
| ) = self.memo | |
| feats = torch.mm(embeds, memo_embeds.t()) | |
| d2t_scores = feats.softmax(dim=1) | |
| t2d_scores = feats.softmax(dim=0) | |
| match_scores_bisoftmax = (d2t_scores + t2d_scores) / 2 | |
| match_scores_cosine = torch.mm( | |
| F.normalize(embeds, p=2, dim=1), | |
| F.normalize(memo_embeds, p=2, dim=1).t(), | |
| ) | |
| match_scores = (match_scores_bisoftmax + match_scores_cosine) / 2 | |
| if self.max_distance != -1: | |
| # Compute the mask based on spatial proximity | |
| current_frame_ids = torch.full( | |
| (bboxes.size(0),), frame_id, dtype=torch.long | |
| ) | |
| distance_mask = self.compute_distance_mask( | |
| bboxes, memo_bboxes, current_frame_ids, memo_frame_ids | |
| ) | |
| # Apply the mask to the match scores | |
| match_scores = match_scores * distance_mask | |
| # track according to match_scores | |
| for i in range(bboxes.size(0)): | |
| conf, memo_ind = torch.max(match_scores[i, :], dim=0) | |
| id = memo_ids[memo_ind] | |
| if conf > self.match_score_thr: | |
| if id > -1: | |
| # keep bboxes with high object score | |
| # and remove background bboxes | |
| if scores[i] > self.obj_score_thr: | |
| ids[i] = id | |
| match_scores[:i, memo_ind] = 0 | |
| match_scores[i + 1 :, memo_ind] = 0 | |
| # initialize new tracks | |
| new_inds = (ids == -1) & (scores > self.init_score_thr).cpu() | |
| num_news = new_inds.sum() | |
| ids[new_inds] = torch.arange( | |
| self.num_tracks, self.num_tracks + num_news, dtype=torch.long | |
| ) | |
| self.num_tracks += num_news | |
| self.update(ids, bboxes, embeds, labels, scores, frame_id) | |
| tracklet_inds = ids > -1 | |
| # update pred_track_instances | |
| pred_track_instances.bboxes = bboxes[tracklet_inds] | |
| pred_track_instances.labels = labels[tracklet_inds] | |
| pred_track_instances.scores = scores[tracklet_inds] | |
| pred_track_instances.instances_id = ids[tracklet_inds] | |
| if with_segm: | |
| pred_track_instances.mask_inds = mask_inds[tracklet_inds] | |
| return pred_track_instances | |
| def remove_distractor( | |
| self, | |
| bboxes, | |
| labels, | |
| scores, | |
| track_feats, | |
| mask_inds=[], | |
| distractor_score_thr=0.5, | |
| distractor_nms_thr=0.3, | |
| nms="inter", | |
| ): | |
| # all objects is valid here | |
| valid_inds = labels > -1 | |
| # nms | |
| low_inds = torch.nonzero(scores < distractor_score_thr, as_tuple=False).squeeze( | |
| 1 | |
| ) | |
| if nms == "inter": | |
| ious = bbox_overlaps(bboxes[low_inds, :], bboxes[:, :]) | |
| elif nms == "intra": | |
| cat_same = labels[low_inds].view(-1, 1) == labels.view(1, -1) | |
| ious = bbox_overlaps(bboxes[low_inds, :], bboxes) | |
| ious *= cat_same.to(ious.device) | |
| else: | |
| raise NotImplementedError | |
| for i, ind in enumerate(low_inds): | |
| if (ious[i, :ind] > distractor_nms_thr).any(): | |
| valid_inds[ind] = False | |
| bboxes = bboxes[valid_inds] | |
| labels = labels[valid_inds] | |
| scores = scores[valid_inds] | |
| if track_feats is not None: | |
| track_feats = track_feats[valid_inds] | |
| if len(mask_inds) > 0: | |
| mask_inds = mask_inds[valid_inds] | |
| return bboxes, labels, scores, track_feats, mask_inds | |