| """ |
| Copyright (c) 2024 The D-FINE Authors. All Rights Reserved. |
| """ |
|
|
| import torch |
| from .box_ops import box_xyxy_to_cxcywh |
|
|
|
|
| def weighting_function(reg_max, up, reg_scale, deploy=False): |
| """ |
| Generates the non-uniform Weighting Function W(n) for bounding box regression. |
| |
| Args: |
| reg_max (int): Max number of the discrete bins. |
| up (Tensor): Controls upper bounds of the sequence, |
| where maximum offset is ±up * H / W. |
| reg_scale (float): Controls the curvature of the Weighting Function. |
| Larger values result in flatter weights near the central axis W(reg_max/2)=0 |
| and steeper weights at both ends. |
| deploy (bool): If True, uses deployment mode settings. |
| |
| Returns: |
| Tensor: Sequence of Weighting Function. |
| """ |
| if deploy: |
| upper_bound1 = (abs(up[0]) * abs(reg_scale)).item() |
| upper_bound2 = (abs(up[0]) * abs(reg_scale) * 2).item() |
| step = (upper_bound1 + 1) ** (2 / (reg_max - 2)) |
| left_values = [-(step) ** i + 1 for i in range(reg_max // 2 - 1, 0, -1)] |
| right_values = [(step) ** i - 1 for i in range(1, reg_max // 2)] |
| values = [-upper_bound2] + left_values + [torch.zeros_like(up[0][None])] + right_values + [upper_bound2] |
| return torch.tensor(values, dtype=up.dtype, device=up.device) |
| else: |
| upper_bound1 = abs(up[0]) * abs(reg_scale) |
| upper_bound2 = abs(up[0]) * abs(reg_scale) * 2 |
| step = (upper_bound1 + 1) ** (2 / (reg_max - 2)) |
| left_values = [-(step) ** i + 1 for i in range(reg_max // 2 - 1, 0, -1)] |
| right_values = [(step) ** i - 1 for i in range(1, reg_max // 2)] |
| values = [-upper_bound2] + left_values + [torch.zeros_like(up[0][None])] + right_values + [upper_bound2] |
| return torch.cat(values, 0) |
|
|
|
|
| def translate_gt(gt, reg_max, reg_scale, up): |
| """ |
| Decodes bounding box ground truth (GT) values into distribution-based GT representations. |
| |
| This function maps continuous GT values into discrete distribution bins, which can be used |
| for regression tasks in object detection models. It calculates the indices of the closest |
| bins to each GT value and assigns interpolation weights to these bins based on their proximity |
| to the GT value. |
| |
| Args: |
| gt (Tensor): Ground truth bounding box values, shape (N, ). |
| reg_max (int): Maximum number of discrete bins for the distribution. |
| reg_scale (float): Controls the curvature of the Weighting Function. |
| up (Tensor): Controls the upper bounds of the Weighting Function. |
| |
| Returns: |
| Tuple[Tensor, Tensor, Tensor]: |
| - indices (Tensor): Index of the left bin closest to each GT value, shape (N, ). |
| - weight_right (Tensor): Weight assigned to the right bin, shape (N, ). |
| - weight_left (Tensor): Weight assigned to the left bin, shape (N, ). |
| """ |
| gt = gt.reshape(-1) |
| function_values = weighting_function(reg_max, up, reg_scale) |
|
|
| |
| diffs = function_values.unsqueeze(0) - gt.unsqueeze(1) |
| mask = diffs <= 0 |
| closest_left_indices = torch.sum(mask, dim=1) - 1 |
|
|
| |
| indices = closest_left_indices.float() |
|
|
| weight_right = torch.zeros_like(indices) |
| weight_left = torch.zeros_like(indices) |
|
|
| valid_idx_mask = (indices >= 0) & (indices < reg_max) |
| valid_indices = indices[valid_idx_mask].long() |
|
|
| |
| left_values = function_values[valid_indices] |
| right_values = function_values[valid_indices + 1] |
|
|
| left_diffs = torch.abs(gt[valid_idx_mask] - left_values) |
| right_diffs = torch.abs(right_values - gt[valid_idx_mask]) |
|
|
| |
| weight_right[valid_idx_mask] = left_diffs / (left_diffs + right_diffs) |
| weight_left[valid_idx_mask] = 1.0 - weight_right[valid_idx_mask] |
|
|
| |
| invalid_idx_mask_neg = (indices < 0) |
| weight_right[invalid_idx_mask_neg] = 0.0 |
| weight_left[invalid_idx_mask_neg] = 1.0 |
| indices[invalid_idx_mask_neg] = 0.0 |
|
|
| invalid_idx_mask_pos = (indices >= reg_max) |
| weight_right[invalid_idx_mask_pos] = 1.0 |
| weight_left[invalid_idx_mask_pos] = 0.0 |
| indices[invalid_idx_mask_pos] = reg_max - 0.1 |
|
|
| return indices, weight_right, weight_left |
|
|
|
|
| def distance2bbox(points, distance, reg_scale): |
| """ |
| Decodes edge-distances into bounding box coordinates. |
| |
| Args: |
| points (Tensor): (B, N, 4) or (N, 4) format, representing [x, y, w, h], |
| where (x, y) is the center and (w, h) are width and height. |
| distance (Tensor): (B, N, 4) or (N, 4), representing distances from the |
| point to the left, top, right, and bottom boundaries. |
| |
| reg_scale (float): Controls the curvature of the Weighting Function. |
| |
| Returns: |
| Tensor: Bounding boxes in (N, 4) or (B, N, 4) format [cx, cy, w, h]. |
| """ |
| reg_scale = abs(reg_scale) |
| x1 = points[..., 0] - (0.5 * reg_scale + distance[..., 0]) * (points[..., 2] / reg_scale) |
| y1 = points[..., 1] - (0.5 * reg_scale + distance[..., 1]) * (points[..., 3] / reg_scale) |
| x2 = points[..., 0] + (0.5 * reg_scale + distance[..., 2]) * (points[..., 2] / reg_scale) |
| y2 = points[..., 1] + (0.5 * reg_scale + distance[..., 3]) * (points[..., 3] / reg_scale) |
|
|
| bboxes = torch.stack([x1, y1, x2, y2], -1) |
|
|
| return box_xyxy_to_cxcywh(bboxes) |
|
|
|
|
| def bbox2distance(points, bbox, reg_max, reg_scale, up, eps=0.1): |
| """ |
| Converts bounding box coordinates to distances from a reference point. |
| |
| Args: |
| points (Tensor): (n, 4) [x, y, w, h], where (x, y) is the center. |
| bbox (Tensor): (n, 4) bounding boxes in "xyxy" format. |
| reg_max (float): Maximum bin value. |
| reg_scale (float): Controling curvarture of W(n). |
| up (Tensor): Controling upper bounds of W(n). |
| eps (float): Small value to ensure target < reg_max. |
| |
| Returns: |
| Tensor: Decoded distances. |
| """ |
| reg_scale = abs(reg_scale) |
| left = (points[:, 0] - bbox[:, 0]) / (points[..., 2] / reg_scale + 1e-16) - 0.5 * reg_scale |
| top = (points[:, 1] - bbox[:, 1]) / (points[..., 3] / reg_scale + 1e-16) - 0.5 * reg_scale |
| right = (bbox[:, 2] - points[:, 0]) / (points[..., 2] / reg_scale + 1e-16) - 0.5 * reg_scale |
| bottom = (bbox[:, 3] - points[:, 1]) / (points[..., 3] / reg_scale + 1e-16) - 0.5 * reg_scale |
| four_lens = torch.stack([left, top, right, bottom], -1) |
| four_lens, weight_right, weight_left = translate_gt(four_lens, reg_max, reg_scale, up) |
| if reg_max is not None: |
| four_lens = four_lens.clamp(min=0, max=reg_max-eps) |
| return four_lens.reshape(-1).detach(), weight_right.detach(), weight_left.detach() |
|
|