| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class SGVLB(nn.Module): | |
| def __init__(self, net, train_size, loss_type='cross_entropy', device='cuda'): | |
| super(SGVLB, self).__init__() | |
| self.train_size = train_size | |
| self.net = net | |
| self.loss_type = loss_type | |
| self.device = device | |
| def forward(self, input, target, kl_weight=1.0): | |
| assert not target.requires_grad | |
| kl = torch.FloatTensor([0.0]).to(self.device) | |
| for module in self.net.children(): | |
| if hasattr(module, 'kl_reg'): | |
| kl = kl + module.kl_reg() | |
| if self.loss_type == 'cross_entropy': | |
| SGVLB = F.cross_entropy(input, target) * self.train_size + kl_weight * kl | |
| elif self.loss_type in ['l2', 'L2']: | |
| SGVLB = ((input - target) ** 2).mean() * self.train_size + kl_weight * kl | |
| else: | |
| raise NotImplementedError | |
| return SGVLB | |