File size: 453 Bytes
97aa5af | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | import torch
import torch.nn as nn
import torch.nn.functional as F
def rmseOnFeatures(feature_difference):
# |feature_difference| should be 0
gt = torch.zeros_like(feature_difference)
return torch.nn.functional.mse_loss(feature_difference, gt, size_average=False)
class RMSEFeaturesLoss(nn.Module):
def __init__(self):
super(RMSEFeaturesLoss, self).__init__()
def forward(self, feature_difference):
return rmseOnFeatures(feature_difference) |