Spaces:
Build error
Build error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from typing import Sequence, Tuple | |
| import torch | |
| from torch import Tensor | |
| from torch.autograd import Function | |
| from torch.autograd.function import once_differentiable | |
| from ..utils import ext_loader | |
| ext_module = ext_loader.load_ext( | |
| '_ext', ['chamfer_distance_forward', 'chamfer_distance_backward']) | |
| class ChamferDistanceFunction(Function): | |
| """This is an implementation of the 2D Chamfer Distance. | |
| It has been used in the paper `Oriented RepPoints for Aerial Object | |
| Detection (CVPR 2022) <https://arxiv.org/abs/2105.11111>_`. | |
| """ | |
| def forward(ctx, xyz1: Tensor, xyz2: Tensor) -> Sequence[Tensor]: | |
| """ | |
| Args: | |
| xyz1 (Tensor): Point set with shape (B, N, 2). | |
| xyz2 (Tensor): Point set with shape (B, N, 2). | |
| Returns: | |
| Sequence[Tensor]: | |
| - dist1 (Tensor): Chamfer distance (xyz1 to xyz2) with | |
| shape (B, N). | |
| - dist2 (Tensor): Chamfer distance (xyz2 to xyz1) with | |
| shape (B, N). | |
| - idx1 (Tensor): Index of chamfer distance (xyz1 to xyz2) | |
| with shape (B, N), which be used in compute gradient. | |
| - idx2 (Tensor): Index of chamfer distance (xyz2 to xyz2) | |
| with shape (B, N), which be used in compute gradient. | |
| """ | |
| batch_size, n, _ = xyz1.size() | |
| _, m, _ = xyz2.size() | |
| device = xyz1.device | |
| xyz1 = xyz1.contiguous() | |
| xyz2 = xyz2.contiguous() | |
| dist1 = torch.zeros(batch_size, n).to(device) | |
| dist2 = torch.zeros(batch_size, m).to(device) | |
| idx1 = torch.zeros(batch_size, n).type(torch.IntTensor).to(device) | |
| idx2 = torch.zeros(batch_size, m).type(torch.IntTensor).to(device) | |
| ext_module.chamfer_distance_forward(xyz1, xyz2, dist1, dist2, idx1, | |
| idx2) | |
| ctx.save_for_backward(xyz1, xyz2, idx1, idx2) | |
| return dist1, dist2, idx1, idx2 | |
| def backward(ctx, | |
| grad_dist1: Tensor, | |
| grad_dist2: Tensor, | |
| grad_idx1=None, | |
| grad_idx2=None) -> Tuple[Tensor, Tensor]: | |
| """ | |
| Args: | |
| grad_dist1 (Tensor): Gradient of chamfer distance | |
| (xyz1 to xyz2) with shape (B, N). | |
| grad_dist2 (Tensor): Gradient of chamfer distance | |
| (xyz2 to xyz1) with shape (B, N). | |
| Returns: | |
| Tuple[Tensor, Tensor]: | |
| - grad_xyz1 (Tensor): Gradient of the point set with shape \ | |
| (B, N, 2). | |
| - grad_xyz2 (Tensor):Gradient of the point set with shape \ | |
| (B, N, 2). | |
| """ | |
| xyz1, xyz2, idx1, idx2 = ctx.saved_tensors | |
| device = grad_dist1.device | |
| grad_dist1 = grad_dist1.contiguous() | |
| grad_dist2 = grad_dist2.contiguous() | |
| grad_xyz1 = torch.zeros(xyz1.size()).to(device) | |
| grad_xyz2 = torch.zeros(xyz2.size()).to(device) | |
| ext_module.chamfer_distance_backward(xyz1, xyz2, idx1, idx2, | |
| grad_dist1, grad_dist2, grad_xyz1, | |
| grad_xyz2) | |
| return grad_xyz1, grad_xyz2 | |
| chamfer_distance = ChamferDistanceFunction.apply | |