|
|
| from typing import List
|
| import torch
|
|
|
| from detectron2.layers import nonzero_tuple
|
|
|
|
|
|
|
| class Matcher:
|
| """
|
| This class assigns to each predicted "element" (e.g., a box) a ground-truth
|
| element. Each predicted element will have exactly zero or one matches; each
|
| ground-truth element may be matched to zero or more predicted elements.
|
|
|
| The matching is determined by the MxN match_quality_matrix, that characterizes
|
| how well each (ground-truth, prediction)-pair match each other. For example,
|
| if the elements are boxes, this matrix may contain box intersection-over-union
|
| overlap values.
|
|
|
| The matcher returns (a) a vector of length N containing the index of the
|
| ground-truth element m in [0, M) that matches to prediction n in [0, N).
|
| (b) a vector of length N containing the labels for each prediction.
|
| """
|
|
|
| def __init__(
|
| self, thresholds: List[float], labels: List[int], allow_low_quality_matches: bool = False
|
| ):
|
| """
|
| Args:
|
| thresholds (list): a list of thresholds used to stratify predictions
|
| into levels.
|
| labels (list): a list of values to label predictions belonging at
|
| each level. A label can be one of {-1, 0, 1} signifying
|
| {ignore, negative class, positive class}, respectively.
|
| allow_low_quality_matches (bool): if True, produce additional matches
|
| for predictions with maximum match quality lower than high_threshold.
|
| See set_low_quality_matches_ for more details.
|
|
|
| For example,
|
| thresholds = [0.3, 0.5]
|
| labels = [0, -1, 1]
|
| All predictions with iou < 0.3 will be marked with 0 and
|
| thus will be considered as false positives while training.
|
| All predictions with 0.3 <= iou < 0.5 will be marked with -1 and
|
| thus will be ignored.
|
| All predictions with 0.5 <= iou will be marked with 1 and
|
| thus will be considered as true positives.
|
| """
|
|
|
| thresholds = thresholds[:]
|
| assert thresholds[0] > 0
|
| thresholds.insert(0, -float("inf"))
|
| thresholds.append(float("inf"))
|
|
|
| assert all([low <= high for (low, high) in zip(thresholds[:-1], thresholds[1:])])
|
| assert all([l in [-1, 0, 1] for l in labels])
|
| assert len(labels) == len(thresholds) - 1
|
| self.thresholds = thresholds
|
| self.labels = labels
|
| self.allow_low_quality_matches = allow_low_quality_matches
|
|
|
| def __call__(self, match_quality_matrix):
|
| """
|
| Args:
|
| match_quality_matrix (Tensor[float]): an MxN tensor, containing the
|
| pairwise quality between M ground-truth elements and N predicted
|
| elements. All elements must be >= 0 (due to the us of `torch.nonzero`
|
| for selecting indices in :meth:`set_low_quality_matches_`).
|
|
|
| Returns:
|
| matches (Tensor[int64]): a vector of length N, where matches[i] is a matched
|
| ground-truth index in [0, M)
|
| match_labels (Tensor[int8]): a vector of length N, where pred_labels[i] indicates
|
| whether a prediction is a true or false positive or ignored
|
| """
|
| assert match_quality_matrix.dim() == 2
|
| if match_quality_matrix.numel() == 0:
|
| default_matches = match_quality_matrix.new_full(
|
| (match_quality_matrix.size(1),), 0, dtype=torch.int64
|
| )
|
|
|
|
|
|
|
| default_match_labels = match_quality_matrix.new_full(
|
| (match_quality_matrix.size(1),), self.labels[0], dtype=torch.int8
|
| )
|
| return default_matches, default_match_labels
|
|
|
| assert torch.all(match_quality_matrix >= 0)
|
|
|
|
|
|
|
| matched_vals, matches = match_quality_matrix.max(dim=0)
|
|
|
| match_labels = matches.new_full(matches.size(), 1, dtype=torch.int8)
|
|
|
| for (l, low, high) in zip(self.labels, self.thresholds[:-1], self.thresholds[1:]):
|
| low_high = (matched_vals >= low) & (matched_vals < high)
|
| match_labels[low_high] = l
|
|
|
| if self.allow_low_quality_matches:
|
| self.set_low_quality_matches_(match_labels, match_quality_matrix)
|
|
|
| return matches, match_labels
|
|
|
| def set_low_quality_matches_(self, match_labels, match_quality_matrix):
|
| """
|
| Produce additional matches for predictions that have only low-quality matches.
|
| Specifically, for each ground-truth G find the set of predictions that have
|
| maximum overlap with it (including ties); for each prediction in that set, if
|
| it is unmatched, then match it to the ground-truth G.
|
|
|
| This function implements the RPN assignment case (i) in Sec. 3.1.2 of
|
| :paper:`Faster R-CNN`.
|
| """
|
|
|
| highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1)
|
|
|
|
|
|
|
| _, pred_inds_with_highest_quality = nonzero_tuple(
|
| match_quality_matrix == highest_quality_foreach_gt[:, None]
|
| )
|
|
|
|
|
|
|
| match_labels[pred_inds_with_highest_quality] = 1
|
|
|