Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import Tuple | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import Tensor | |
| def select_candidates_in_gts(priors_points: Tensor, | |
| gt_bboxes: Tensor, | |
| eps: float = 1e-9) -> Tensor: | |
| """Select the positive priors' center in gt. | |
| Args: | |
| priors_points (Tensor): Model priors points, | |
| shape(num_priors, 2) | |
| gt_bboxes (Tensor): Ground true bboxes, | |
| shape(batch_size, num_gt, 4) | |
| eps (float): Default to 1e-9. | |
| Return: | |
| (Tensor): shape(batch_size, num_gt, num_priors) | |
| """ | |
| batch_size, num_gt, _ = gt_bboxes.size() | |
| gt_bboxes = gt_bboxes.reshape([-1, 4]) | |
| priors_number = priors_points.size(0) | |
| priors_points = priors_points.unsqueeze(0).repeat(batch_size * num_gt, 1, | |
| 1) | |
| # calculate the left, top, right, bottom distance between positive | |
| # prior center and gt side | |
| gt_bboxes_lt = gt_bboxes[:, 0:2].unsqueeze(1).repeat(1, priors_number, 1) | |
| gt_bboxes_rb = gt_bboxes[:, 2:4].unsqueeze(1).repeat(1, priors_number, 1) | |
| bbox_deltas = torch.cat( | |
| [priors_points - gt_bboxes_lt, gt_bboxes_rb - priors_points], dim=-1) | |
| bbox_deltas = bbox_deltas.reshape([batch_size, num_gt, priors_number, -1]) | |
| return (bbox_deltas.min(axis=-1)[0] > eps).to(gt_bboxes.dtype) | |
| def select_highest_overlaps(pos_mask: Tensor, overlaps: Tensor, | |
| num_gt: int) -> Tuple[Tensor, Tensor, Tensor]: | |
| """If an anchor box is assigned to multiple gts, the one with the highest | |
| iou will be selected. | |
| Args: | |
| pos_mask (Tensor): The assigned positive sample mask, | |
| shape(batch_size, num_gt, num_priors) | |
| overlaps (Tensor): IoU between all bbox and ground truth, | |
| shape(batch_size, num_gt, num_priors) | |
| num_gt (int): Number of ground truth. | |
| Return: | |
| gt_idx_pre_prior (Tensor): Target ground truth index, | |
| shape(batch_size, num_priors) | |
| fg_mask_pre_prior (Tensor): Force matching ground truth, | |
| shape(batch_size, num_priors) | |
| pos_mask (Tensor): The assigned positive sample mask, | |
| shape(batch_size, num_gt, num_priors) | |
| """ | |
| fg_mask_pre_prior = pos_mask.sum(axis=-2) | |
| # Make sure the positive sample matches the only one and is the largest IoU | |
| if fg_mask_pre_prior.max() > 1: | |
| mask_multi_gts = (fg_mask_pre_prior.unsqueeze(1) > 1).repeat( | |
| [1, num_gt, 1]) | |
| index = overlaps.argmax(axis=1) | |
| is_max_overlaps = F.one_hot(index, num_gt) | |
| is_max_overlaps = \ | |
| is_max_overlaps.permute(0, 2, 1).to(overlaps.dtype) | |
| pos_mask = torch.where(mask_multi_gts, is_max_overlaps, pos_mask) | |
| fg_mask_pre_prior = pos_mask.sum(axis=-2) | |
| gt_idx_pre_prior = pos_mask.argmax(axis=-2) | |
| return gt_idx_pre_prior, fg_mask_pre_prior, pos_mask | |
| # TODO:'mmdet.BboxOverlaps2D' will cause gradient inconsistency, | |
| # which will be found and solved in a later version. | |
| def yolov6_iou_calculator(bbox1: Tensor, | |
| bbox2: Tensor, | |
| eps: float = 1e-9) -> Tensor: | |
| """Calculate iou for batch. | |
| Args: | |
| bbox1 (Tensor): shape(batch size, num_gt, 4) | |
| bbox2 (Tensor): shape(batch size, num_priors, 4) | |
| eps (float): Default to 1e-9. | |
| Return: | |
| (Tensor): IoU, shape(size, num_gt, num_priors) | |
| """ | |
| bbox1 = bbox1.unsqueeze(2) # [N, M1, 4] -> [N, M1, 1, 4] | |
| bbox2 = bbox2.unsqueeze(1) # [N, M2, 4] -> [N, 1, M2, 4] | |
| # calculate xy info of predict and gt bbox | |
| bbox1_x1y1, bbox1_x2y2 = bbox1[:, :, :, 0:2], bbox1[:, :, :, 2:4] | |
| bbox2_x1y1, bbox2_x2y2 = bbox2[:, :, :, 0:2], bbox2[:, :, :, 2:4] | |
| # calculate overlap area | |
| overlap = (torch.minimum(bbox1_x2y2, bbox2_x2y2) - | |
| torch.maximum(bbox1_x1y1, bbox2_x1y1)).clip(0).prod(-1) | |
| # calculate bbox area | |
| bbox1_area = (bbox1_x2y2 - bbox1_x1y1).clip(0).prod(-1) | |
| bbox2_area = (bbox2_x2y2 - bbox2_x1y1).clip(0).prod(-1) | |
| union = bbox1_area + bbox2_area - overlap + eps | |
| return overlap / union | |