|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Box matcher implementation."""
|
|
|
| from typing import List, Tuple
|
|
|
| import tensorflow as tf, tf_keras
|
|
|
|
|
| class BoxMatcher:
|
| """Matcher based on highest value.
|
|
|
| This class computes matches from a similarity matrix. Each column is matched
|
| to a single row.
|
|
|
| To support object detection target assignment this class enables setting both
|
| positive_threshold (upper threshold) and negative_threshold (lower thresholds)
|
| defining three categories of similarity which define whether examples are
|
| positive, negative, or ignored, for example:
|
| (1) thresholds=[negative_threshold, positive_threshold], and
|
| indicators=[negative_value, ignore_value, positive_value]: The similarity
|
| metrics below negative_threshold will be assigned with negative_value,
|
| the metrics between negative_threshold and positive_threshold will be
|
| assigned ignore_value, and the metrics above positive_threshold will be
|
| assigned positive_value.
|
| (2) thresholds=[negative_threshold, positive_threshold], and
|
| indicators=[ignore_value, negative_value, positive_value]: The similarity
|
| metric below negative_threshold will be assigned with ignore_value,
|
| the metrics between negative_threshold and positive_threshold will be
|
| assigned negative_value, and the metrics above positive_threshold will be
|
| assigned positive_value.
|
| """
|
|
|
| def __init__(self,
|
| thresholds: List[float],
|
| indicators: List[int],
|
| force_match_for_each_col: bool = False):
|
| """Construct BoxMatcher.
|
|
|
| Args:
|
| thresholds: A list of thresholds to classify the matches into different
|
| types (e.g. positive or negative or ignored match). The list needs to be
|
| sorted, and will be prepended with -Inf and appended with +Inf.
|
| indicators: A list of values representing match types (e.g. positive or
|
| negative or ignored match). len(`indicators`) must equal to
|
| len(`thresholds`) + 1.
|
| force_match_for_each_col: If True, ensures that each column is matched to
|
| at least one row (which is not guaranteed otherwise if the
|
| positive_threshold is high). Defaults to False. If True, all force
|
| matched row will be assigned to `indicators[-1]`.
|
|
|
| Raises:
|
| ValueError: If `threshold` not sorted,
|
| or len(indicators) != len(threshold) + 1
|
| """
|
| if not all([lo <= hi for (lo, hi) in zip(thresholds[:-1], thresholds[1:])]):
|
| raise ValueError('`threshold` must be sorted, got {}'.format(thresholds))
|
| self.indicators = indicators
|
| if len(indicators) != len(thresholds) + 1:
|
| raise ValueError('len(`indicators`) must be len(`thresholds`) + 1, got '
|
| 'indicators {}, thresholds {}'.format(
|
| indicators, thresholds))
|
| thresholds = thresholds[:]
|
| thresholds.insert(0, -float('inf'))
|
| thresholds.append(float('inf'))
|
| self.thresholds = thresholds
|
| self._force_match_for_each_col = force_match_for_each_col
|
|
|
| def __call__(self,
|
| similarity_matrix: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
|
| """Tries to match each column of the similarity matrix to a row.
|
|
|
| Args:
|
| similarity_matrix: A float tensor of shape [num_rows, num_cols] or
|
| [batch_size, num_rows, num_cols] representing any similarity metric.
|
|
|
| Returns:
|
| matched_columns: An integer tensor of shape [num_rows] or [batch_size,
|
| num_rows] storing the index of the matched column for each row.
|
| match_indicators: An integer tensor of shape [num_rows] or [batch_size,
|
| num_rows] storing the match type indicator (e.g. positive or negative or
|
| ignored match).
|
| """
|
| squeeze_result = False
|
| if len(similarity_matrix.shape) == 2:
|
| squeeze_result = True
|
| similarity_matrix = tf.expand_dims(similarity_matrix, axis=0)
|
|
|
| static_shape = similarity_matrix.shape.as_list()
|
| num_rows = static_shape[1] or tf.shape(similarity_matrix)[1]
|
| batch_size = static_shape[0] or tf.shape(similarity_matrix)[0]
|
|
|
| def _match_when_rows_are_empty():
|
| """Performs matching when the rows of similarity matrix are empty.
|
|
|
| When the rows are empty, all detections are false positives. So we return
|
| a tensor of -1's to indicate that the rows do not match to any columns.
|
|
|
| Returns:
|
| matched_columns: An integer tensor of shape [num_rows] or [batch_size,
|
| num_rows] storing the index of the matched column for each row.
|
| match_indicators: An integer tensor of shape [num_rows] or [batch_size,
|
| num_rows] storing the match type indicator (e.g. positive or negative
|
| or ignored match).
|
| """
|
| with tf.name_scope('empty_gt_boxes'):
|
| matched_columns = tf.zeros([batch_size, num_rows], dtype=tf.int32)
|
| match_indicators = -tf.ones([batch_size, num_rows], dtype=tf.int32)
|
| return matched_columns, match_indicators
|
|
|
| def _match_when_rows_are_non_empty():
|
| """Performs matching when the rows of similarity matrix are non empty.
|
|
|
| Returns:
|
| matched_columns: An integer tensor of shape [num_rows] or [batch_size,
|
| num_rows] storing the index of the matched column for each row.
|
| match_indicators: An integer tensor of shape [num_rows] or [batch_size,
|
| num_rows] storing the match type indicator (e.g. positive or negative
|
| or ignored match).
|
| """
|
| with tf.name_scope('non_empty_gt_boxes'):
|
| matched_columns = tf.argmax(
|
| similarity_matrix, axis=-1, output_type=tf.int32)
|
|
|
|
|
| matched_vals = tf.reduce_max(similarity_matrix, axis=-1)
|
| match_indicators = tf.zeros([batch_size, num_rows], tf.int32)
|
|
|
| match_dtype = matched_vals.dtype
|
| for (ind, low, high) in zip(self.indicators, self.thresholds[:-1],
|
| self.thresholds[1:]):
|
| low_threshold = tf.cast(low, match_dtype)
|
| high_threshold = tf.cast(high, match_dtype)
|
| mask = tf.logical_and(
|
| tf.greater_equal(matched_vals, low_threshold),
|
| tf.less(matched_vals, high_threshold))
|
| match_indicators = self._set_values_using_indicator(
|
| match_indicators, mask, ind)
|
|
|
| if self._force_match_for_each_col:
|
|
|
|
|
| matching_rows = tf.argmax(
|
| input=similarity_matrix, axis=1, output_type=tf.int32)
|
|
|
|
|
| column_to_row_match_mapping = tf.one_hot(
|
| matching_rows, depth=num_rows)
|
|
|
|
|
| force_matched_columns = tf.argmax(
|
| input=column_to_row_match_mapping, axis=1, output_type=tf.int32)
|
|
|
| force_matched_column_mask = tf.cast(
|
| tf.reduce_max(column_to_row_match_mapping, axis=1), tf.bool)
|
|
|
| matched_columns = tf.where(force_matched_column_mask,
|
| force_matched_columns, matched_columns)
|
| match_indicators = tf.where(
|
| force_matched_column_mask, self.indicators[-1] *
|
| tf.ones([batch_size, num_rows], dtype=tf.int32), match_indicators)
|
|
|
| return matched_columns, match_indicators
|
|
|
| num_gt_boxes = similarity_matrix.shape.as_list()[-1] or tf.shape(
|
| similarity_matrix)[-1]
|
| matched_columns, match_indicators = tf.cond(
|
| pred=tf.greater(num_gt_boxes, 0),
|
| true_fn=_match_when_rows_are_non_empty,
|
| false_fn=_match_when_rows_are_empty)
|
|
|
| if squeeze_result:
|
| matched_columns = tf.squeeze(matched_columns, axis=0)
|
| match_indicators = tf.squeeze(match_indicators, axis=0)
|
|
|
| return matched_columns, match_indicators
|
|
|
| def _set_values_using_indicator(self, x, indicator, val):
|
| """Set the indicated fields of x to val.
|
|
|
| Args:
|
| x: tensor.
|
| indicator: boolean with same shape as x.
|
| val: scalar with value to set.
|
|
|
| Returns:
|
| modified tensor.
|
| """
|
| indicator = tf.cast(indicator, x.dtype)
|
| return tf.add(tf.multiply(x, 1 - indicator), val * indicator)
|
|
|