|
|
|
|
|
|
| """
|
| @Author : Peike Li
|
| @Contact : peike.li@yahoo.com
|
| @File : kl_loss.py
|
| @Time : 7/23/19 4:02 PM
|
| @Desc :
|
| @License : This source code is licensed under the license found in the
|
| LICENSE file in the root directory of this source tree.
|
| """
|
| import torch
|
| import torch.nn.functional as F
|
| from torch import nn
|
|
|
|
|
| def flatten_probas(input, target, labels, ignore=255):
|
| """
|
| Flattens predictions in the batch.
|
| """
|
| B, C, H, W = input.size()
|
| input = input.permute(0, 2, 3, 1).contiguous().view(-1, C)
|
| target = target.permute(0, 2, 3, 1).contiguous().view(-1, C)
|
| labels = labels.view(-1)
|
| if ignore is None:
|
| return input, target
|
| valid = (labels != ignore)
|
| vinput = input[valid.nonzero().squeeze()]
|
| vtarget = target[valid.nonzero().squeeze()]
|
| return vinput, vtarget
|
|
|
|
|
| class KLDivergenceLoss(nn.Module):
|
| def __init__(self, ignore_index=255, T=1):
|
| super(KLDivergenceLoss, self).__init__()
|
| self.ignore_index=ignore_index
|
| self.T = T
|
|
|
| def forward(self, input, target, label):
|
| log_input_prob = F.log_softmax(input / self.T, dim=1)
|
| target_porb = F.softmax(target / self.T, dim=1)
|
| loss = F.kl_div(*flatten_probas(log_input_prob, target_porb, label, ignore=self.ignore_index))
|
| return self.T*self.T*loss
|
|
|