import torch import torch.nn as nn import torch.nn.functional as F def emd(template: torch.Tensor, source: torch.Tensor): from emd import EMDLoss emd_loss = torch.mean(self.emd(template, source))/(template.size()[1]) return emd_loss class EMDLoss(nn.Module): def __init__(self): super(EMDLoss, self).__init__() def forward(self, template, source): return emd(template, source)