File size: 453 Bytes
97aa5af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
import torch.nn as nn
import torch.nn.functional as F

def rmseOnFeatures(feature_difference):
	# |feature_difference| should be 0
	gt = torch.zeros_like(feature_difference)
	return torch.nn.functional.mse_loss(feature_difference, gt, size_average=False)


class RMSEFeaturesLoss(nn.Module):
	def __init__(self):
		super(RMSEFeaturesLoss, self).__init__()

	def forward(self, feature_difference):
		return rmseOnFeatures(feature_difference)