dianecy commited on
Commit
3dcfb26
·
verified ·
1 Parent(s): 8866e31

Upload folder using huggingface_hub

Browse files
ASDA/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
ASDA/utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (137 Bytes). View file
 
ASDA/utils/__pycache__/checkpoint.cpython-39.pyc ADDED
Binary file (3.44 kB). View file
 
ASDA/utils/__pycache__/logger.cpython-39.pyc ADDED
Binary file (2.75 kB). View file
 
ASDA/utils/__pycache__/losses.cpython-39.pyc ADDED
Binary file (5.24 kB). View file
 
ASDA/utils/__pycache__/parsing_metrics.cpython-39.pyc ADDED
Binary file (3.5 kB). View file
 
ASDA/utils/__pycache__/transforms.cpython-39.pyc ADDED
Binary file (9.34 kB). View file
 
ASDA/utils/__pycache__/utils.cpython-39.pyc ADDED
Binary file (9.29 kB). View file
 
ASDA/utils/checkpoint.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import torch
4
+
5
+ def save_checkpoint(state, is_best, args, filename='default'):
6
+ if filename=='default':
7
+ filename = 'mcn_%s_batch%d'%(args.dataset,args.samples_per_gpu)
8
+ print("=> saving checkpoint '{}'".format(filename))
9
+ if not os.path.exists('./saved_models'):
10
+ os.makedirs('./saved_models')
11
+ checkpoint_name = './saved_models/%s_checkpoint.pth.tar'%(filename)
12
+ best_name = './saved_models/%s_model_best.pth.tar'%(filename)
13
+ torch.save(state, checkpoint_name)
14
+ if is_best:
15
+ print("=> saving best model '{}'".format(best_name))
16
+ shutil.copyfile(checkpoint_name, best_name)
17
+
18
+ def load_pretrain(model, args, logging, rank):
19
+ if os.path.isfile(args.pretrain):
20
+ checkpoint = torch.load(args.pretrain)
21
+ pretrained_dict = checkpoint['state_dict']
22
+ if hasattr(model, 'module'):
23
+ model_dict = model.module.state_dict()
24
+ else:
25
+ model_dict = model.state_dict()
26
+ pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
27
+ assert (len([k for k, v in pretrained_dict.items()])!=0)
28
+ model_dict.update(pretrained_dict)
29
+ if hasattr(model, 'module'):
30
+ model.module.load_state_dict(model_dict)
31
+ else:
32
+ model.load_state_dict(model_dict)
33
+ print("=> loaded pretrain model at {}"
34
+ .format(args.pretrain))
35
+ if rank == 0:
36
+ logging.info("=> loaded pretrain model at {}"
37
+ .format(args.pretrain))
38
+ del checkpoint # dereference seems crucial
39
+ torch.cuda.empty_cache()
40
+ else:
41
+ print(("=> no pretrained file found at '{}'".format(args.pretrain)))
42
+ if rank == 0:
43
+ logging.info("=> no pretrained file found at '{}'".format(args.pretrain))
44
+ return model
45
+
46
+ def load_pretrain_ddp(model, args):
47
+ if os.path.isfile(args.pretrain):
48
+ checkpoint = torch.load(args.pretrain)
49
+ pretrained_dict = checkpoint['state_dict']
50
+ model_dict = model.state_dict()
51
+ pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
52
+ assert (len([k for k, v in pretrained_dict.items()])!=0)
53
+ model_dict.update(pretrained_dict)
54
+ if hasattr(model, 'module'):
55
+ state_dict = model.module.state_dict()
56
+ model.module.load_state_dict(model_dict)
57
+ else:
58
+ state_dict = model.state_dict()
59
+ model.load_state_dict(model_dict)
60
+ print("load ")
61
+ print("=> loaded pretrain model at {}"
62
+ .format(args.pretrain))
63
+ del checkpoint # dereference seems crucial
64
+ torch.cuda.empty_cache()
65
+ else:
66
+ print(("=> no pretrained file found at '{}'".format(args.pretrain)))
67
+ return model
68
+
69
+
70
+ def load_resume(model, optimizer, args, logging, rank):
71
+ if os.path.isfile(args.resume):
72
+ print(("=> loading checkpoint '{}'".format(args.resume)))
73
+ if rank == 0:
74
+ logging.info("=> loading checkpoint '{}'".format(args.resume))
75
+ checkpoint = torch.load(args.resume, map_location='cpu')
76
+ args.start_epoch = checkpoint['epoch']
77
+ print("epoch: ", args.start_epoch)
78
+ args.best_iou = checkpoint['best_iou']
79
+ print("best iou: ", args.best_iou)
80
+ state_dict = checkpoint['state_dict']
81
+
82
+ if hasattr(model, 'module'):
83
+ model_dict = model.module.state_dict()
84
+ else:
85
+ model_dict = model.state_dict()
86
+ new_state_dict = {k:v for k,v in state_dict.items() if k in model_dict}
87
+ model_dict.update(new_state_dict)
88
+
89
+
90
+ if hasattr(model, 'module'):
91
+ model.module.load_state_dict(model_dict)
92
+ else:
93
+ model.load_state_dict(model_dict)
94
+ optimizer.load_state_dict(checkpoint['optimizer'])
95
+ del checkpoint # dereference seems crucial
96
+ torch.cuda.empty_cache()
97
+ print("load successfully!")
98
+ else:
99
+ print(("=> no checkpoint found at '{}'".format(args.resume)))
100
+ if rank == 0:
101
+ logging.info(("=> no checkpoint found at '{}'".format(args.resume)))
102
+ return model
ASDA/utils/logger.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ import functools
3
+ import logging
4
+ import os
5
+ import sys
6
+ from termcolor import colored
7
+
8
+ class _ColorfulFormatter(logging.Formatter):
9
+ def __init__(self, *args, **kwargs):
10
+ self._root_name = kwargs.pop("root_name") + "."
11
+ self._abbrev_name = kwargs.pop("abbrev_name", "")
12
+ if len(self._abbrev_name):
13
+ self._abbrev_name = self._abbrev_name + "."
14
+ super(_ColorfulFormatter, self).__init__(*args, **kwargs)
15
+
16
+ def formatMessage(self, record):
17
+ record.name = record.name.replace(self._root_name, self._abbrev_name)
18
+ log = super(_ColorfulFormatter, self).formatMessage(record)
19
+ if record.levelno == logging.WARNING:
20
+ prefix = colored("WARNING", "red", attrs=["blink"])
21
+ elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
22
+ prefix = colored("ERROR", "red", attrs=["blink", "underline"])
23
+ else:
24
+ return log
25
+ return prefix + " " + log
26
+
27
+
28
+ # so that calling setup_logger multiple times won't add many handlers
29
+ @functools.lru_cache()
30
+ def setup_logger(
31
+ output=None, distributed_rank=0, *, color=True, name="imagenet", abbrev_name=None
32
+ ):
33
+ """
34
+ Initialize the detectron2 logger and set its verbosity level to "INFO".
35
+
36
+ Args:
37
+ output (str): a file name or a directory to save log. If None, will not save log file.
38
+ If ends with ".txt" or ".log", assumed to be a file name.
39
+ Otherwise, logs will be saved to `output/log.txt`.
40
+ name (str): the root module name of this logger
41
+
42
+ Returns:
43
+ logging.Logger: a logger
44
+ """
45
+ logger = logging.getLogger(name)
46
+ logger.setLevel(logging.DEBUG)
47
+ logger.propagate = False
48
+
49
+ if abbrev_name is None:
50
+ abbrev_name = name
51
+
52
+ plain_formatter = logging.Formatter(
53
+ '[%(asctime)s.%(msecs)03d]: %(message)s',
54
+ datefmt='%m/%d %H:%M:%S'
55
+ )
56
+ # stdout logging: master only
57
+ if distributed_rank == 0:
58
+ ch = logging.StreamHandler(stream=sys.stdout)
59
+ ch.setLevel(logging.DEBUG)
60
+ if color:
61
+ formatter = _ColorfulFormatter(
62
+ colored("[%(asctime)s.%(msecs)03d]: ", "green") + "%(message)s",
63
+ datefmt="%m/%d %H:%M:%S",
64
+ root_name=name,
65
+ abbrev_name=str(abbrev_name),
66
+ )
67
+ else:
68
+ formatter = plain_formatter
69
+ ch.setFormatter(formatter)
70
+ logger.addHandler(ch)
71
+
72
+ # file logging: all workers
73
+ if output is not None:
74
+ if output.endswith(".txt") or output.endswith(".log"):
75
+ filename = output
76
+ else:
77
+ filename = os.path.join(output, "log.txt")
78
+ if distributed_rank > 0:
79
+ filename = filename + f".rank{distributed_rank}"
80
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
81
+
82
+ fh = logging.StreamHandler(_cached_log_stream(filename))
83
+ fh.setLevel(logging.DEBUG)
84
+ fh.setFormatter(plain_formatter)
85
+ logger.addHandler(fh)
86
+
87
+ return logger
88
+
89
+
90
+ # cache the opened file object, so that different calls to `setup_logger`
91
+ # with the same file name can safely write to the same file.
92
+ @functools.lru_cache(maxsize=None)
93
+ def _cached_log_stream(filename):
94
+ return open(filename, "a")
ASDA/utils/losses.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """
4
+ Custom loss function definitions.
5
+ """
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torch.autograd import Variable
12
+ from utils.utils import *
13
+
14
+ class IoULoss(nn.Module):
15
+ """
16
+ Creates a criterion that computes the Intersection over Union (IoU)
17
+ between a segmentation mask and its ground truth.
18
+
19
+ Rahman, M.A. and Wang, Y:
20
+ Optimizing Intersection-Over-Union in Deep Neural Networks for
21
+ Image Segmentation. International Symposium on Visual Computing (2016)
22
+ http://www.cs.umanitoba.ca/~ywang/papers/isvc16.pdf
23
+ """
24
+
25
+ def __init__(self, size_average=True):
26
+ super().__init__()
27
+ self.size_average = size_average
28
+
29
+ def forward(self, input, target):
30
+ input = F.sigmoid(input)
31
+ intersection = (input * target).sum()
32
+ union = ((input + target) - (input * target)).sum()
33
+ iou = intersection / union
34
+ iou_dual = input.size(0) - iou
35
+ if self.size_average:
36
+ iou_dual = iou_dual / input.size(0)
37
+ return iou_dual
38
+
39
+
40
+ def yolo_loss(input, target, gi, gj, best_n_list, w_coord=5.):
41
+ mseloss = torch.nn.MSELoss(size_average=True)
42
+ celoss = torch.nn.CrossEntropyLoss(size_average=True)
43
+ batch = input.size(0)
44
+
45
+ pred_bbox = Variable(torch.zeros(batch,4).cuda())
46
+ gt_bbox = Variable(torch.zeros(batch,4).cuda())
47
+ for ii in range(batch):
48
+ pred_bbox[ii, 0:2] = F.sigmoid(input[ii,best_n_list[ii],0:2,gj[ii],gi[ii]])
49
+ pred_bbox[ii, 2:4] = input[ii,best_n_list[ii],2:4,gj[ii],gi[ii]]
50
+ gt_bbox[ii, :] = target[ii,best_n_list[ii],:4,gj[ii],gi[ii]]
51
+ loss_x = mseloss(pred_bbox[:,0], gt_bbox[:,0])
52
+ loss_y = mseloss(pred_bbox[:,1], gt_bbox[:,1])
53
+ loss_w = mseloss(pred_bbox[:,2], gt_bbox[:,2])
54
+ loss_h = mseloss(pred_bbox[:,3], gt_bbox[:,3])
55
+
56
+ pred_conf_list, gt_conf_list = [], []
57
+ pred_conf_list.append(input[:,:,4,:,:].contiguous().view(batch,-1))
58
+ gt_conf_list.append(target[:,:,4,:,:].contiguous().view(batch,-1))
59
+ pred_conf = torch.cat(pred_conf_list, dim=1)
60
+ gt_conf = torch.cat(gt_conf_list, dim=1)
61
+ loss_conf = celoss(pred_conf, gt_conf.max(1)[1])
62
+ return (loss_x+loss_y+loss_w+loss_h)*w_coord + loss_conf
63
+
64
+ def build_target(raw_coord, anchors, args):
65
+ coord = Variable(torch.zeros(raw_coord.size(0), raw_coord.size(1)).cuda())
66
+ batch, grid = raw_coord.size(0), args.size//args.gsize
67
+ coord[:,0] = (raw_coord[:,0] + raw_coord[:,2])/(2*args.size) # x 相对原图归一化
68
+ coord[:,1] = (raw_coord[:,1] + raw_coord[:,3])/(2*args.size) # y
69
+ coord[:,2] = (raw_coord[:,2] - raw_coord[:,0])/(args.size) # w
70
+ coord[:,3] = (raw_coord[:,3] - raw_coord[:,1])/(args.size) # h
71
+ coord = coord * grid
72
+ bbox=torch.zeros(coord.size(0),len(anchors),5,grid,grid)
73
+
74
+ best_n_list, best_gi, best_gj = [],[],[]
75
+
76
+ for ii in range(batch):
77
+ gi = coord[ii,0].long()
78
+ gj = coord[ii,1].long()
79
+ tx = coord[ii,0] - gi.float()
80
+ ty = coord[ii,1] - gj.float()
81
+ gw = coord[ii,2]
82
+ gh = coord[ii,3]
83
+
84
+ scaled_anchors = [ (x[0] / (args.anchor_imsize/grid), \
85
+ x[1] / (args.anchor_imsize/grid)) for x in anchors]
86
+
87
+ ## Get shape of gt box
88
+ gt_box = torch.FloatTensor(np.array([0, 0, gw, gh],dtype=np.float32)).unsqueeze(0) #[1,4]
89
+ ## Get shape of anchor box
90
+ anchor_shapes = torch.FloatTensor(np.concatenate((np.zeros((len(scaled_anchors), 2)), np.array(scaled_anchors)), 1))
91
+ ## Calculate iou between gt and anchor shapes
92
+ anch_ious = list(bbox_iou(gt_box, anchor_shapes,x1y1x2y2=False))
93
+ ## Find the best matching anchor box
94
+ best_n = np.argmax(np.array(anch_ious))
95
+
96
+ tw = torch.log(gw / scaled_anchors[best_n][0] + 1e-16)
97
+ th = torch.log(gh / scaled_anchors[best_n][1] + 1e-16)
98
+
99
+ bbox[ii, best_n, :, gj, gi] = torch.stack([tx, ty, tw, th, torch.ones(1).cuda().squeeze()])
100
+ best_n_list.append(int(best_n))
101
+ best_gi.append(gi)
102
+ best_gj.append(gj)
103
+ bbox = Variable(bbox.cuda())
104
+ return bbox, best_gi, best_gj, best_n_list
105
+
106
+ def adjust_learning_rate(args, optimizer, i_iter):
107
+ # print(optimizer.param_groups[0]['lr'], optimizer.param_groups[1]['lr'])
108
+ if i_iter in args.steps:
109
+ #lr = args.lr * args.power
110
+ lr = args.lr * args.power ** (args.steps.index(i_iter) + 1)
111
+ optimizer.param_groups[0]['lr'] = lr
112
+ if len(optimizer.param_groups) > 1:
113
+ optimizer.param_groups[1]['lr'] = lr / 10
114
+ if len(optimizer.param_groups) > 2:
115
+ optimizer.param_groups[2]['lr'] = lr / 10
116
+
117
+ def cem_loss(co_energy):
118
+ loss = -1.0 * torch.log(co_energy+1e-6).sum()
119
+ return loss
120
+
121
+ class FocalLoss(nn.Module):
122
+ def __init__(self, alpha=0.25, gamma=2, logits=True, reduce=False):
123
+ super(FocalLoss, self).__init__()
124
+ self.alpha = alpha
125
+ self.gamma = gamma
126
+ self.logits = logits
127
+ self.reduce = reduce
128
+
129
+ def forward(self, inputs, targets):
130
+ if self.logits:
131
+ BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False)
132
+ else:
133
+ BCE_loss = F.binary_cross_entropy(inputs, targets, reduce=False)
134
+ pt = torch.exp(-BCE_loss)
135
+ F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
136
+ if self.reduce:
137
+ return torch.mean(F_loss)
138
+ else:
139
+ return torch.sum(F_loss)
ASDA/utils/misc_utils.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """
4
+ Misc download and visualization helper functions and class wrappers.
5
+ """
6
+
7
+ import sys
8
+ import time
9
+ import torch
10
+ from visdom import Visdom
11
+
12
+
13
+ def reporthook(count, block_size, total_size):
14
+ global start_time
15
+ if count == 0:
16
+ start_time = time.time()
17
+ return
18
+ duration = time.time() - start_time
19
+ progress_size = int(count * block_size)
20
+ speed = int(progress_size / (1024 * duration))
21
+ percent = min(int(count * block_size * 100 / total_size), 100)
22
+ sys.stdout.write("\r...%d%%, %d MB, %d KB/s, %d seconds passed" %
23
+ (percent, progress_size / (1024 * 1024), speed, duration))
24
+ sys.stdout.flush()
25
+
26
+
27
+ class VisdomWrapper(Visdom):
28
+ def __init__(self, *args, env=None, **kwargs):
29
+ Visdom.__init__(self, *args, **kwargs)
30
+ self.env = env
31
+ self.plots = {}
32
+
33
+ def init_line_plot(self, name,
34
+ X=torch.zeros((1,)).cpu(),
35
+ Y=torch.zeros((1,)).cpu(), **opts):
36
+ self.plots[name] = self.line(X=X, Y=Y, env=self.env, opts=opts)
37
+
38
+ def plot_line(self, name, **kwargs):
39
+ self.line(win=self.plots[name], env=self.env, **kwargs)
ASDA/utils/parsing_metrics.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ def _fast_hist(label_true, label_pred, n_class):
5
+ mask = (label_true >= 0) & (label_true < n_class)
6
+ hist = np.bincount(
7
+ n_class * label_true[mask].astype(int) +
8
+ label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class)
9
+ return hist
10
+
11
+ def label_accuracy_score(label_trues, label_preds, n_class, bg_thre=200):
12
+ """Returns accuracy score evaluation result.
13
+ - overall accuracy
14
+ - mean accuracy
15
+ - mean IU
16
+ - fwavacc
17
+ """
18
+ hist = np.zeros((n_class, n_class))
19
+ for lt, lp in zip(label_trues, label_preds):
20
+ # hist += _fast_hist(lt.flatten(), lp.flatten(), n_class)
21
+ hist += _fast_hist(lt[lt<bg_thre].flatten(), lp[lt<bg_thre].flatten(), n_class)
22
+ acc = np.diag(hist).sum() / hist.sum()
23
+ acc_cls = np.diag(hist) / hist.sum(axis=1)
24
+ acc_cls = np.nanmean(acc_cls)
25
+ iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
26
+ mean_iu = np.nanmean(iu)
27
+ freq = hist.sum(axis=1) / hist.sum()
28
+ fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
29
+ return acc, acc_cls, mean_iu, fwavacc
30
+
31
+ def label_confusion_matrix(label_trues, label_preds, n_class, bg_thre=200):
32
+ # eps=1e-20
33
+ hist=np.zeros((n_class,n_class),dtype=float)
34
+ """ (8,256,256), (256,256) """
35
+ for lt,lp in zip(label_trues, label_preds):
36
+ # hist += _fast_hist(lt.flatten(), lp.flatten(), n_class)
37
+ hist += _fast_hist(lt[lt<bg_thre].flatten(), lp[lt<bg_thre].flatten(), n_class)
38
+ iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
39
+ # for i in range(n_class):
40
+ # hist[i,:]=(hist[i,:]+eps)/sum(hist[i,:]+eps)
41
+ return hist, iu
42
+
43
+ def body_region_confusion_matrix(label_trues, label_preds, n_class, boxes, counter):
44
+ ## pred: [bb,region_index,c,h,w] (pred score)
45
+ ## gt: [bb,region_index,h,w] (0-nclass score)
46
+ label_trues = label_trues.data.cpu().numpy()
47
+ label_preds = label_preds.data.cpu().numpy()
48
+ hist=np.zeros((label_trues.shape[1],n_class,n_class),dtype=float)
49
+ for body_i in range(label_trues.shape[1]):
50
+ for bb in range(label_trues.shape[0]):
51
+ if body_i != label_trues.shape[1]-1 and \
52
+ torch.equal(boxes[bb,body_i,:], torch.Tensor([0.,0.,1.,1.])):
53
+ counter+=1
54
+ continue
55
+ else:
56
+ hist[body_i,:,:] += label_confusion_matrix(label_trues[bb,body_i,:,:], \
57
+ np.argmax(label_preds[bb,body_i,:,:,:], axis=0), n_class)[0]
58
+ return hist
59
+
60
+ def hist_based_accu_cal(hist):
61
+ acc = np.diag(hist).sum() / hist.sum()
62
+ acc_cls = np.diag(hist) / hist.sum(axis=1)
63
+ acc_cls = np.nanmean(acc_cls)
64
+ iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
65
+ mean_iu = np.nanmean(iu)
66
+ freq = hist.sum(axis=1) / hist.sum()
67
+ fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
68
+ return acc, acc_cls, mean_iu, fwavacc, iu
69
+
70
+ def cal_seg_iou_loss(gt,pred,trsh=0.5):
71
+ t=np.array(pred>trsh)
72
+ p=np.array(gt>0.)
73
+ intersection = np.logical_and(t, p)
74
+ union = np.logical_or(t, p)
75
+ iou = (np.sum(intersection > 0 , axis=(2,3)) + 1e-10 )/ (np.sum(union > 0, axis=(2,3)) + 1e-10)
76
+ return iou
77
+
78
+ def cal_seg_iou(gt,pred,trsh=0.5):
79
+ #(gt.shape) [1 428 640]
80
+ #(pred.shape) [428 640]
81
+ t=np.array(pred>trsh)
82
+ p=np.array(gt>0.)
83
+ intersection = np.logical_and(t, p)
84
+ union = np.logical_or(t, p)
85
+ iou = (np.sum(intersection > 0) + 1e-10 )/ (np.sum(union > 0) + 1e-10)
86
+
87
+ prec=dict()
88
+ thresholds = np.arange(0.5, 1, 0.05)
89
+ for thresh in thresholds:
90
+ prec[thresh]= float(iou > thresh)
91
+ return iou,prec
92
+
93
+ def cal_seg_iou2(gt,pred,trsh=0.5):
94
+ #(gt.shape) [1 428 640]
95
+ #(pred.shape) [428 640]
96
+ t=np.array(pred>trsh)
97
+ p=np.array(gt>0.)
98
+ intersection = np.logical_and(t, p)
99
+ union = np.logical_or(t, p)
100
+ iou = (np.sum(intersection > 0) + 1e-10 )/ (np.sum(union > 0) + 1e-10)
101
+
102
+ prec=dict()
103
+ thresholds = np.arange(0.5, 1, 0.05)
104
+ for thresh in thresholds:
105
+ prec[thresh]= float(iou > thresh)
106
+ return iou, prec, np.sum(intersection > 0), np.sum(union > 0)
ASDA/utils/transforms.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """
4
+ Generic Image Transform utillities.
5
+ """
6
+
7
+ import cv2
8
+ import random, math
9
+ import numpy as np
10
+ from collections.abc import Iterable
11
+ from torch import rand
12
+
13
+ import torch.nn.functional as F
14
+ from torch.autograd import Variable
15
+
16
+
17
+ class ResizePad:
18
+ """
19
+ Resize and pad an image to given size.
20
+ """
21
+
22
+ def __init__(self, size):
23
+ if not isinstance(size, (int, Iterable)):
24
+ raise TypeError('Got inappropriate size arg: {}'.format(size))
25
+
26
+ self.h, self.w = size
27
+
28
+ def __call__(self, img):
29
+ h, w = img.shape[:2]
30
+ scale = min(self.h / h, self.w / w)
31
+ resized_h = int(np.round(h * scale))
32
+ resized_w = int(np.round(w * scale))
33
+ pad_h = int(np.floor(self.h - resized_h) / 2)
34
+ pad_w = int(np.floor(self.w - resized_w) / 2)
35
+
36
+ resized_img = cv2.resize(img, (resized_w, resized_h))
37
+
38
+ # if img.ndim > 2:
39
+ if img.ndim > 2:
40
+ new_img = np.zeros(
41
+ (self.h, self.w, img.shape[-1]), dtype=resized_img.dtype)
42
+ else:
43
+ resized_img = np.expand_dims(resized_img, -1)
44
+ new_img = np.zeros((self.h, self.w, 1), dtype=resized_img.dtype)
45
+ new_img[pad_h: pad_h + resized_h,
46
+ pad_w: pad_w + resized_w, ...] = resized_img
47
+ return new_img
48
+
49
+
50
+ class CropResize:
51
+ """Remove padding and resize image to its original size."""
52
+
53
+ def __call__(self, img, size):
54
+ if not isinstance(size, (int, Iterable)):
55
+ raise TypeError('Got inappropriate size arg: {}'.format(size))
56
+ im_h, im_w = img.data.shape[:2]
57
+ input_h, input_w = size
58
+ scale = max(input_h / im_h, input_w / im_w)
59
+ # scale = torch.Tensor([[input_h / im_h, input_w / im_w]]).max()
60
+ resized_h = int(np.round(im_h * scale))
61
+ # resized_h = torch.round(im_h * scale)
62
+ resized_w = int(np.round(im_w * scale))
63
+ # resized_w = torch.round(im_w * scale)
64
+ crop_h = int(np.floor(resized_h - input_h) / 2)
65
+ # crop_h = torch.floor(resized_h - input_h) // 2
66
+ crop_w = int(np.floor(resized_w - input_w) / 2)
67
+ # crop_w = torch.floor(resized_w - input_w) // 2
68
+ # resized_img = cv2.resize(img, (resized_w, resized_h))
69
+ resized_img = F.upsample(
70
+ img.unsqueeze(0).unsqueeze(0), size=(resized_h, resized_w),
71
+ mode='bilinear')
72
+
73
+ resized_img = resized_img.squeeze().unsqueeze(0)
74
+
75
+ return resized_img[0, crop_h: crop_h + input_h,
76
+ crop_w: crop_w + input_w]
77
+
78
+
79
+ class ResizeImage:
80
+ """Resize the largest of the sides of the image to a given size"""
81
+ def __init__(self, size):
82
+ if not isinstance(size, (int, Iterable)):
83
+ raise TypeError('Got inappropriate size arg: {}'.format(size))
84
+
85
+ self.size = size
86
+
87
+ def __call__(self, img):
88
+ im_h, im_w = img.shape[-2:]
89
+ scale = min(self.size / im_h, self.size / im_w)
90
+ resized_h = int(np.round(im_h * scale))
91
+ resized_w = int(np.round(im_w * scale))
92
+ out = F.upsample(
93
+ Variable(img).unsqueeze(0), size=(resized_h, resized_w),
94
+ mode='bilinear').squeeze().data
95
+ return out
96
+
97
+
98
+ class ResizeAnnotation:
99
+ """Resize the largest of the sides of the annotation to a given size"""
100
+ def __init__(self, size):
101
+ if not isinstance(size, (int, Iterable)):
102
+ raise TypeError('Got inappropriate size arg: {}'.format(size))
103
+
104
+ self.size = size
105
+
106
+ def __call__(self, img):
107
+ im_h, im_w = img.shape[-2:]
108
+ scale = min(self.size / im_h, self.size / im_w)
109
+ resized_h = int(np.round(im_h * scale))
110
+ resized_w = int(np.round(im_w * scale))
111
+ out = F.upsample(
112
+ Variable(img).unsqueeze(0).unsqueeze(0),
113
+ size=(resized_h, resized_w),
114
+ mode='bilinear').squeeze().data
115
+ return out
116
+
117
+
118
+ class ToNumpy:
119
+ """Transform an torch.*Tensor to an numpy ndarray."""
120
+
121
+ def __call__(self, x):
122
+ return x.numpy()
123
+
124
+ def letterbox(img, mask, height, color=(123.7, 116.3, 103.5)): # resize a rectangular image to a padded square
125
+ shape = img.shape[:2] # shape = [height, width]
126
+ ratio = float(height) / max(shape) # ratio = old / new
127
+ new_shape = (round(shape[1] * ratio), round(shape[0] * ratio))
128
+ dw = (height - new_shape[0]) / 2 # width padding
129
+ dh = (height - new_shape[1]) / 2 # height padding
130
+ top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
131
+ left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
132
+ img = cv2.resize(img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border
133
+ img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # padded square
134
+ if mask is not None:
135
+ mask = cv2.resize(mask, new_shape, interpolation=cv2.INTER_NEAREST) # resized, no border
136
+ mask = cv2.copyMakeBorder(mask, top, bottom, left, right, cv2.BORDER_CONSTANT, value=0) # padded square
137
+ return img, mask, ratio, dw, dh
138
+
139
+
140
+ def random_affine(img, mask, targets, degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-2, 2),
141
+ borderValue=(123.7, 116.3, 103.5), all_bbox=None):
142
+ border = 0 # width of added border (optional)
143
+ height = max(img.shape[0], img.shape[1]) + border * 2
144
+
145
+ # Rotation and Scale
146
+ R = np.eye(3)
147
+ a = random.random() * (degrees[1] - degrees[0]) + degrees[0]
148
+ # a += random.choice([-180, -90, 0, 90]) # 90deg rotations added to small rotations
149
+ s = random.random() * (scale[1] - scale[0]) + scale[0]
150
+ R[:2] = cv2.getRotationMatrix2D(angle=a, center=(img.shape[1] / 2, img.shape[0] / 2), scale=s)
151
+
152
+ # Translation
153
+ T = np.eye(3)
154
+ T[0, 2] = (random.random() * 2 - 1) * translate[0] * img.shape[0] + border # x translation (pixels)
155
+ T[1, 2] = (random.random() * 2 - 1) * translate[1] * img.shape[1] + border # y translation (pixels)
156
+
157
+ # Shear
158
+ S = np.eye(3)
159
+ S[0, 1] = math.tan((random.random() * (shear[1] - shear[0]) + shear[0]) * math.pi / 180) # x shear (deg)
160
+ S[1, 0] = math.tan((random.random() * (shear[1] - shear[0]) + shear[0]) * math.pi / 180) # y shear (deg)
161
+
162
+ M = S @ T @ R # Combined rotation matrix. ORDER IS IMPORTANT HERE!!
163
+ imw = cv2.warpPerspective(img, M, dsize=(height, height), flags=cv2.INTER_LINEAR,
164
+ borderValue=borderValue) # BGR order borderValue
165
+ if mask is not None:
166
+ maskw = cv2.warpPerspective(mask, M, dsize=(height, height), flags=cv2.INTER_NEAREST,
167
+ borderValue=0) # BGR order borderValue
168
+ else:
169
+ maskw = None
170
+
171
+ # Return warped points also
172
+ if type(targets)==type([1]):
173
+ targetlist=[]
174
+ for bbox in targets:
175
+ targetlist.append(wrap_points(bbox, M, height, a))
176
+ return imw, maskw, targetlist, M
177
+ elif all_bbox is not None:
178
+ targets = wrap_points(targets, M, height, a)
179
+ for ii in range(all_bbox.shape[0]):
180
+ all_bbox[ii,:] = wrap_points(all_bbox[ii,:], M, height, a)
181
+ return imw, maskw, targets, all_bbox, M
182
+ elif targets is not None: ## previous main
183
+ targets = wrap_points(targets, M, height, a)
184
+ return imw, maskw, targets, M
185
+ else:
186
+ return imw
187
+
188
+ def wrap_points(targets, M, height, a):
189
+ # n = targets.shape[0]
190
+ # points = targets[:, 1:5].copy()
191
+ points = targets.copy()
192
+ # area0 = (points[:, 2] - points[:, 0]) * (points[:, 3] - points[:, 1])
193
+ area0 = (points[2] - points[0]) * (points[3] - points[1])
194
+
195
+ # warp points
196
+ xy = np.ones((4, 3))
197
+ xy[:, :2] = points[[0, 1, 2, 3, 0, 3, 2, 1]].reshape(4, 2) # x1y1, x2y2, x1y2, x2y1
198
+ xy = (xy @ M.T)[:, :2].reshape(1, 8)
199
+
200
+ # create new boxes
201
+ x = xy[:, [0, 2, 4, 6]]
202
+ y = xy[:, [1, 3, 5, 7]]
203
+ xy = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, 1).T
204
+
205
+ # apply angle-based reduction
206
+ radians = a * math.pi / 180
207
+ reduction = max(abs(math.sin(radians)), abs(math.cos(radians))) ** 0.5
208
+ x = (xy[:, 2] + xy[:, 0]) / 2
209
+ y = (xy[:, 3] + xy[:, 1]) / 2
210
+ w = (xy[:, 2] - xy[:, 0]) * reduction
211
+ h = (xy[:, 3] - xy[:, 1]) * reduction
212
+ xy = np.concatenate((x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, 1).T
213
+
214
+ # reject warped points outside of image
215
+ np.clip(xy, 0, height, out=xy)
216
+ w = xy[:, 2] - xy[:, 0]
217
+ h = xy[:, 3] - xy[:, 1]
218
+ area = w * h
219
+ ar = np.maximum(w / (h + 1e-16), h / (w + 1e-16))
220
+ i = (w > 4) & (h > 4) & (area / (area0 + 1e-16) > 0.1) & (ar < 10)
221
+
222
+ ## print(targets, xy)
223
+ ## [ 56 36 108 210] [[ 47.80464857 15.6096533 106.30993434 196.71267693]]
224
+ # targets = targets[i]
225
+ # targets[:, 1:5] = xy[i]
226
+ targets = xy[0]
227
+ return targets
228
+
229
+
230
+ def random_crop(img, seg, pad, h, w):
231
+ if random.random() < 0.5:
232
+ return img, seg
233
+
234
+ img = cv2.copyMakeBorder(img, pad, pad, pad, pad, cv2.BORDER_CONSTANT, value=(123.7, 116.3, 103.5))
235
+ seg = cv2.copyMakeBorder(seg, pad, pad, pad, pad, cv2.BORDER_CONSTANT, value=(0, 0, 0))
236
+
237
+ Left = random.randint(0, pad * 2)
238
+ Top = random.randint(0, pad * 2)
239
+
240
+ seg_pixel = seg.sum()
241
+
242
+ for _ in range(100):
243
+ if seg[Top: Top + h, Left: Left + w].sum() / seg_pixel > 0.95 and seg[Top: Top + h, Left: Left + w].sum() > 0:
244
+ img = img[Top: Top + h, Left: Left + w, :]
245
+ seg = seg[Top: Top + h, Left: Left + w]
246
+
247
+ return img, seg
248
+
249
+ Left = random.randint(0, pad * 2)
250
+ Top = random.randint(0, pad * 2)
251
+
252
+ return img, seg
253
+
254
+
255
+ def random_copy(img, seg, phrase, bbox):
256
+ if 'left' in phrase or 'right' in phrase or \
257
+ 'center' in phrase or 'middle' in phrase or \
258
+ 'front' in phrase or 'back' in phrase:
259
+ return img, seg, phrase, bbox
260
+
261
+ if random.random() < 0.75:
262
+ return img, seg, phrase, bbox
263
+
264
+ h, w = img.shape[0], img.shape[1]
265
+
266
+ # x1, y1, x2, y2 = w, h, 0, 0
267
+ # for j in range(h):
268
+ # for i in range(w):
269
+ # if seg[j, i] > 0:
270
+ # if i < x1: x1 = i
271
+ # if j < y1: y1 = j
272
+ # if i > x2: x2 = i
273
+ # if j > y2: y2 = j
274
+ # x2 = x2 + 1
275
+ # y2 = y2 + 1
276
+
277
+ # contours, hierarchy = cv2.findContours(seg.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
278
+ # c = max(contours, key = cv2.contourArea)
279
+ x, y, bboxw, bboxh = cv2.boundingRect(seg.astype(np.uint8))
280
+ x1 = x
281
+ y1 = y
282
+ x2 = x + bboxw
283
+ y2 = y + bboxh
284
+
285
+ if x1 - (x2 - x1) < 0 or w - (x2 - x1) < x2:
286
+ return img, seg, phrase, bbox
287
+
288
+ # tmp = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
289
+ # color_mask = np.array([0, 255, 0], dtype=np.uint8)
290
+ # mask = seg.astype(np.bool)
291
+ # tmp[mask] = tmp[mask] * 0.5 + color_mask * 0.5
292
+ # cv2.imwrite('./{}.png'.format(phrase.replace(' ', '_')), tmp)
293
+
294
+ if random.random() < 0.5:
295
+ new_x1 = random.randint(0, x1 - (x2 - x1))
296
+ phrase += ' on left'
297
+ else:
298
+ new_x1 = random.randint(x2, w - (x2 - x1))
299
+ phrase += ' on right'
300
+
301
+ new_x2 = new_x1 + (x2 - x1)
302
+
303
+ delta_y = random.randint((y1 - y2), y2 - y1)
304
+
305
+ while y2 + delta_y > h or y1 + delta_y < 0:
306
+ delta_y = random.randint((y1 - y2), y2 - y1)
307
+
308
+ new_y1 = y1 + delta_y
309
+ new_y2 = y2 + delta_y
310
+
311
+ new_seg = np.zeros_like(seg)
312
+ new_seg[new_y1: new_y2, new_x1: new_x2] = seg[y1: y2, x1: x2]
313
+
314
+ # tmp = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
315
+ # color_mask = np.array([0, 255, 0], dtype=np.uint8)
316
+ # mask = new_seg.astype(np.bool)
317
+ # tmp[mask] = tmp[mask] * 0.5 + color_mask * 0.5
318
+ # cv2.imwrite('./{}.png'.format(phrase.replace(' ', '_')), tmp)
319
+
320
+ img[new_seg.astype(np.bool)] = img[seg.astype(np.bool)]
321
+ # bbox = [new_x1, new_y1, new_x2 - 1, new_y2 - 1]
322
+ seg = new_seg
323
+
324
+ # tmp = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
325
+ # color_mask = np.array([0, 255, 0], dtype=np.uint8)
326
+ # mask = seg.astype(np.bool)
327
+ # tmp[mask] = tmp[mask] * 0.5 + color_mask * 0.5
328
+ # cv2.imwrite('./{}.png'.format(phrase.replace(' ', '_')), tmp)
329
+
330
+ # exit()
331
+
332
+ return img, seg, phrase, bbox
333
+
334
+
335
+ def random_erase(img, seg):
336
+ if random.random() < 0.5:
337
+ return img, seg
338
+
339
+ x, y, bboxw, bboxh = cv2.boundingRect(seg.astype(np.uint8))
340
+
341
+ area = bboxw * bboxh * 0.5
342
+
343
+ for attempt in range(100):
344
+ target_area = random.uniform(0.02, 0.4)
345
+ aspect_ratio = random.uniform(0.3, 1/0.3)
346
+
347
+ h = int(round(math.sqrt(target_area * aspect_ratio)))
348
+ w = int(round(math.sqrt(target_area / aspect_ratio)))
349
+
350
+ if w < bboxw and h < bboxh:
351
+ x1 = random.randint(0, bboxw - w)
352
+ y1 = random.randint(0, bboxh - h)
353
+
354
+ new_seg = seg.copy()
355
+ new_seg[y+y1: y+y1+h, x+x1: x+x1+w] = 0
356
+
357
+ if new_seg.sum() / seg.sum() > 0.75:
358
+ continue
359
+
360
+ seg[y+y1: y+y1+h, x+x1: x+x1+w] = 0
361
+
362
+ img[y+y1: y+y1+h, x+x1: x+x1+w, 0] = 123.7
363
+ img[y+y1: y+y1+h, x+x1: x+x1+w, 1] = 116.3
364
+ img[y+y1: y+y1+h, x+x1: x+x1+w, 2] = 103.5
365
+
366
+ # tmp = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
367
+ # color_mask = np.array([0, 255, 0], dtype=np.uint8)
368
+ # mask = seg.astype(np.bool)
369
+ # tmp[mask] = tmp[mask] * 0.5 + color_mask * 0.5
370
+ # cv2.imwrite('./erase.png', tmp)
371
+
372
+ return img, seg
373
+
374
+ return img, seg
ASDA/utils/utils.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import optim
6
+ from torch.optim import Optimizer
7
+
8
+ class AverageMeter(object):
9
+ """Computes and stores the average and current value"""
10
+ def __init__(self):
11
+ self.reset()
12
+
13
+ def reset(self):
14
+ self.val = 0
15
+ self.avg = 0
16
+ self.sum = 0
17
+ self.count = 0
18
+
19
+ def update(self, val, n=1):
20
+ self.val = val
21
+ self.sum += val * n
22
+ self.count += n
23
+ self.avg = self.sum / self.count
24
+
25
+ def xyxy2xywh(x): # Convert bounding box format from [x1, y1, x2, y2] to [x, y, w, h]
26
+ y = torch.zeros(x.shape) if x.dtype is torch.float32 else np.zeros(x.shape)
27
+ y[:, 0] = (x[:, 0] + x[:, 2]) / 2
28
+ y[:, 1] = (x[:, 1] + x[:, 3]) / 2
29
+ y[:, 2] = x[:, 2] - x[:, 0]
30
+ y[:, 3] = x[:, 3] - x[:, 1]
31
+ return y
32
+
33
+
34
+ def xywh2xyxy(x): # Convert bounding box format from [x, y, w, h] to [x1, y1, x2, y2]
35
+ y = torch.zeros(x.shape) if x.dtype is torch.float32 else np.zeros(x.shape)
36
+ y[:, 0] = (x[:, 0] - x[:, 2] / 2)
37
+ y[:, 1] = (x[:, 1] - x[:, 3] / 2)
38
+ y[:, 2] = (x[:, 0] + x[:, 2] / 2)
39
+ y[:, 3] = (x[:, 1] + x[:, 3] / 2)
40
+ return y
41
+
42
+ def bbox_iou_numpy(box1, box2):
43
+ """Computes IoU between bounding boxes.
44
+ Parameters
45
+ ----------
46
+ box1 : ndarray
47
+ (N, 4) shaped array with bboxes
48
+ box2 : ndarray
49
+ (M, 4) shaped array with bboxes
50
+ Returns
51
+ -------
52
+ : ndarray
53
+ (N, M) shaped array with IoUs
54
+ """
55
+ area = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1])
56
+
57
+ iw = np.minimum(np.expand_dims(box1[:, 2], axis=1), box2[:, 2]) - np.maximum(
58
+ np.expand_dims(box1[:, 0], 1), box2[:, 0]
59
+ )
60
+ ih = np.minimum(np.expand_dims(box1[:, 3], axis=1), box2[:, 3]) - np.maximum(
61
+ np.expand_dims(box1[:, 1], 1), box2[:, 1]
62
+ )
63
+
64
+ iw = np.maximum(iw, 0)
65
+ ih = np.maximum(ih, 0)
66
+
67
+ ua = np.expand_dims((box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1]), axis=1) + area - iw * ih
68
+
69
+ ua = np.maximum(ua, np.finfo(float).eps)
70
+
71
+ intersection = iw * ih
72
+
73
+ return intersection / ua
74
+
75
+
76
+ def bbox_iou(box1, box2, x1y1x2y2=True):
77
+ """
78
+ Returns the IoU of two bounding boxes
79
+ """
80
+ if x1y1x2y2:
81
+ # Get the coordinates of bounding boxes
82
+ b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3]
83
+ b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3]
84
+ else:
85
+ # Transform from center and width to exact coordinates
86
+ b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2
87
+ b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2
88
+ b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2
89
+ b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2
90
+
91
+ # get the coordinates of the intersection rectangle
92
+ inter_rect_x1 = torch.max(b1_x1, b2_x1)
93
+ inter_rect_y1 = torch.max(b1_y1, b2_y1)
94
+ inter_rect_x2 = torch.min(b1_x2, b2_x2)
95
+ inter_rect_y2 = torch.min(b1_y2, b2_y2)
96
+ # Intersection area
97
+ inter_area = torch.clamp(inter_rect_x2 - inter_rect_x1, 0) * torch.clamp(inter_rect_y2 - inter_rect_y1, 0)
98
+ # Union Area
99
+ b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
100
+ b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
101
+
102
+ # print(box1, box1.shape)
103
+ # print(box2, box2.shape)
104
+ return inter_area / (b1_area + b2_area - inter_area + 1e-16)
105
+
106
+ def multiclass_metrics(pred, gt):
107
+ """
108
+ check precision and recall for predictions.
109
+ Output: overall = {precision, recall, f1}
110
+ """
111
+ eps=1e-6
112
+ overall = {'precision': -1, 'recall': -1, 'f1': -1}
113
+ NP, NR, NC = 0, 0, 0 # num of pred, num of recall, num of correct
114
+ for ii in range(pred.shape[0]):
115
+ pred_ind = np.array(pred[ii]>0.5, dtype=int)
116
+ gt_ind = np.array(gt[ii]>0.5, dtype=int)
117
+ inter = pred_ind * gt_ind
118
+ # add to overall
119
+ NC += np.sum(inter)
120
+ NP += np.sum(pred_ind)
121
+ NR += np.sum(gt_ind)
122
+ if NP > 0:
123
+ overall['precision'] = float(NC)/NP
124
+ if NR > 0:
125
+ overall['recall'] = float(NC)/NR
126
+ if NP > 0 and NR > 0:
127
+ overall['f1'] = 2*overall['precision']*overall['recall']/(overall['precision']+overall['recall']+eps)
128
+ return overall
129
+
130
+ def compute_ap(recall, precision):
131
+ """ Compute the average precision, given the recall and precision curves.
132
+ Code originally from https://github.com/rbgirshick/py-faster-rcnn.
133
+ # Arguments
134
+ recall: The recall curve (list).
135
+ precision: The precision curve (list).
136
+ # Returns
137
+ The average precision as computed in py-faster-rcnn.
138
+ """
139
+ # correct AP calculation
140
+ # first append sentinel values at the end
141
+ mrec = np.concatenate(([0.0], recall, [1.0]))
142
+ mpre = np.concatenate(([0.0], precision, [0.0]))
143
+
144
+ # compute the precision envelope
145
+ for i in range(mpre.size - 1, 0, -1):
146
+ mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
147
+
148
+ # to calculate area under PR curve, look for points
149
+ # where X axis (recall) changes value
150
+ i = np.where(mrec[1:] != mrec[:-1])[0]
151
+
152
+ # and sum (\Delta recall) * prec
153
+ ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
154
+ return ap
155
+
156
+ def concat_coord(x):
157
+ ins_feat = x # [bt, c, h, w] [512, 26, 26]
158
+ batch_size, c, h, w = x.size()
159
+
160
+ float_h = float(h)
161
+ float_w = float(w)
162
+
163
+ y_range = torch.arange(0., float_h, dtype=torch.float32) # [h, ]
164
+ y_range = 2.0 * y_range / (float_h - 1.0) - 1.0
165
+ x_range = torch.arange(0., float_w, dtype=torch.float32) # [w, ]
166
+ x_range = 2.0 * x_range / (float_w - 1.0) - 1.0
167
+ x_range = x_range[None, :] # [1, w]
168
+ y_range = y_range[:, None] # [h, 1]
169
+ x = x_range.repeat(h, 1) # [h, w]
170
+ y = y_range.repeat(1, w) # [h, w]
171
+
172
+ x = x[None, None, :, :] # [1, 1, h, w]
173
+ y = y[None, None, :, :] # [1, 1, h, w]
174
+ x = x.repeat(batch_size, 1, 1, 1) # [N, 1, h, w]
175
+ y = y.repeat(batch_size, 1, 1, 1) # [N, 1, h, w]
176
+ x = x.cuda()
177
+ y = y.cuda()
178
+
179
+ ins_feat_out = torch.cat((ins_feat, x, x, x, y, y, y), 1) # [N, c+6, h, w]
180
+
181
+ return ins_feat_out
182
+
183
+
184
+ def get_cosine_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int,
185
+ num_cycles: float = 0.5, last_epoch: int = -1):
186
+ """
187
+ Implementation by Huggingface:
188
+ https://github.com/huggingface/transformers/blob/v4.16.2/src/transformers/optimization.py
189
+
190
+ Create a schedule with a learning rate that decreases following the values
191
+ of the cosine function between the initial lr set in the optimizer to 0,
192
+ after a warmup period during which it increases linearly between 0 and the
193
+ initial lr set in the optimizer.
194
+ Args:
195
+ optimizer ([`~torch.optim.Optimizer`]):
196
+ The optimizer for which to schedule the learning rate.
197
+ num_warmup_steps (`int`):
198
+ The number of steps for the warmup phase.
199
+ num_training_steps (`int`):
200
+ The total number of training steps.
201
+ num_cycles (`float`, *optional*, defaults to 0.5):
202
+ The number of waves in the cosine schedule (the defaults is to just
203
+ decrease from the max value to 0 following a half-cosine).
204
+ last_epoch (`int`, *optional*, defaults to -1):
205
+ The index of the last epoch when resuming training.
206
+ Return:
207
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
208
+ """
209
+
210
+ def lr_lambda(current_step):
211
+ if current_step < num_warmup_steps:
212
+ return max(1e-6, float(current_step) / float(max(1, num_warmup_steps)))
213
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
214
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
215
+
216
+ return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)
217
+
218
+ def dice_loss(inputs, targets):
219
+ """
220
+ Compute the DICE loss, similar to generalized IOU for masks
221
+ Args:
222
+ inputs: A float tensor of arbitrary shape.
223
+ The predictions for each example.
224
+ targets: A float tensor with the same shape as inputs. Stores the binary
225
+ classification label for each element in inputs
226
+ (0 for the negative class and 1 for the positive class).
227
+ """
228
+
229
+ inputs = inputs.sigmoid()
230
+ inputs = inputs.flatten(1)
231
+ targets = targets.flatten(1)
232
+ numerator = 2 * (inputs * targets).sum(1)
233
+ denominator = inputs.sum(-1) + targets.sum(-1)
234
+ loss = 1 - (numerator + 1) / (denominator + 1)
235
+ return loss.mean()
236
+
237
+ def sigmoid_focal_loss(inputs, targets, alpha: float = -1, gamma: float = 0):
238
+ """
239
+ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
240
+ Args:
241
+ inputs: A float tensor of arbitrary shape.
242
+ The predictions for each example.
243
+ targets: A float tensor with the same shape as inputs. Stores the binary
244
+ classification label for each element in inputs
245
+ (0 for the negative class and 1 for the positive class).
246
+ alpha: (optional) Weighting factor in range (0,1) to balance
247
+ positive vs negative examples. Default = -1 (no weighting).
248
+ gamma: Exponent of the modulating factor (1 - p_t) to
249
+ balance easy vs hard examples.
250
+ Returns:
251
+ Loss tensor
252
+ """
253
+
254
+ prob = inputs.sigmoid()
255
+ ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
256
+ p_t = prob * targets + (1 - prob) * (1 - targets)
257
+ loss = ce_loss * ((1 - p_t) ** gamma)
258
+
259
+ if alpha >= 0:
260
+ alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
261
+ loss = alpha_t * loss
262
+ return loss.mean()