| import torch |
| import os |
| from torch import nn |
| import numpy as np |
| import torch.nn.functional |
| from collections import OrderedDict |
| from termcolor import colored |
| import sys |
| from lib.config import cfg |
|
|
|
|
| |
|
|
|
|
| def sigmoid(x): |
| y = torch.clamp(x.sigmoid(), min=1e-4, max=1 - 1e-4) |
| return y |
|
|
|
|
| def _neg_loss(pred, gt): |
| ''' Modified focal loss. Exactly the same as CornerNet. |
| Runs faster and costs a little bit more memory |
| Arguments: |
| pred (batch x c x h x w) |
| gt_regr (batch x c x h x w) |
| ''' |
| pos_inds = gt.eq(1).float() |
| neg_inds = gt.lt(1).float() |
|
|
| neg_weights = torch.pow(1 - gt, 4) |
|
|
| loss = 0 |
|
|
| pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds |
| neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds |
| |
|
|
| num_pos = pos_inds.float().sum() |
| pos_loss = pos_loss.sum() |
| neg_loss = neg_loss.sum() |
|
|
| if num_pos == 0: |
| loss = loss - neg_loss |
| else: |
| loss = loss - (pos_loss + neg_loss) / num_pos |
| return loss |
|
|
|
|
| class FocalLoss(nn.Module): |
| '''nn.Module warpper for focal loss''' |
| def __init__(self): |
| super(FocalLoss, self).__init__() |
| self.neg_loss = _neg_loss |
|
|
| def forward(self, out, target): |
| return self.neg_loss(out, target) |
|
|
|
|
| def smooth_l1_loss(vertex_pred, vertex_targets, vertex_weights, sigma=1.0, normalize=True, reduce=True): |
| """ |
| :param vertex_pred: [b, vn*2, h, w] vn为顶点个数,那么 vn*2 应该为顶点的x,y坐标 |
| :param vertex_targets: [b, vn*2, h, w] |
| :param vertex_weights: [b, 1, h, w] |
| :param sigma: |
| :param normalize: |
| :param reduce: |
| :return: |
| """ |
| b, ver_dim, _, _ = vertex_pred.shape |
| sigma_2 = sigma ** 2 |
| vertex_diff = vertex_pred - vertex_targets |
| diff = vertex_weights * vertex_diff |
| abs_diff = torch.abs(diff) |
| |
| smoothL1_sign = (abs_diff < 1. / sigma_2).detach().float() |
| |
| in_loss = torch.pow(diff, 2) * (sigma_2 / 2.) * smoothL1_sign \ |
| + (abs_diff - (0.5 / sigma_2)) * (1. - smoothL1_sign) |
|
|
| if normalize: |
| in_loss = torch.sum(in_loss.view(b, -1), 1) / (ver_dim * torch.sum(vertex_weights.view(b, -1), 1) + 1e-3) |
|
|
| if reduce: |
| in_loss = torch.mean(in_loss) |
|
|
| return in_loss |
|
|
|
|
| class SmoothL1Loss(nn.Module): |
| |
| def __init__(self): |
| super(SmoothL1Loss, self).__init__() |
| self.smooth_l1_loss = smooth_l1_loss |
|
|
| def forward(self, preds, targets, weights, sigma=1.0, normalize=True, reduce=True): |
| return self.smooth_l1_loss(preds, targets, weights, sigma, normalize, reduce) |
|
|
| |
| class AELoss(nn.Module): |
| def __init__(self): |
| super(AELoss, self).__init__() |
|
|
| def forward(self, ae, ind, ind_mask): |
| """ |
| ae: [b, 1, h, w] |
| ind: [b, max_objs, max_parts] |
| ind_mask: [b, max_objs, max_parts] |
| obj_mask: [b, max_objs] |
| """ |
| |
| b, _, h, w = ae.shape |
| b, max_objs, max_parts = ind.shape |
| obj_mask = torch.sum(ind_mask, dim=2) != 0 |
|
|
| ae = ae.view(b, h * w, 1) |
| seed_ind = ind.view(b, max_objs * max_parts, 1) |
| tag = ae.gather(1, seed_ind).view(b, max_objs, max_parts) |
|
|
| |
| tag_mean = tag * ind_mask |
| tag_mean = tag_mean.sum(2) / (ind_mask.sum(2) + 1e-4) |
|
|
| |
| pull_dist = (tag - tag_mean.unsqueeze(2)).pow(2) * ind_mask |
| obj_num = obj_mask.sum(dim=1).float() |
| pull = (pull_dist.sum(dim=(1, 2)) / (obj_num + 1e-4)).sum() |
| pull /= b |
|
|
| |
| push_dist = torch.abs(tag_mean.unsqueeze(1) - tag_mean.unsqueeze(2)) |
| push_dist = 1 - push_dist |
| push_dist = nn.functional.relu(push_dist, inplace=True) |
| obj_mask = (obj_mask.unsqueeze(1) + obj_mask.unsqueeze(2)) == 2 |
| push_dist = push_dist * obj_mask.float() |
| push = ((push_dist.sum(dim=(1, 2)) - obj_num) / (obj_num * (obj_num - 1) + 1e-4)).sum() |
| push /= b |
| return pull, push |
|
|
| |
| class PolyMatchingLoss(nn.Module): |
| def __init__(self, pnum): |
| super(PolyMatchingLoss, self).__init__() |
|
|
| self.pnum = pnum |
| batch_size = 1 |
| pidxall = np.zeros(shape=(batch_size, pnum, pnum), dtype=np.int32) |
| for b in range(batch_size): |
| for i in range(pnum): |
| pidx = (np.arange(pnum) + i) % pnum |
| pidxall[b, i] = pidx |
|
|
| device = torch.device('cuda') |
| pidxall = torch.from_numpy(np.reshape(pidxall, newshape=(batch_size, -1))).to(device) |
|
|
| self.feature_id = pidxall.unsqueeze_(2).long().expand(pidxall.size(0), pidxall.size(1), 2).detach() |
|
|
| def forward(self, pred, gt, loss_type="L2"): |
| pnum = self.pnum |
| batch_size = pred.size()[0] |
| feature_id = self.feature_id.expand(batch_size, self.feature_id.size(1), 2) |
| device = torch.device('cuda') |
|
|
| gt_expand = torch.gather(gt, 1, feature_id).view(batch_size, pnum, pnum, 2) |
|
|
| pred_expand = pred.unsqueeze(1) |
|
|
| dis = pred_expand - gt_expand |
|
|
| if loss_type == "L2": |
| dis = (dis ** 2).sum(3).sqrt().sum(2) |
| elif loss_type == "L1": |
| dis = torch.abs(dis).sum(3).sum(2) |
|
|
| min_dis, min_id = torch.min(dis, dim=1, keepdim=True) |
| |
|
|
| |
| |
| |
| |
|
|
| return torch.mean(min_dis) |
|
|
| |
| class AttentionLoss(nn.Module): |
| def __init__(self, beta=4, gamma=0.5): |
| super(AttentionLoss, self).__init__() |
|
|
| self.beta = beta |
| self.gamma = gamma |
|
|
| def forward(self, pred, gt): |
| num_pos = torch.sum(gt) |
| num_neg = torch.sum(1 - gt) |
| alpha = num_neg / (num_pos + num_neg) |
| edge_beta = torch.pow(self.beta, torch.pow(1 - pred, self.gamma)) |
| bg_beta = torch.pow(self.beta, torch.pow(pred, self.gamma)) |
|
|
| loss = 0 |
| loss = loss - alpha * edge_beta * torch.log(pred) * gt |
| loss = loss - (1 - alpha) * bg_beta * torch.log(1 - pred) * (1 - gt) |
| return torch.mean(loss) |
|
|
|
|
| def _gather_feat(feat, ind, mask=None): |
| dim = feat.size(2) |
| ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim) |
| feat = feat.gather(1, ind) |
| if mask is not None: |
| mask = mask.unsqueeze(2).expand_as(feat) |
| feat = feat[mask] |
| feat = feat.view(-1, dim) |
| return feat |
|
|
|
|
| def _tranpose_and_gather_feat(feat, ind): |
| feat = feat.permute(0, 2, 3, 1).contiguous() |
| feat = feat.view(feat.size(0), -1, feat.size(3)) |
| feat = _gather_feat(feat, ind) |
| return feat |
|
|
|
|
| |
| class Ind2dRegL1Loss(nn.Module): |
| def __init__(self, type='l1'): |
| super(Ind2dRegL1Loss, self).__init__() |
| if type == 'l1': |
| self.loss = torch.nn.functional.l1_loss |
| elif type == 'smooth_l1': |
| self.loss = torch.nn.functional.smooth_l1_loss |
|
|
| def forward(self, output, target, ind, ind_mask): |
| """ind: [b, max_objs, max_parts]""" |
| b, max_objs, max_parts = ind.shape |
| ind = ind.view(b, max_objs * max_parts) |
| pred = _tranpose_and_gather_feat(output, ind).view(b, max_objs, max_parts, output.size(1)) |
| mask = ind_mask.unsqueeze(3).expand_as(pred) |
| loss = self.loss(pred * mask, target * mask, reduction='sum') |
| loss = loss / (mask.sum() + 1e-4) |
| return loss |
|
|
|
|
| class IndL1Loss1d(nn.Module): |
| def __init__(self, type='l1'): |
| super(IndL1Loss1d, self).__init__() |
| if type == 'l1': |
| self.loss = torch.nn.functional.l1_loss |
| elif type == 'smooth_l1': |
| self.loss = torch.nn.functional.smooth_l1_loss |
|
|
| def forward(self, output, target, ind, weight): |
| """ind: [b, n]""" |
| output = _tranpose_and_gather_feat(output, ind) |
| weight = weight.unsqueeze(2) |
| loss = self.loss(output * weight, target * weight, reduction='sum') |
| loss = loss / (weight.sum() * output.size(2) + 1e-4) |
| return loss |
|
|
|
|
| class GeoCrossEntropyLoss(nn.Module): |
| def __init__(self): |
| super(GeoCrossEntropyLoss, self).__init__() |
|
|
| def forward(self, output, target, poly): |
| output = torch.nn.functional.softmax(output, dim=1) |
| output = torch.log(torch.clamp(output, min=1e-4)) |
| poly = poly.view(poly.size(0), 4, poly.size(1) // 4, 2) |
| target = target[..., None, None].expand(poly.size(0), poly.size(1), 1, poly.size(3)) |
| target_poly = torch.gather(poly, 2, target) |
| sigma = (poly[:, :, 0] - poly[:, :, 1]).pow(2).sum(2, keepdim=True) |
| kernel = torch.exp(-(poly - target_poly).pow(2).sum(3) / (sigma / 3)) |
| loss = -(output * kernel.transpose(2, 1)).sum(1).mean() |
| return loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def load_model(net, optim, scheduler, recorder, model_dir, resume=True, epoch=-1): |
| if not resume: |
| |
| return 0 |
|
|
| if not os.path.exists(model_dir): |
| print(colored('WARNING: NO MODEL LOADED !!!!', 'red')) |
| return 0 |
|
|
| pths = [int(pth.split('.')[0]) for pth in os.listdir(model_dir)] |
| if len(pths) == 0: |
| print(colored('WARNING: NO MODEL LOADED !!!', 'red')) |
| return 0 |
| |
| if epoch == -1: |
| pth = max(pths) |
| else: |
| pth = epoch |
| print('load model: {}'.format(os.path.join(model_dir, '{}.pth'.format(pth)))) |
| pretrained_model = torch.load(os.path.join(model_dir, '{}.pth'.format(pth))) |
| net.load_state_dict(pretrained_model['net']) |
| optim.load_state_dict(pretrained_model['optim']) |
| scheduler.load_state_dict(pretrained_model['scheduler']) |
| recorder.load_state_dict(pretrained_model['recorder']) |
| return pretrained_model['epoch'] + 1 |
|
|
|
|
| def save_model(net, optim, scheduler, recorder, epoch, model_dir): |
| os.system('mkdir -p {}'.format(model_dir)) |
| torch.save({ |
| 'net': net.state_dict(), |
| 'optim': optim.state_dict(), |
| 'scheduler': scheduler.state_dict(), |
| 'recorder': recorder.state_dict(), |
| 'epoch': epoch |
| }, os.path.join(model_dir, '{}.pth'.format(epoch))) |
|
|
| |
| pths = [int(pth.split('.')[0]) for pth in os.listdir(model_dir)] |
| if len(pths) <= 200: |
| return |
| os.system('rm {}'.format(os.path.join(model_dir, '{}.pth'.format(min(pths))))) |
|
|
|
|
| def load_network(net, model_dir, resume=True, epoch=-1, strict=False): |
| if not resume: |
| return 0 |
|
|
| if not os.path.exists(model_dir): |
| print(colored('WARNING: NO MODEL LOADED !!!@!', 'red')) |
| return 0 |
|
|
| pths = [int(pth.split('.')[0]) for pth in os.listdir(cfg.model_dir) if 'pth' in pth] |
| if len(pths) == 0: |
| print(colored('WARNING: NO MODEL LOADED !!!', 'red')) |
| return 0 |
|
|
| if epoch == -1: |
| pth = max(pths) |
| else: |
| pth = epoch |
| print('load model: {}'.format(os.path.join(model_dir, '{}.pth'.format(pth)))) |
| pretrained_model = torch.load(os.path.join(model_dir, '{}.pth'.format(pth))) |
| try: |
| net.load_state_dict(pretrained_model['state_dict'], strict=strict) |
| except KeyError: |
| net.load_state_dict(pretrained_model['net'], strict=strict) |
| return pretrained_model['epoch'] + 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def remove_net_prefix(net, prefix): |
| net_ = OrderedDict() |
| for k in net.keys(): |
| if k.startswith(prefix): |
| net_[k[len(prefix):]] = net[k] |
| else: |
| net_[k] = net[k] |
| return net_ |
|
|
|
|
| def add_net_prefix(net, prefix): |
| net_ = OrderedDict() |
| for k in net.keys(): |
| net_[prefix + k] = net[k] |
| return net_ |
|
|
|
|
| def replace_net_prefix(net, orig_prefix, prefix): |
| net_ = OrderedDict() |
| for k in net.keys(): |
| if k.startswith(orig_prefix): |
| net_[prefix + k[len(orig_prefix):]] = net[k] |
| else: |
| net_[k] = net[k] |
| return net_ |
|
|
|
|
| def remove_net_layer(net, layers): |
| keys = list(net.keys()) |
| for k in keys: |
| for layer in layers: |
| if k.startswith(layer): |
| del net[k] |
| return net |
|
|