|
|
|
|
|
|
| """
|
| @Author : Peike Li
|
| @Contact : peike.li@yahoo.com
|
| @File : kl_loss.py
|
| @Time : 7/23/19 4:02 PM
|
| @Desc :
|
| @License : This source code is licensed under the license found in the
|
| LICENSE file in the root directory of this source tree.
|
| """
|
| import torch
|
| import torch.nn.functional as F
|
| from torch import nn
|
| from datasets.target_generation import generate_edge_tensor
|
|
|
|
|
| class ConsistencyLoss(nn.Module):
|
| def __init__(self, ignore_index=255):
|
| super(ConsistencyLoss, self).__init__()
|
| self.ignore_index=ignore_index
|
|
|
| def forward(self, parsing, edge, label):
|
| parsing_pre = torch.argmax(parsing, dim=1)
|
| parsing_pre[label==self.ignore_index]=self.ignore_index
|
| generated_edge = generate_edge_tensor(parsing_pre)
|
| edge_pre = torch.argmax(edge, dim=1)
|
| v_generate_edge = generated_edge[label!=255]
|
| v_edge_pre = edge_pre[label!=255]
|
| v_edge_pre = v_edge_pre.type(torch.cuda.FloatTensor)
|
| positive_union = (v_generate_edge==1)&(v_edge_pre==1)
|
| return F.smooth_l1_loss(v_generate_edge[positive_union].squeeze(0), v_edge_pre[positive_union].squeeze(0))
|
|
|