R3PM-Net / thirdparty /learning3d /losses /rmse_features.py
YasiiKB's picture
initial commit
97aa5af verified
import torch
import torch.nn as nn
import torch.nn.functional as F
def rmseOnFeatures(feature_difference):
# |feature_difference| should be 0
gt = torch.zeros_like(feature_difference)
return torch.nn.functional.mse_loss(feature_difference, gt, size_average=False)
class RMSEFeaturesLoss(nn.Module):
def __init__(self):
super(RMSEFeaturesLoss, self).__init__()
def forward(self, feature_difference):
return rmseOnFeatures(feature_difference)