| |
| |
|
|
| """ |
| @Author : Peike Li |
| @Contact : peike.li@yahoo.com |
| @File : kl_loss.py |
| @Time : 7/23/19 4:02 PM |
| @Desc : |
| @License : This source code is licensed under the license found in the |
| LICENSE file in the root directory of this source tree. |
| """ |
| import torch |
| import torch.nn.functional as F |
| from torch import nn |
|
|
|
|
| def flatten_probas(input, target, labels, ignore=255): |
| """ |
| Flattens predictions in the batch. |
| """ |
| B, C, H, W = input.size() |
| input = input.permute(0, 2, 3, 1).contiguous().view(-1, C) |
| target = target.permute(0, 2, 3, 1).contiguous().view(-1, C) |
| labels = labels.view(-1) |
| if ignore is None: |
| return input, target |
| valid = (labels != ignore) |
| vinput = input[valid.nonzero().squeeze()] |
| vtarget = target[valid.nonzero().squeeze()] |
| return vinput, vtarget |
|
|
|
|
| class KLDivergenceLoss(nn.Module): |
| def __init__(self, ignore_index=255, T=1): |
| super(KLDivergenceLoss, self).__init__() |
| self.ignore_index=ignore_index |
| self.T = T |
|
|
| def forward(self, input, target, label): |
| log_input_prob = F.log_softmax(input / self.T, dim=1) |
| target_porb = F.softmax(target / self.T, dim=1) |
| loss = F.kl_div(*flatten_probas(log_input_prob, target_porb, label, ignore=self.ignore_index)) |
| return self.T*self.T*loss |
|
|