| | |
| | import torch |
| |
|
| | from detectron2.layers import nonzero_tuple |
| |
|
| | __all__ = ["subsample_labels"] |
| |
|
| |
|
| | def subsample_labels( |
| | labels: torch.Tensor, num_samples: int, positive_fraction: float, bg_label: int |
| | ): |
| | """ |
| | Return `num_samples` (or fewer, if not enough found) |
| | random samples from `labels` which is a mixture of positives & negatives. |
| | It will try to return as many positives as possible without |
| | exceeding `positive_fraction * num_samples`, and then try to |
| | fill the remaining slots with negatives. |
| | |
| | Args: |
| | labels (Tensor): (N, ) label vector with values: |
| | * -1: ignore |
| | * bg_label: background ("negative") class |
| | * otherwise: one or more foreground ("positive") classes |
| | num_samples (int): The total number of labels with value >= 0 to return. |
| | Values that are not sampled will be filled with -1 (ignore). |
| | positive_fraction (float): The number of subsampled labels with values > 0 |
| | is `min(num_positives, int(positive_fraction * num_samples))`. The number |
| | of negatives sampled is `min(num_negatives, num_samples - num_positives_sampled)`. |
| | In order words, if there are not enough positives, the sample is filled with |
| | negatives. If there are also not enough negatives, then as many elements are |
| | sampled as is possible. |
| | bg_label (int): label index of background ("negative") class. |
| | |
| | Returns: |
| | pos_idx, neg_idx (Tensor): |
| | 1D vector of indices. The total length of both is `num_samples` or fewer. |
| | """ |
| | positive = nonzero_tuple((labels != -1) & (labels != bg_label))[0] |
| | negative = nonzero_tuple(labels == bg_label)[0] |
| |
|
| | num_pos = int(num_samples * positive_fraction) |
| | |
| | num_pos = min(positive.numel(), num_pos) |
| | num_neg = num_samples - num_pos |
| | |
| | num_neg = min(negative.numel(), num_neg) |
| |
|
| | |
| | perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos] |
| | perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg] |
| |
|
| | pos_idx = positive[perm1] |
| | neg_idx = negative[perm2] |
| | return pos_idx, neg_idx |
| |
|