import torch import os from torch import nn import numpy as np import torch.nn.functional from collections import OrderedDict from termcolor import colored import sys from lib.config import cfg # 损失函数定义------------------------------------------------------------------------------------------------------------ def sigmoid(x): y = torch.clamp(x.sigmoid(), min=1e-4, max=1 - 1e-4) return y def _neg_loss(pred, gt): # 修改版的焦点损失(Modified focal loss),用于解决类别不平衡问题,特别是在目标检测任务中 ''' Modified focal loss. Exactly the same as CornerNet. Runs faster and costs a little bit more memory Arguments: pred (batch x c x h x w) gt_regr (batch x c x h x w) ''' pos_inds = gt.eq(1).float() # 检查 gt(ground truth,即真实标签)中的每个元素是否等于1。如果等于1,则对应位置的元素变为1.0 neg_inds = gt.lt(1).float() # 检查 gt(ground truth,即真实标签)中的每个元素是否小于1。如果小于1,则对应位置的元素变为0.0 neg_weights = torch.pow(1 - gt, 4) # (1 - gt)的四次方 loss = 0 pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds # 正样本:对这个损失求导可以发现,当 pred 接近 0 的时候,导数的梯度很小,导致这些容易分类的样本对总损失的贡献很小 neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds # 负样本:对这个损失求导可以发现,当 pred 接近 1 的时候,导数的梯度很小,导致这些容易分类的样本对总损失的贡献很小 # 二者的结合作用将会使损失更接近 0 和 1 之间的部分,即最难分类的部分 num_pos = pos_inds.float().sum() pos_loss = pos_loss.sum() neg_loss = neg_loss.sum() if num_pos == 0: loss = loss - neg_loss else: loss = loss - (pos_loss + neg_loss) / num_pos return loss class FocalLoss(nn.Module): '''nn.Module warpper for focal loss''' def __init__(self): super(FocalLoss, self).__init__() self.neg_loss = _neg_loss def forward(self, out, target): return self.neg_loss(out, target) def smooth_l1_loss(vertex_pred, vertex_targets, vertex_weights, sigma=1.0, normalize=True, reduce=True): """ :param vertex_pred: [b, vn*2, h, w] vn为顶点个数,那么 vn*2 应该为顶点的x,y坐标 :param vertex_targets: [b, vn*2, h, w] :param vertex_weights: [b, 1, h, w] :param sigma: :param normalize: :param reduce: :return: """ b, ver_dim, _, _ = vertex_pred.shape sigma_2 = sigma ** 2 vertex_diff = vertex_pred - vertex_targets diff = vertex_weights * vertex_diff abs_diff = torch.abs(diff) # 使用 abs_diff 和 sigma_2 来决定是使用L1损失还是L2损失。当 abs_diff 小于 1. / sigma_2 时,使用L2损失;否则使用L1损失 smoothL1_sign = (abs_diff < 1. / sigma_2).detach().float() # detach() 将一个张量(tensor)从当前的计算图中分离出来,使其不再参与梯度计算。换句话说,当你对一个张量调用 detach() 方法后,这个张量将变成一个不需要梯度的张量,它的值可以被用于计算,但不会影响模型的梯度更新。 in_loss = torch.pow(diff, 2) * (sigma_2 / 2.) * smoothL1_sign \ + (abs_diff - (0.5 / sigma_2)) * (1. - smoothL1_sign) if normalize: in_loss = torch.sum(in_loss.view(b, -1), 1) / (ver_dim * torch.sum(vertex_weights.view(b, -1), 1) + 1e-3) if reduce: in_loss = torch.mean(in_loss) return in_loss class SmoothL1Loss(nn.Module): # 封装 smooth_l1_loss def __init__(self): super(SmoothL1Loss, self).__init__() self.smooth_l1_loss = smooth_l1_loss def forward(self, preds, targets, weights, sigma=1.0, normalize=True, reduce=True): return self.smooth_l1_loss(preds, targets, weights, sigma, normalize, reduce) # 好像这个损失函数没用上 class AELoss(nn.Module): def __init__(self): super(AELoss, self).__init__() def forward(self, ae, ind, ind_mask): """ ae: [b, 1, h, w] ind: [b, max_objs, max_parts] ind_mask: [b, max_objs, max_parts] obj_mask: [b, max_objs] """ # first index b, _, h, w = ae.shape b, max_objs, max_parts = ind.shape obj_mask = torch.sum(ind_mask, dim=2) != 0 ae = ae.view(b, h * w, 1) seed_ind = ind.view(b, max_objs * max_parts, 1) tag = ae.gather(1, seed_ind).view(b, max_objs, max_parts) # compute the mean tag_mean = tag * ind_mask tag_mean = tag_mean.sum(2) / (ind_mask.sum(2) + 1e-4) # pull ae of the same object to their mean pull_dist = (tag - tag_mean.unsqueeze(2)).pow(2) * ind_mask obj_num = obj_mask.sum(dim=1).float() pull = (pull_dist.sum(dim=(1, 2)) / (obj_num + 1e-4)).sum() pull /= b # push away the mean of different objects push_dist = torch.abs(tag_mean.unsqueeze(1) - tag_mean.unsqueeze(2)) push_dist = 1 - push_dist push_dist = nn.functional.relu(push_dist, inplace=True) obj_mask = (obj_mask.unsqueeze(1) + obj_mask.unsqueeze(2)) == 2 push_dist = push_dist * obj_mask.float() push = ((push_dist.sum(dim=(1, 2)) - obj_num) / (obj_num * (obj_num - 1) + 1e-4)).sum() push /= b return pull, push # 好像这个损失函数没用上 class PolyMatchingLoss(nn.Module): def __init__(self, pnum): super(PolyMatchingLoss, self).__init__() self.pnum = pnum batch_size = 1 pidxall = np.zeros(shape=(batch_size, pnum, pnum), dtype=np.int32) for b in range(batch_size): for i in range(pnum): pidx = (np.arange(pnum) + i) % pnum pidxall[b, i] = pidx device = torch.device('cuda') pidxall = torch.from_numpy(np.reshape(pidxall, newshape=(batch_size, -1))).to(device) self.feature_id = pidxall.unsqueeze_(2).long().expand(pidxall.size(0), pidxall.size(1), 2).detach() def forward(self, pred, gt, loss_type="L2"): pnum = self.pnum batch_size = pred.size()[0] feature_id = self.feature_id.expand(batch_size, self.feature_id.size(1), 2) device = torch.device('cuda') gt_expand = torch.gather(gt, 1, feature_id).view(batch_size, pnum, pnum, 2) pred_expand = pred.unsqueeze(1) dis = pred_expand - gt_expand if loss_type == "L2": dis = (dis ** 2).sum(3).sqrt().sum(2) elif loss_type == "L1": dis = torch.abs(dis).sum(3).sum(2) min_dis, min_id = torch.min(dis, dim=1, keepdim=True) # print(min_id) # min_id = torch.from_numpy(min_id.data.cpu().numpy()).to(device) # min_gt_id_to_gather = min_id.unsqueeze_(2).unsqueeze_(3).long().\ # expand(min_id.size(0), min_id.size(1), gt_expand.size(2), gt_expand.size(3)) # gt_right_order = torch.gather(gt_expand, 1, min_gt_id_to_gather).view(batch_size, pnum, 2) return torch.mean(min_dis) # 好像这个损失函数没用上 class AttentionLoss(nn.Module): def __init__(self, beta=4, gamma=0.5): super(AttentionLoss, self).__init__() self.beta = beta self.gamma = gamma def forward(self, pred, gt): num_pos = torch.sum(gt) num_neg = torch.sum(1 - gt) alpha = num_neg / (num_pos + num_neg) edge_beta = torch.pow(self.beta, torch.pow(1 - pred, self.gamma)) bg_beta = torch.pow(self.beta, torch.pow(pred, self.gamma)) loss = 0 loss = loss - alpha * edge_beta * torch.log(pred) * gt loss = loss - (1 - alpha) * bg_beta * torch.log(1 - pred) * (1 - gt) return torch.mean(loss) def _gather_feat(feat, ind, mask=None): dim = feat.size(2) ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim) feat = feat.gather(1, ind) if mask is not None: mask = mask.unsqueeze(2).expand_as(feat) feat = feat[mask] feat = feat.view(-1, dim) return feat def _tranpose_and_gather_feat(feat, ind): feat = feat.permute(0, 2, 3, 1).contiguous() feat = feat.view(feat.size(0), -1, feat.size(3)) feat = _gather_feat(feat, ind) return feat # 这个损失也没用上 class Ind2dRegL1Loss(nn.Module): def __init__(self, type='l1'): super(Ind2dRegL1Loss, self).__init__() if type == 'l1': self.loss = torch.nn.functional.l1_loss elif type == 'smooth_l1': self.loss = torch.nn.functional.smooth_l1_loss def forward(self, output, target, ind, ind_mask): """ind: [b, max_objs, max_parts]""" b, max_objs, max_parts = ind.shape ind = ind.view(b, max_objs * max_parts) pred = _tranpose_and_gather_feat(output, ind).view(b, max_objs, max_parts, output.size(1)) mask = ind_mask.unsqueeze(3).expand_as(pred) loss = self.loss(pred * mask, target * mask, reduction='sum') loss = loss / (mask.sum() + 1e-4) return loss class IndL1Loss1d(nn.Module): def __init__(self, type='l1'): super(IndL1Loss1d, self).__init__() if type == 'l1': self.loss = torch.nn.functional.l1_loss elif type == 'smooth_l1': self.loss = torch.nn.functional.smooth_l1_loss def forward(self, output, target, ind, weight): """ind: [b, n]""" output = _tranpose_and_gather_feat(output, ind) weight = weight.unsqueeze(2) loss = self.loss(output * weight, target * weight, reduction='sum') loss = loss / (weight.sum() * output.size(2) + 1e-4) return loss class GeoCrossEntropyLoss(nn.Module): def __init__(self): super(GeoCrossEntropyLoss, self).__init__() def forward(self, output, target, poly): output = torch.nn.functional.softmax(output, dim=1) output = torch.log(torch.clamp(output, min=1e-4)) poly = poly.view(poly.size(0), 4, poly.size(1) // 4, 2) target = target[..., None, None].expand(poly.size(0), poly.size(1), 1, poly.size(3)) target_poly = torch.gather(poly, 2, target) sigma = (poly[:, :, 0] - poly[:, :, 1]).pow(2).sum(2, keepdim=True) kernel = torch.exp(-(poly - target_poly).pow(2).sum(3) / (sigma / 3)) loss = -(output * kernel.transpose(2, 1)).sum(1).mean() return loss # 以下为加载网络的一些函数-------------------------------------------------------------------------------------------------- def load_model(net, optim, scheduler, recorder, model_dir, resume=True, epoch=-1): if not resume: # os.system('rm -rf {}'.format(model_dir)) return 0 if not os.path.exists(model_dir): print(colored('WARNING: NO MODEL LOADED !!!!', 'red')) return 0 pths = [int(pth.split('.')[0]) for pth in os.listdir(model_dir)] if len(pths) == 0: print(colored('WARNING: NO MODEL LOADED !!!', 'red')) return 0 if epoch == -1: pth = max(pths) else: pth = epoch print('load model: {}'.format(os.path.join(model_dir, '{}.pth'.format(pth)))) pretrained_model = torch.load(os.path.join(model_dir, '{}.pth'.format(pth))) net.load_state_dict(pretrained_model['net']) optim.load_state_dict(pretrained_model['optim']) scheduler.load_state_dict(pretrained_model['scheduler']) recorder.load_state_dict(pretrained_model['recorder']) return pretrained_model['epoch'] + 1 def save_model(net, optim, scheduler, recorder, epoch, model_dir): os.system('mkdir -p {}'.format(model_dir)) torch.save({ 'net': net.state_dict(), 'optim': optim.state_dict(), 'scheduler': scheduler.state_dict(), 'recorder': recorder.state_dict(), 'epoch': epoch }, os.path.join(model_dir, '{}.pth'.format(epoch))) # remove previous pretrained model if the number of models is too big pths = [int(pth.split('.')[0]) for pth in os.listdir(model_dir)] if len(pths) <= 200: return os.system('rm {}'.format(os.path.join(model_dir, '{}.pth'.format(min(pths))))) def load_network(net, model_dir, resume=True, epoch=-1, strict=False): if not resume: return 0 if not os.path.exists(model_dir): print(colored('WARNING: NO MODEL LOADED !!!@!', 'red')) return 0 pths = [int(pth.split('.')[0]) for pth in os.listdir(cfg.model_dir) if 'pth' in pth] if len(pths) == 0: print(colored('WARNING: NO MODEL LOADED !!!', 'red')) return 0 if epoch == -1: pth = max(pths) else: pth = epoch print('load model: {}'.format(os.path.join(model_dir, '{}.pth'.format(pth)))) pretrained_model = torch.load(os.path.join(model_dir, '{}.pth'.format(pth))) try: net.load_state_dict(pretrained_model['state_dict'], strict=strict) except KeyError: net.load_state_dict(pretrained_model['net'], strict=strict) return pretrained_model['epoch'] + 1 # 下面的我改了,应该没有用到------------------------------------------------------------------------------------------------- def remove_net_prefix(net, prefix): net_ = OrderedDict() for k in net.keys(): if k.startswith(prefix): net_[k[len(prefix):]] = net[k] else: net_[k] = net[k] return net_ def add_net_prefix(net, prefix): net_ = OrderedDict() for k in net.keys(): net_[prefix + k] = net[k] return net_ def replace_net_prefix(net, orig_prefix, prefix): net_ = OrderedDict() for k in net.keys(): if k.startswith(orig_prefix): net_[prefix + k[len(orig_prefix):]] = net[k] else: net_[k] = net[k] return net_ def remove_net_layer(net, layers): keys = list(net.keys()) for k in keys: for layer in layers: if k.startswith(layer): del net[k] return net