import torch import torch.nn as nn import torch.nn.functional as F def frobeniusNormLoss(predicted, igt): """ |predicted*igt - I| (should be 0) """ assert predicted.size(0) == igt.size(0) assert predicted.size(1) == igt.size(1) and predicted.size(1) == 4 assert predicted.size(2) == igt.size(2) and predicted.size(2) == 4 error = predicted.matmul(igt) I = torch.eye(4).to(error).view(1, 4, 4).expand(error.size(0), 4, 4) return torch.nn.functional.mse_loss(error, I, size_average=True) * 16 class FrobeniusNormLoss(nn.Module): def __init__(self): super(FrobeniusNormLoss, self).__init__() def forward(self, predicted, igt): return frobeniusNormLoss(predicted, igt)