| | |
| | |
| |
|
| | """ |
| | @Author : Peike Li |
| | @Contact : peike.li@yahoo.com |
| | @File : kl_loss.py |
| | @Time : 7/23/19 4:02 PM |
| | @Desc : |
| | @License : This source code is licensed under the license found in the |
| | LICENSE file in the root directory of this source tree. |
| | """ |
| | import torch |
| | import torch.nn.functional as F |
| | from torch import nn |
| |
|
| |
|
| | def flatten_probas(input, target, labels, ignore=255): |
| | """ |
| | Flattens predictions in the batch. |
| | """ |
| | B, C, H, W = input.size() |
| | input = input.permute(0, 2, 3, 1).contiguous().view(-1, C) |
| | target = target.permute(0, 2, 3, 1).contiguous().view(-1, C) |
| | labels = labels.view(-1) |
| | if ignore is None: |
| | return input, target |
| | valid = (labels != ignore) |
| | vinput = input[valid.nonzero().squeeze()] |
| | vtarget = target[valid.nonzero().squeeze()] |
| | return vinput, vtarget |
| |
|
| |
|
| | class KLDivergenceLoss(nn.Module): |
| | def __init__(self, ignore_index=255, T=1): |
| | super(KLDivergenceLoss, self).__init__() |
| | self.ignore_index=ignore_index |
| | self.T = T |
| |
|
| | def forward(self, input, target, label): |
| | log_input_prob = F.log_softmax(input / self.T, dim=1) |
| | target_porb = F.softmax(target / self.T, dim=1) |
| | loss = F.kl_div(*flatten_probas(log_input_prob, target_porb, label, ignore=self.ignore_index)) |
| | return self.T*self.T*loss |
| |
|