R3PM-Net / thirdparty /learning3d /losses /chamfer_distance.py
YasiiKB's picture
initial commit
97aa5af verified
import torch
import torch.nn as nn
import torch.nn.functional as F
def pairwise_distances(a: torch.Tensor, b: torch.Tensor, p=2):
"""
Compute the pairwise distance_tensor matrix between a and b which both have size [m, n, d]. The result is a tensor of
size [m, n, n] whose entry [m, i, j] contains the distance_tensor between a[m, i, :] and b[m, j, :].
:param a: A tensor containing m batches of n points of dimension d. i.e. of size [m, n, d]
:param b: A tensor containing m batches of n points of dimension d. i.e. of size [m, n, d]
:param p: Norm to use for the distance_tensor
:return: A tensor containing the pairwise distance_tensor between each pair of inputs in a batch.
"""
if len(a.shape) != 3:
raise ValueError("Invalid shape for a. Must be [m, n, d] but got", a.shape)
if len(b.shape) != 3:
raise ValueError("Invalid shape for a. Must be [m, n, d] but got", b.shape)
return (a.unsqueeze(2) - b.unsqueeze(1)).abs().pow(p).sum(3)
def chamfer(a, b):
"""
Compute the chamfer distance between two sets of vectors, a, and b
:param a: A m-sized minibatch of point sets in R^d. i.e. shape [m, n_a, d]
:param b: A m-sized minibatch of point sets in R^d. i.e. shape [m, n_b, d]
:return: A [m] shaped tensor storing the Chamfer distance between each minibatch entry
"""
M = pairwise_distances(a, b)
dist1 = torch.mean(torch.sqrt(M.min(1)[0]))
dist2 = torch.mean(torch.sqrt(M.min(2)[0]))
return (dist1 + dist2) / 2.0
def chamfer_distance(template: torch.Tensor, source: torch.Tensor):
try:
from .cuda.chamfer_distance import ChamferDistance
cost_p0_p1, cost_p1_p0 = ChamferDistance()(template, source)
cost_p0_p1 = torch.mean(torch.sqrt(cost_p0_p1))
cost_p1_p0 = torch.mean(torch.sqrt(cost_p1_p0))
chamfer_loss = (cost_p0_p1 + cost_p1_p0)/2.0
except:
chamfer_loss = chamfer(template, source)
return chamfer_loss
class ChamferDistanceLoss(nn.Module):
def __init__(self):
super(ChamferDistanceLoss, self).__init__()
def forward(self, template, source):
return chamfer_distance(template, source)