Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| import itertools | |
| import torch | |
| from torch import nn | |
| import numpy as np | |
| import cv2 | |
| import torchvision.transforms as transforms | |
| # import torchsnooper ## for debug | |
| class DBLoss(nn.Module): | |
| def __init__(self, alpha=1., beta=10., ohem_ratio=3): | |
| """ | |
| Implement DB Loss. | |
| :param alpha: loss binary_map 前面的系数 | |
| :param beta: loss threshold 前面的系数 | |
| :param ohem_ratio: OHEM的比例 | |
| """ | |
| super().__init__() | |
| self.alpha = alpha | |
| self.beta = beta | |
| self.ohem_ratio = ohem_ratio | |
| def forward(self, outputs, labels, training_masks, G_d): | |
| """ | |
| Implement DB Loss. | |
| :param outputs: N 2 H W | |
| :param labels: N 2 H W | |
| :param training_masks: | |
| """ | |
| prob_map = outputs[:, 0, :, :] | |
| thres_map = outputs[:, 1, :, :] | |
| gt_prob = labels[:, 0, :, :] | |
| gt_thres = labels[:, 1, :, :] | |
| G_d = G_d.to(dtype = torch.float32) | |
| training_masks = training_masks.to(dtype = torch.float32) | |
| # OHEM mask (todo) | |
| # selected_masks = self.ohem_batch(prob_map, gt_prob) | |
| # selected_masks = selected_masks.to(outputs.device) | |
| # 计算 prob loss | |
| loss_prob = self.dice_loss(prob_map, gt_prob, training_masks) | |
| # loss_prob = self.bce_loss(prob_map, gt_prob, selected_masks) | |
| # 计算 binary map loss | |
| bin_map = self.DB(prob_map, thres_map) | |
| loss_bin = self.dice_loss(bin_map, gt_prob, training_masks) | |
| # loss_prob = self.bce_loss(bin_map, gt_prob, selected_masks) | |
| # 计算 threshold map loss | |
| loss_fn = torch.nn.L1Loss(reduction='mean') | |
| L1_loss = loss_fn(thres_map, gt_thres) | |
| loss_thres = L1_loss * G_d | |
| loss_prob = loss_prob.mean() | |
| loss_bin = loss_bin.mean() | |
| loss_thres = loss_thres.mean() | |
| loss_all = loss_prob + self.alpha * loss_bin + self.beta * loss_thres | |
| return loss_all, loss_prob, loss_bin, loss_thres | |
| def DB(self, prob_map, thres_map, k=50): | |
| ''' | |
| Differentiable binarization | |
| another form: torch.sigmoid(k * (prob_map - thres_map)) | |
| ''' | |
| return 1. / (torch.exp((-k * (prob_map - thres_map))) + 1) | |
| def dice_loss(self, pred_cls, gt_cls, training_mask): | |
| ''' | |
| dice loss | |
| 此处默认真实值和预测值的格式均为 NCHW | |
| :param gt_cls: | |
| :param pred_cls: | |
| :param training_mask: | |
| :return: | |
| ''' | |
| eps = 1e-5 | |
| intersection = torch.sum(gt_cls * pred_cls * training_mask) | |
| union = torch.sum(gt_cls * training_mask) + torch.sum(pred_cls * training_mask) + eps | |
| loss = 1. - (2 * intersection / union) | |
| return loss | |
| def bce_loss(self, input, target, mask): | |
| if mask.sum() == 0: | |
| return torch.tensor(0.0, device=input.device, requires_grad=True) | |
| target[target <= 0.5] = 0 | |
| target[target > 0.5] = 1 | |
| input = input[mask.bool()] | |
| target = target[mask.bool()] | |
| loss = nn.BCELoss(reduction='mean')(input, target) | |
| return loss | |
| def ohem_single(self, score, gt_text): | |
| pos_num = (int)(np.sum(gt_text > 0.5)) | |
| if pos_num == 0: | |
| selected_mask = np.zeros_like(score) | |
| selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32') | |
| return selected_mask | |
| neg_num = (int)(np.sum(gt_text <= 0.5)) | |
| neg_num = (int)(min(pos_num * self.ohem_ratio, neg_num)) | |
| if neg_num == 0: | |
| selected_mask = np.zeros_like(score) | |
| selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32') | |
| return selected_mask | |
| neg_score = score[gt_text <= 0.5] | |
| neg_score_sorted = np.sort(-neg_score) | |
| threshold = -neg_score_sorted[neg_num - 1] | |
| selected_mask = (score >= threshold) | (gt_text > 0.5) | |
| selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32') | |
| return selected_mask | |
| def ohem_batch(self, scores, gt_texts): | |
| scores = scores.data.cpu().numpy() | |
| gt_texts = gt_texts.data.cpu().numpy() | |
| selected_masks = [] | |
| for i in range(scores.shape[0]): | |
| selected_masks.append(self.ohem_single(scores[i, :, :], gt_texts[i, :, :])) | |
| selected_masks = np.concatenate(selected_masks, 0) | |
| selected_masks = torch.from_numpy(selected_masks).float() | |
| return selected_masks | |