R3PM-Net / thirdparty /learning3d /losses /frobenius_norm.py
YasiiKB's picture
initial commit
97aa5af verified
import torch
import torch.nn as nn
import torch.nn.functional as F
def frobeniusNormLoss(predicted, igt):
""" |predicted*igt - I| (should be 0) """
assert predicted.size(0) == igt.size(0)
assert predicted.size(1) == igt.size(1) and predicted.size(1) == 4
assert predicted.size(2) == igt.size(2) and predicted.size(2) == 4
error = predicted.matmul(igt)
I = torch.eye(4).to(error).view(1, 4, 4).expand(error.size(0), 4, 4)
return torch.nn.functional.mse_loss(error, I, size_average=True) * 16
class FrobeniusNormLoss(nn.Module):
def __init__(self):
super(FrobeniusNormLoss, self).__init__()
def forward(self, predicted, igt):
return frobeniusNormLoss(predicted, igt)