| |
|
|
| import torch |
| from torch import nn |
|
|
| |
|
|
|
|
| class AutomaticWeightedLoss(nn.Module): |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| def __init__(self, num=2, args=None): |
| super(AutomaticWeightedLoss, self).__init__() |
| if args is None or args.use_awl: |
| params = torch.ones(num, requires_grad=True) |
| self.params = torch.nn.Parameter(params) |
| else: |
| params = torch.ones(num, requires_grad=False) |
| self.params = torch.nn.Parameter(params, requires_grad=False) |
|
|
| def forward(self, *x): |
| loss_sum = 0 |
| for i, loss in enumerate(x): |
| loss_sum += 0.5 / (self.params[i] ** 2) * loss + torch.log(1 + self.params[i] ** 2) |
| return loss_sum |
|
|