| import torch | |
| import torch.nn as nn | |
| import random | |
| import os | |
| import numpy as np | |
| import logging | |
| def TET_loss(outputs, labels, criterion, means, lamb): | |
| print('using TET') | |
| T = outputs.size(1) | |
| Loss_es = 0 | |
| for t in range(T): | |
| Loss_es += criterion(outputs[t, ...], labels) | |
| Loss_es = Loss_es / T # L_TET | |
| if lamb != 0: | |
| MMDLoss = torch.nn.MSELoss() | |
| y = torch.zeros_like(outputs).fill_(means) | |
| Loss_mmd = MMDLoss(outputs, y) # L_mse | |
| else: | |
| Loss_mmd = 0 | |
| return (1 - lamb) * Loss_es + lamb * Loss_mmd # L_Total |