| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| def pairwise_distances(a: torch.Tensor, b: torch.Tensor, p=2): |
| """ |
| Compute the pairwise distance_tensor matrix between a and b which both have size [m, n, d]. The result is a tensor of |
| size [m, n, n] whose entry [m, i, j] contains the distance_tensor between a[m, i, :] and b[m, j, :]. |
| :param a: A tensor containing m batches of n points of dimension d. i.e. of size [m, n, d] |
| :param b: A tensor containing m batches of n points of dimension d. i.e. of size [m, n, d] |
| :param p: Norm to use for the distance_tensor |
| :return: A tensor containing the pairwise distance_tensor between each pair of inputs in a batch. |
| """ |
|
|
| if len(a.shape) != 3: |
| raise ValueError("Invalid shape for a. Must be [m, n, d] but got", a.shape) |
| if len(b.shape) != 3: |
| raise ValueError("Invalid shape for a. Must be [m, n, d] but got", b.shape) |
| return (a.unsqueeze(2) - b.unsqueeze(1)).abs().pow(p).sum(3) |
|
|
| def chamfer(a, b): |
| """ |
| Compute the chamfer distance between two sets of vectors, a, and b |
| :param a: A m-sized minibatch of point sets in R^d. i.e. shape [m, n_a, d] |
| :param b: A m-sized minibatch of point sets in R^d. i.e. shape [m, n_b, d] |
| :return: A [m] shaped tensor storing the Chamfer distance between each minibatch entry |
| """ |
| M = pairwise_distances(a, b) |
| dist1 = torch.mean(torch.sqrt(M.min(1)[0])) |
| dist2 = torch.mean(torch.sqrt(M.min(2)[0])) |
| return (dist1 + dist2) / 2.0 |
|
|
|
|
| def chamfer_distance(template: torch.Tensor, source: torch.Tensor): |
| try: |
| from .cuda.chamfer_distance import ChamferDistance |
| cost_p0_p1, cost_p1_p0 = ChamferDistance()(template, source) |
| cost_p0_p1 = torch.mean(torch.sqrt(cost_p0_p1)) |
| cost_p1_p0 = torch.mean(torch.sqrt(cost_p1_p0)) |
| chamfer_loss = (cost_p0_p1 + cost_p1_p0)/2.0 |
| except: |
| chamfer_loss = chamfer(template, source) |
| return chamfer_loss |
|
|
|
|
| class ChamferDistanceLoss(nn.Module): |
| def __init__(self): |
| super(ChamferDistanceLoss, self).__init__() |
|
|
| def forward(self, template, source): |
| return chamfer_distance(template, source) |