P-rateek commited on
Commit
8d15a8b
·
verified ·
1 Parent(s): ff760d2

Upload 58 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. modelsforCIML/Auto_Annotate_SDG.py +113 -0
  3. modelsforCIML/Auto_Annotate_SPG.py +248 -0
  4. modelsforCIML/CAAA_OK.png +3 -0
  5. modelsforCIML/Readme.md +50 -0
  6. modelsforCIML/classify_convxl.py +168 -0
  7. modelsforCIML/convbuper.py +209 -0
  8. modelsforCIML/convnext.py +165 -0
  9. modelsforCIML/dass.py +289 -0
  10. modelsforCIML/mmseg/__init__.py +62 -0
  11. modelsforCIML/mmseg/core/__init__.py +12 -0
  12. modelsforCIML/mmseg/core/builder.py +33 -0
  13. modelsforCIML/mmseg/core/evaluation/__init__.py +11 -0
  14. modelsforCIML/mmseg/core/evaluation/class_names.py +327 -0
  15. modelsforCIML/mmseg/core/evaluation/eval_hooks.py +132 -0
  16. modelsforCIML/mmseg/core/evaluation/metrics.py +396 -0
  17. modelsforCIML/mmseg/core/hook/__init__.py +4 -0
  18. modelsforCIML/mmseg/core/hook/wandblogger_hook.py +370 -0
  19. modelsforCIML/mmseg/core/optimizers/__init__.py +7 -0
  20. modelsforCIML/mmseg/core/optimizers/layer_decay_optimizer_constructor.py +211 -0
  21. modelsforCIML/mmseg/core/seg/__init__.py +5 -0
  22. modelsforCIML/mmseg/core/seg/builder.py +9 -0
  23. modelsforCIML/mmseg/core/seg/sampler/__init__.py +5 -0
  24. modelsforCIML/mmseg/core/seg/sampler/base_pixel_sampler.py +13 -0
  25. modelsforCIML/mmseg/core/seg/sampler/ohem_pixel_sampler.py +85 -0
  26. modelsforCIML/mmseg/core/utils/__init__.py +5 -0
  27. modelsforCIML/mmseg/core/utils/dist_util.py +46 -0
  28. modelsforCIML/mmseg/core/utils/misc.py +18 -0
  29. modelsforCIML/mmseg/models/__init__.py +10 -0
  30. modelsforCIML/mmseg/models/builder.py +49 -0
  31. modelsforCIML/mmseg/models/decode_heads/__init__.py +9 -0
  32. modelsforCIML/mmseg/models/decode_heads/aspp_head.py +122 -0
  33. modelsforCIML/mmseg/models/decode_heads/decode_head.py +295 -0
  34. modelsforCIML/mmseg/models/decode_heads/fcn_head.py +88 -0
  35. modelsforCIML/mmseg/models/decode_heads/psp_head.py +117 -0
  36. modelsforCIML/mmseg/models/decode_heads/sep_aspp_head.py +105 -0
  37. modelsforCIML/mmseg/models/decode_heads/uper_head.py +128 -0
  38. modelsforCIML/mmseg/models/decode_heads/uper_lab.py +120 -0
  39. modelsforCIML/mmseg/models/losses/__init__.py +16 -0
  40. modelsforCIML/mmseg/models/losses/accuracy.py +92 -0
  41. modelsforCIML/mmseg/models/losses/cross_entropy_loss.py +296 -0
  42. modelsforCIML/mmseg/models/losses/dice_loss.py +137 -0
  43. modelsforCIML/mmseg/models/losses/focal_loss.py +327 -0
  44. modelsforCIML/mmseg/models/losses/lovasz_loss.py +323 -0
  45. modelsforCIML/mmseg/models/losses/tversky_loss.py +137 -0
  46. modelsforCIML/mmseg/models/losses/utils.py +126 -0
  47. modelsforCIML/mmseg/ops/__init__.py +5 -0
  48. modelsforCIML/mmseg/ops/encoding.py +75 -0
  49. modelsforCIML/mmseg/ops/wrappers.py +51 -0
  50. modelsforCIML/mmseg/utils/__init__.py +11 -0
.gitattributes CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  models[[:space:]]for[[:space:]]CIML/CAAA_OK.png filter=lfs diff=lfs merge=lfs -text
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  models[[:space:]]for[[:space:]]CIML/CAAA_OK.png filter=lfs diff=lfs merge=lfs -text
37
+ modelsforCIML/CAAA_OK.png filter=lfs diff=lfs merge=lfs -text
modelsforCIML/Auto_Annotate_SDG.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python2
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @author: liuyaqi
5
+ """
6
+ import os
7
+ import cv2
8
+ import random
9
+ import torch
10
+ import torchvision
11
+ import torch.nn as nn
12
+ import torch.optim as optim
13
+ from torch.nn import functional as F
14
+ import numpy as np
15
+ import time
16
+ import logging
17
+ import argparse
18
+ from PIL import Image
19
+ from tqdm import tqdm
20
+ import albumentations as A
21
+ import torch.distributed as dist
22
+ from albumentations.pytorch import ToTensorV2
23
+ from torch.utils.data import Dataset, DataLoader
24
+ import safm_convb as safm
25
+ parser = argparse.ArgumentParser()
26
+ parser.add_argument('--nm', type=str, default='ori')
27
+ parser.add_argument('--epoch', type=int, default=1)
28
+ parser.add_argument('--pth', type=str, default='SAFM.pth')
29
+ parser.add_argument('--thres', type=float, default=0.5)
30
+ parser.add_argument('--numw', type=int, default=16)
31
+ parser.add_argument('--batch_size', type=int, default=1)
32
+ parser.add_argument('--input_scale', type=int, default=512)
33
+ parser.add_argument('--local_rank', default=-1, type=int, help='node rank for distributed training')
34
+ args = parser.parse_args()
35
+
36
+ class CVPR24EvalDataset(Dataset):
37
+ def __init__(self, roots, img_dir, sz=512, fan=False):
38
+ self.fan = fan
39
+ self.roots = os.path.join(roots, img_dir)
40
+ '''
41
+ Dir strucure in self.roots:
42
+ |
43
+ self.roots
44
+ |
45
+ |---dir1
46
+ | |----0.jpg (SDG authentic image)
47
+ | |----1.jpg (SDG manipulated image)
48
+ |
49
+ |---dir2
50
+ | |----0.jpg (SDG authentic image)
51
+ | |----1.jpg (SDG manipulated image)
52
+ |
53
+ .........
54
+ '''
55
+ self.indexs = [os.path.join(self.roots, x) for x in os.listdir(self.roots)]
56
+ self.indexs.sort()
57
+ self.lens = len(self.indexs)
58
+ self.tsr = ToTensorV2()
59
+ self.lbl = torch.FloatTensor([1])
60
+ self.rsz = torchvision.transforms.Compose([torchvision.transforms.Resize((sz,sz))])
61
+ self.toctsr =torchvision.transforms.Compose([torchvision.transforms.Resize((sz, sz)), torchvision.transforms.Normalize(mean=((0.485, 0.455, 0.406)), std=((0.229, 0.224, 0.225)))])
62
+
63
+ def __len__(self):
64
+ return self.lens
65
+
66
+ def __getitem__(self, idx):
67
+ this_r = self.indexs[idx]
68
+ img1 = self.toctsr(self.tsr(image=cv2.cvtColor(cv2.imread(os.path.join(this_r, '0.jpg')), cv2.COLOR_BGR2RGB))['image'].float()/255.0)
69
+ img2 = self.toctsr(self.tsr(image=cv2.cvtColor(cv2.imread(os.path.join(this_r, '1.jpg')), cv2.COLOR_BGR2RGB))['image'].float()/255.0)
70
+ return (img1, img2, this_r.split('/')[-1])
71
+
72
+ test_data = CVPR24EvalDataset('./', 'SDG')
73
+ test_loader = DataLoader(dataset=test_data, batch_size=1, num_workers=4)
74
+
75
+ model = safm.SAFM(2, 512)
76
+ model = model.cuda()
77
+ model = nn.DataParallel(model)
78
+ loader = torch.load(args.pth, map_location='cpu')
79
+ model.load_state_dict(loader)
80
+ model.eval()
81
+
82
+
83
+ if not os.path.exists('SDG_preds'):
84
+ os.makedirs('SDG_preds')
85
+
86
+
87
+ with torch.no_grad():
88
+ ious = []
89
+ ps = []
90
+ rs = []
91
+ fs = []
92
+ for (im1, im2, fnm) in tqdm(test_loader):
93
+ im1 = im1.cuda()
94
+ im2 = im2.cuda()
95
+ _, pred, _, _ = model(im1, im2)
96
+ _, pred2, _, _ = model(im1, torch.flip(im2, [2]))
97
+ pred2 = torch.flip(pred2, [2])
98
+
99
+ _, pred3, _, _ = model(im1, torch.flip(im2, [3]))
100
+ pred3 = torch.flip(pred3, [3])
101
+
102
+ preds = F.softmax((pred+pred2+pred3) ,dim=1)[:,1:2].squeeze().cpu().numpy()
103
+ s1 = (preds>(1/16)).sum()
104
+ s2 = (preds>(15/16)).sum()
105
+ if (s2/(s1+1e-6)>0.5):
106
+ cv2.imwrite('SDG_preds/'+fnm[0]+'.png', (preds*255).astype(np.uint8))
107
+
108
+
109
+
110
+
111
+
112
+
113
+
modelsforCIML/Auto_Annotate_SPG.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import math
4
+ import torch#用户ID:7fb702cd-1293-4470-a3b2-4ba88c3b3d4a
5
+ import numpy as np
6
+ import torch.nn as nn
7
+ import logging
8
+ import torch.optim as optim
9
+ import torch.distributed as dist
10
+ import random
11
+ import pickle
12
+ from PIL import Image
13
+ from tqdm import tqdm
14
+ from torch.autograd import Variable
15
+ from torch.cuda.amp import autocast
16
+ import segmentation_models_pytorch as smp
17
+ from torch.utils.data import Dataset, DataLoader
18
+ import albumentations as A
19
+ from albumentations.pytorch import ToTensorV2
20
+ import torchvision
21
+ import argparse
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument('--data_root', type=str, default='../../')
24
+ parser.add_argument('--train_name', type=str, default='CHDOC_JPEG0')
25
+ parser.add_argument('--model_name', type=str, default='exp')
26
+ parser.add_argument('--att', type=str, default='None')
27
+ parser.add_argument('--num', type=str, default='1')
28
+ parser.add_argument('--n_class', type=int, default=2)
29
+ parser.add_argument('--bs', type=int, default=1)
30
+ parser.add_argument('--es', type=int, default=0)
31
+ parser.add_argument('--ep', type=int, default=1)
32
+ parser.add_argument('--xk', type=int, default=0)
33
+ parser.add_argument('--numw', type=int, default=8)
34
+ parser.add_argument('--load', type=int, default=0)
35
+ parser.add_argument('--pilt', type=int, default=0)
36
+ parser.add_argument('--base', type=int, default=1)
37
+ parser.add_argument('--lr_base', type=float, default=3e-4)
38
+ parser.add_argument('--cp', type=float, default=1.0)
39
+ parser.add_argument('--mode', type=str, default='0123')
40
+ parser.add_argument('--adds', type=str, default='123')
41
+ parser.add_argument('--loss-', type=str, default='1,2,3,4')
42
+ args = parser.parse_args()
43
+
44
+ def getdir(path):
45
+ if not os.path.exists(path):
46
+ os.makedirs(path)
47
+
48
+
49
+ class CVPR24REDataset(Dataset):
50
+ def __init__(self, roots, img_dir, times=3, repeats=1):
51
+ self.roots = os.path.join(roots, img_dir)
52
+ self.indexs = [os.path.join(self.roots, x) for x in os.listdir(self.roots)]
53
+ self.lens = len(self.indexs)
54
+ self.roots = os.path.join(roots, img_dir)
55
+ '''
56
+ Dir strucure in self.roots:
57
+ |
58
+ self.roots
59
+ |
60
+ |---dir1
61
+ | |----0.jpg (SDG authentic image)
62
+ | |----1.jpg (SDG manipulated image)
63
+ |
64
+ |---dir2
65
+ | |----0.jpg (SDG authentic image)
66
+ | |----1.jpg (SDG manipulated image)
67
+ |
68
+ .........
69
+ '''
70
+ self.rsz = A.Compose([A.Resize(1024,1024)])
71
+ self.transforms = A.Compose([ToTensorV2()])
72
+ self.toctsr =torchvision.transforms.Compose([torchvision.transforms.Normalize(mean=((0.485, 0.455, 0.406)*times), std=((0.229, 0.224, 0.225)*times))])
73
+
74
+ def __len__(self):
75
+ return self.lens
76
+
77
+ def __getitem__(self, idx):
78
+ this_r = self.indexs[idx]
79
+ print(this_r)
80
+ this_r = (os.path.join(this_r, '1.jpg'), os.path.join(this_r, '0.jpg'))
81
+ img1 = cv2.cvtColor(cv2.imread(this_r[1]), cv2.COLOR_BGR2RGB)
82
+ img2 = cv2.cvtColor(cv2.imread(this_r[0]), cv2.COLOR_BGR2RGB)
83
+ h,w = img2.shape[:2]
84
+ mask = np.zeros((h,w),dtype=np.uint8)
85
+ img1 = self.rsz(image=img1)['image']
86
+ rsts = self.rsz(image=img2, mask=mask)
87
+ img2 = rsts['image']
88
+ mask = rsts['mask']
89
+ imgs = np.concatenate((img1,img2),2)
90
+ rsts = self.transforms(image=imgs,mask=mask)
91
+ imgs = rsts['image']
92
+ imgs = (torch.cat((imgs,torch.abs(imgs[:3]-imgs[3:])), 0).float()/255.0)
93
+ imgs = self.toctsr(imgs)
94
+ mask = rsts['mask'].long()
95
+ return (imgs, mask, this_r[0].split('/')[-2], h, w)
96
+
97
+ ngpu = torch.cuda.device_count()
98
+ ngpub = ngpu * args.base
99
+ if False:
100
+ gpus = True
101
+ device = torch.device("cuda",args.local_rank)
102
+ torch.cuda.set_device(args.local_rank)
103
+ dist.init_process_group(backend='nccl')
104
+ else:
105
+ gpus = False
106
+ device = torch.device("cuda")
107
+
108
+ roots1 = './'
109
+
110
+ test_data1 = CVPR24REDataset('your_data_dir/', 'SPG')
111
+ test_data2 = CVPR24REDataset('your_data_dir/', 'SPG')
112
+
113
+ class AverageMeter(object):
114
+ def __init__(self):
115
+ self.reset()
116
+ def reset(self):
117
+ self.val = 0
118
+ self.avg = 0
119
+ self.sum = 0
120
+ self.count = 0
121
+ def update(self, val, n=1):
122
+ self.val = val
123
+ self.sum += val * n
124
+ self.count += n
125
+ self.avg = self.sum / self.count
126
+
127
+ def second2time(second):
128
+ if second < 60:
129
+ return str('{}'.format(round(second, 4)))
130
+ elif second < 60*60:
131
+ m = second//60
132
+ s = second % 60
133
+ return str('{}:{}'.format(int(m), round(s, 1)))
134
+ elif second < 60*60*60:
135
+ h = second//(60*60)
136
+ m = second % (60*60)//60
137
+ s = second % (60*60) % 60
138
+ return str('{}:{}:{}'.format(int(h), int(m), int(s)))
139
+
140
+ def inial_logger(file):
141
+ logger = logging.getLogger('log')
142
+ logger.setLevel(level=logging.DEBUG)
143
+ formatter = logging.Formatter('%(message)s')
144
+ file_handler = logging.FileHandler(file)
145
+ file_handler.setLevel(level=logging.INFO)
146
+ file_handler.setFormatter(formatter)
147
+ stream_handler = logging.StreamHandler()
148
+ stream_handler.setLevel(logging.DEBUG)
149
+ stream_handler.setFormatter(formatter)
150
+ logger.addHandler(file_handler)
151
+ logger.addHandler(stream_handler)
152
+ return logger
153
+
154
+ from functools import partial
155
+ import torch
156
+ import torch.nn as nn
157
+ import torch.nn.functional as F
158
+ from timm.models.layers import trunc_normal_, DropPath
159
+ from mmseg.utils import get_root_logger
160
+
161
+ from dass import DASS
162
+
163
+ model=DASS(in_chans=9).to(device)
164
+
165
+ model = nn.DataParallel(model)
166
+ loader = torch.load('DASS.pth',map_location='cpu')['state_dict']
167
+ model.load_state_dict(loader)
168
+
169
+ model_name = args.model_name
170
+ save_ckpt_dir = os.path.join('./outputs/', model_name, 'ckpt')
171
+ save_log_dir = os.path.join('./outputs/', model_name)
172
+ try:
173
+ if not os.path.exists(save_ckpt_dir):
174
+ os.makedirs(save_ckpt_dir)
175
+ except:
176
+ pass
177
+ try:
178
+ if not os.path.exists(save_log_dir):
179
+ os.makedirs(save_log_dir)
180
+ except:
181
+ pass
182
+ import gc
183
+ param = {}
184
+ param['batch_size'] = args.bs # 批大小
185
+ param['epochs'] = args.ep # 训练轮数,请和scheduler的策略对应,不然复现不出效果,对于t0=3,t_mut=2的scheduler来讲,44的时候会达到最优
186
+ param['disp_inter'] = 1 # 显示间隔(epoch)
187
+ param['save_inter'] = 4 # 保存间隔(epoch)
188
+ param['iter_inter'] = 64 # 显示迭代间隔(batch)
189
+ param['min_inter'] = 10
190
+ param['model_name'] = model_name # 模型名称
191
+ param['save_log_dir'] = save_log_dir # 日志保存路径
192
+ param['save_ckpt_dir'] = save_ckpt_dir # 权重保存路径
193
+ param['T0']=int(24/ngpub) #cosine warmup的参数
194
+ param['load_ckpt_dir'] = None
195
+ import time
196
+
197
+ def collate_batch(batch_list):
198
+ assert type(batch_list) == list, f"Error"
199
+ batch_size = len(batch_list)
200
+ data = torch.cat([item[0] for item in batch_list]).reshape(batch_size, -1)
201
+ labels = torch.cat([item[1] for item in batch_list]).reshape(batch_size, -1)
202
+ return data, labels
203
+
204
+ def train_net_qyl(param, model, test_data1, test_data2, plot=False,device='cuda'):
205
+ # 初始化参数
206
+ global gpus
207
+ model_name = param['model_name']
208
+ epochs = param['epochs']
209
+ batch_size = param['batch_size']
210
+ iter_inter = param['iter_inter']
211
+ save_log_dir = param['save_log_dir']
212
+ save_ckpt_dir = param['save_ckpt_dir']
213
+ load_ckpt_dir = param['load_ckpt_dir']
214
+ T0=param['T0']
215
+ lr_base = args.lr_base
216
+ if gpus:
217
+ # valid_loader1 = DataLoader(dataset=test_data1, batch_size=batch_size, num_workers=args.numw, shuffle=False)
218
+ valid_loader2 = DataLoader(dataset=test_data2, batch_size=batch_size, num_workers=args.numw, shuffle=False)
219
+ else:
220
+ # valid_loader1 = DataLoader(dataset=test_data1, batch_size=batch_size, num_workers=args.numw, shuffle=False)
221
+ valid_loader2 = DataLoader(dataset=test_data2, batch_size=batch_size, num_workers=args.numw, shuffle=False)
222
+ optimizer = optim.AdamW(model.parameters(), lr=1e-4 ,weight_decay=5e-2)
223
+ if True:
224
+ model.eval()
225
+ with torch.no_grad():
226
+ for batch_idx, batch_samples in enumerate(tqdm(valid_loader2)):
227
+ data, target, fnms, h, w = batch_samples
228
+ h = h.item()
229
+ w = w.item()
230
+ data, target = Variable(data.to(device)), Variable(target.to(device))
231
+ if True:
232
+ d2 = torch.flip(data,dims=[2])
233
+ d3 = torch.flip(data,dims=[3])
234
+ data = torch.cat((data,d2,d3),0)
235
+ pred = model(data)
236
+ pred[1:2] = torch.flip(pred[1:2], dims=[2])
237
+ pred[2:3] = torch.flip(pred[2:3], dims=[3])
238
+ pred = pred.mean(0,keepdim=True)
239
+ pred= (F.softmax(pred,dim=1)[:,1:2].cpu().numpy()*255).astype(np.uint8)
240
+ for (p, fnm) in zip(pred, fnms):
241
+ ds = 'SPG_preds/'
242
+ getdir(ds)
243
+ p = cv2.resize(p.squeeze(),(w,h))
244
+ cv2.imwrite(ds+'/'+fnm+'.png', p)
245
+
246
+ train_net_qyl(param, model, test_data1, test_data2, device=device)
247
+
248
+
modelsforCIML/CAAA_OK.png ADDED

Git LFS Details

  • SHA256: 5fdd67c5b595a2a7fe12a6b8b16012e51a6892c82522837d54de855cf111c296
  • Pointer size: 132 Bytes
  • Size of remote file: 3.73 MB
modelsforCIML/Readme.md ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### This is the official implement of Category-Aware Auto-Annotation (CAAA)
2
+
3
+
4
+ ![CAAA](https://github.com/qcf-568/MIML/blob/main/models%20for%20CIML/CAAA_OK.png)
5
+
6
+
7
+ The classifiers are available at [Google Drive](https://drive.google.com/file/d/1OMGtuzqhjwcvDaP3OO1njPfAS_2s0vg8/view?usp=sharing) and [Baidu Drive](https://pan.baidu.com/s/1-NidYwgVZUA0Pi0KE3ngGw?pwd=conv).
8
+
9
+ The DASS model is available at [Google Drive](https://drive.google.com/file/d/1PXL9e8XiRGlSIcGhhppLXJtVG2rdQh5a/view?usp=sharing) and [Baidu Drive](https://pan.baidu.com/s/1lmksoTe2b2xObGkhUbd5-A?pwd=DASS).
10
+
11
+ The SACM model is available at [Google Drive](https://drive.google.com/file/d/1_C5gATKv8Mh7SyKNE_ubSpXlEASkEYja/view?usp=sharing) and [Baidu Drive](https://pan.baidu.com/s/1PnLepP7bAd-8L5NcUGBx4A?pwd=SAFM).
12
+
13
+
14
+
15
+ To leverage the CAAA for auto-annotation, you should first categorize the image pairs (each pair contains a forged image and its authentic image) into aligned SPG and SDG. Then construct the dir structure as follows:
16
+
17
+ ```
18
+ roots (dir of SPG or SDG pairs)
19
+ |
20
+ |---dir1
21
+ | |----0.jpg (authentic image)
22
+ | |----1.jpg (manipulated image)
23
+ |
24
+ |---dir2
25
+ | |----0.jpg (authentic image)
26
+ | |----1.jpg (manipulated image)
27
+ |
28
+ ..........
29
+ ```
30
+
31
+ Then run the scripts for auto-annotation.
32
+
33
+
34
+ Commands to run the classifier to catogerize the image pairs into SPG or SDG:
35
+ ```
36
+ CUDA_VISIBLE_DEVICES=0 python classify_convxl.py
37
+ ```
38
+
39
+
40
+ Commands to run the DASS to auto-annotate the image pairs in SPG:
41
+ ```
42
+ CUDA_VISIBLE_DEVICES=0 python Auto_Annotate_SPG.py --pth DASS.pth
43
+ ```
44
+
45
+
46
+ Commands to run the SACM to auto-annotate the image pairs in SDG:
47
+
48
+ ```
49
+ CUDA_VISIBLE_DEVICES=0 python Auto_Annotate_SDG.py --pth SAFM.pth
50
+ ```
modelsforCIML/classify_convxl.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import math
4
+ import torch
5
+ import numpy as np
6
+ import torch.nn as nn
7
+ import logging
8
+ from tqdm import tqdm
9
+ import torch.optim as optim
10
+ import torch.distributed as dist
11
+ import random
12
+ import pickle
13
+ from PIL import Image
14
+ from tqdm import tqdm
15
+ from torch.autograd import Variable
16
+ from torch.utils.data import Dataset, DataLoader
17
+ import albumentations as A
18
+ from albumentations.pytorch import ToTensorV2
19
+ import torchvision
20
+ import argparse
21
+ parser = argparse.ArgumentParser()
22
+ parser.add_argument('--img_dir', type=str)
23
+ parser.add_argument('--model_name', type=str, default='cls')
24
+ parser.add_argument('--att', type=str, default='None')
25
+ parser.add_argument('--num', type=str, default='1')
26
+ parser.add_argument('--n_class', type=int, default=2)
27
+ parser.add_argument('--bs', type=int, default=4)
28
+ parser.add_argument('--es', type=int, default=0)
29
+ parser.add_argument('--ep', type=int, default=10)
30
+ parser.add_argument('--xk', type=int, default=0)
31
+ parser.add_argument('--numw', type=int, default=16)
32
+ parser.add_argument('--load', type=int, default=0)
33
+ parser.add_argument('--pilt', type=int, default=0)
34
+ parser.add_argument('--base', type=int, default=1)
35
+ parser.add_argument('--lr_base', type=float, default=3e-4)
36
+ parser.add_argument('--cp', type=float, default=1.0)
37
+ parser.add_argument('--mode', type=str, default='0123')
38
+ parser.add_argument('--local-rank', default=-1, type=int, help='node rank for distributed training')
39
+ parser.add_argument('--adds', type=str, default='123')
40
+ parser.add_argument('--lossw', type=str, default='1,2,3,4')
41
+ args = parser.parse_args()
42
+
43
+ from tqdm import tqdm
44
+
45
+ class CVPR24EVALDataset(Dataset):
46
+ def __init__(self, roots):
47
+ self.indexs = [(os.path.join(roots, d,'0.jpg'), os.path.join(roots, d,'1.jpg')) for d in os.listdir(roots)]
48
+ self.roots = roots
49
+ self.indexs.sort()
50
+ self.lens = len(self.indexs)
51
+ self.rsztsr = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Resize((512,512)),torchvision.transforms.Normalize(mean=((0.485, 0.455, 0.406)), std=((0.229, 0.224, 0.225)))])
52
+
53
+ def __len__(self):
54
+ return self.lens
55
+
56
+ def __getitem__(self, idx):
57
+ try:
58
+ img1 = cv2.cvtColor(cv2.imread(self.indexs[idx][0]),cv2.COLOR_BGR2RGB)
59
+ img2 = cv2.cvtColor(cv2.imread(self.indexs[idx][1]),cv2.COLOR_BGR2RGB)
60
+ img1 = self.rsztsr(img1)
61
+ img2 = self.rsztsr(img2)
62
+ imgs = torch.cat((img1, img2), 0)
63
+ return (imgs, self.indexs[idx][0], self.indexs[idx][1], False)
64
+ except:
65
+ print('error')
66
+ return (None, None, None, True)
67
+
68
+ device = torch.device("cuda")
69
+
70
+ roots1 = './'
71
+ test_data = CVPR24EVALDataset(roots1)
72
+
73
+ def get_logger(filename, verbosity=1, name=None):
74
+ level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}
75
+ formatter = logging.Formatter("[%(asctime)s][%(filename)s][%(levelname)s] %(message)s")
76
+ logger = logging.getLogger(name)
77
+ logger.setLevel(level_dict[verbosity])
78
+ fh = logging.FileHandler(filename, "w")
79
+ fh.setFormatter(formatter)
80
+ logger.addHandler(fh)
81
+ sh = logging.StreamHandler()
82
+ sh.setFormatter(formatter)
83
+ logger.addHandler(sh)
84
+ return logger
85
+
86
+ class AverageMeter(object):
87
+ def __init__(self):
88
+ self.reset()
89
+ def reset(self):
90
+ self.val = 0
91
+ self.avg = 0
92
+ self.sum = 0
93
+ self.count = 0
94
+ def update(self, val, n=1):
95
+ self.val = val
96
+ self.sum += val * n
97
+ self.count += n
98
+ self.avg = self.sum / self.count
99
+
100
+ def second2time(second):
101
+ if second < 60:
102
+ return str('{}'.format(round(second, 4)))
103
+ elif second < 60*60:
104
+ m = second//60
105
+ s = second % 60
106
+ return str('{}:{}'.format(int(m), round(s, 1)))
107
+ elif second < 60*60*60:
108
+ h = second//(60*60)
109
+ m = second % (60*60)//60
110
+ s = second % (60*60) % 60
111
+ return str('{}:{}:{}'.format(int(h), int(m), int(s)))
112
+
113
+ def inial_logger(file):
114
+ logger = logging.getLogger('log')
115
+ logger.setLevel(level=logging.DEBUG)
116
+ formatter = logging.Formatter('%(message)s')
117
+ file_handler = logging.FileHandler(file)
118
+ file_handler.setLevel(level=logging.INFO)
119
+ file_handler.setFormatter(formatter)
120
+ stream_handler = logging.StreamHandler()
121
+ stream_handler.setLevel(logging.DEBUG)
122
+ stream_handler.setFormatter(formatter)
123
+ logger.addHandler(file_handler)
124
+ logger.addHandler(stream_handler)
125
+ return logger
126
+
127
+ from functools import partial
128
+ import torch
129
+ import torch.nn as nn
130
+ import torch.nn.functional as F
131
+ from timm.models.layers import trunc_normal_, DropPath
132
+ from mmseg.utils import get_root_logger
133
+ from convnext import ConvNeXt
134
+
135
+ model=ConvNeXt(in_chans=6, depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], drop_path_rate=0.8, layer_scale_init_value=1.0, num_classes=8).to(device)
136
+
137
+ model = nn.DataParallel(model)
138
+ loaders = torch.load('convxl.pth',map_location='cpu')['state_dict']
139
+ model.load_state_dict(loaders)
140
+ model = model.cuda()
141
+ model.eval()
142
+
143
+ all_dict = {}
144
+ SPG = []
145
+ SDG = []
146
+ NotAlignedSPG = []
147
+
148
+ with torch.no_grad():
149
+ for idx in tqdm(range(len(test_data))):
150
+ (imgs,auth,temp,flags) = test_data.__getitem__(idx)
151
+ if flags:
152
+ continue
153
+ pred = model(imgs.unsqueeze(0))
154
+ b,c = pred.shape
155
+ pred = F.softmax(pred.reshape(b,c//2,2),dim=-1).cpu().numpy()
156
+ all_dict[temp]=(auth, pred)
157
+ if ((pred[0,0,1]>0.5) and (pred[0,1,1]>0.5)): # SPG
158
+ SPG.append((auth, temp))
159
+ if ((pred[0,0,0]>0.5) and (pred[0,1,0]>0.5)): # SDG
160
+ SDG.append((auth, temp))
161
+ if ((pred[0,0,1]>0.5) and (pred[0,1,0]>0.5)): # NotAlignedSPG
162
+ NotAlignedSPG.append((auth, temp))
163
+
164
+ with open('convxl_cls.pk','wb') as f:
165
+ pickle.dump(all_dict, f)
166
+
167
+
168
+
modelsforCIML/convbuper.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from timm.models.layers import trunc_normal_, DropPath
14
+ from mmseg.models.decode_heads import UPerHead,FCNHead
15
+ from functools import partial
16
+ from itertools import chain
17
+ from typing import Sequence
18
+
19
+ class Block(nn.Module):
20
+ r""" ConvNeXt Block. There are two equivalent implementations:
21
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
22
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
23
+ We use (2) as we find it slightly faster in PyTorch
24
+
25
+ Args:
26
+ dim (int): Number of input channels.
27
+ drop_path (float): Stochastic depth rate. Default: 0.0
28
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
29
+ """
30
+ def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
31
+ super().__init__()
32
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
33
+ self.norm = LayerNorm(dim, eps=1e-6)
34
+ self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
35
+ self.act = nn.GELU()
36
+ self.pwconv2 = nn.Linear(4 * dim, dim)
37
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
38
+ requires_grad=True) if layer_scale_init_value > 0 else None
39
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
40
+
41
+ def forward(self, x):
42
+ input = x
43
+ x = self.dwconv(x)
44
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
45
+ x = self.norm(x)
46
+ x = self.pwconv1(x)
47
+ x = self.act(x)
48
+ x = self.pwconv2(x)
49
+ if self.gamma is not None:
50
+ x = self.gamma * x
51
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
52
+
53
+ x = input + self.drop_path(x)
54
+ return x
55
+
56
+ class ConvNeXt(nn.Module):
57
+ r""" ConvNeXt
58
+ A PyTorch impl of : `A ConvNet for the 2020s` -
59
+ https://arxiv.org/pdf/2201.03545.pdf
60
+
61
+ Args:
62
+ in_chans (int): Number of input image channels. Default: 3
63
+ num_classes (int): Number of classes for classification head. Default: 1000
64
+ depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
65
+ dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
66
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
67
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
68
+ head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
69
+ """
70
+ def __init__(self, in_chans=3, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768],
71
+ drop_path_rate=0., layer_scale_init_value=1e-6, out_indices=[0, 1, 2, 3],
72
+ ):
73
+ super().__init__()
74
+
75
+ self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
76
+ stem = nn.Sequential(
77
+ nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
78
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
79
+ )
80
+ self.downsample_layers.append(stem)
81
+ for i in range(3):
82
+ downsample_layer = nn.Sequential(
83
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
84
+ nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
85
+ )
86
+ self.downsample_layers.append(downsample_layer)
87
+
88
+ self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
89
+ dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
90
+ cur = 0
91
+ for i in range(4):
92
+ stage = nn.Sequential(
93
+ *[Block(dim=dims[i], drop_path=dp_rates[cur + j],
94
+ layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
95
+ )
96
+ self.stages.append(stage)
97
+ cur += depths[i]
98
+
99
+ self.out_indices = out_indices
100
+
101
+ norm_layer = partial(LayerNorm, eps=1e-6, data_format="channels_first")
102
+ for i_layer in range(4):
103
+ layer = norm_layer(dims[i_layer])
104
+ layer_name = f'norm{i_layer}'
105
+ self.add_module(layer_name, layer)
106
+
107
+ self.apply(self._init_weights)
108
+
109
+ def _init_weights(self, m):
110
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
111
+ trunc_normal_(m.weight, std=.02)
112
+ nn.init.constant_(m.bias, 0)
113
+
114
+ def init_weights(self, pretrained=None):
115
+ """Initialize the weights in backbone.
116
+ Args:
117
+ pretrained (str, optional): Path to pre-trained weights.
118
+ Defaults to None.
119
+ """
120
+
121
+ def _init_weights(m):
122
+ if isinstance(m, nn.Linear):
123
+ trunc_normal_(m.weight, std=.02)
124
+ if isinstance(m, nn.Linear) and m.bias is not None:
125
+ nn.init.constant_(m.bias, 0)
126
+ elif isinstance(m, nn.LayerNorm):
127
+ nn.init.constant_(m.bias, 0)
128
+ nn.init.constant_(m.weight, 1.0)
129
+
130
+ if pretrained is None:
131
+ self.apply(_init_weights)
132
+ else:
133
+ raise TypeError('pretrained must be a str or None')
134
+
135
+ def forward_features(self, x):
136
+ outs = []
137
+ for i in range(4):
138
+ x = self.downsample_layers[i](x)
139
+ x = self.stages[i](x)
140
+ if i in self.out_indices:
141
+ norm_layer = getattr(self, f'norm{i}')
142
+ x_out = norm_layer(x)
143
+ outs.append(x_out)
144
+
145
+ return tuple(outs)
146
+
147
+ def forward(self, x):
148
+ x = self.forward_features(x)
149
+ return x
150
+
151
+ class LayerNorm(nn.Module):
152
+ r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
153
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
154
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
155
+ with shape (batch_size, channels, height, width).
156
+ """
157
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
158
+ super().__init__()
159
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
160
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
161
+ self.eps = eps
162
+ self.data_format = data_format
163
+ if self.data_format not in ["channels_last", "channels_first"]:
164
+ raise NotImplementedError
165
+ self.normalized_shape = (normalized_shape, )
166
+
167
+ def forward(self, x):
168
+ if self.data_format == "channels_last":
169
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
170
+ elif self.data_format == "channels_first":
171
+ u = x.mean(1, keepdim=True)
172
+ s = (x - u).pow(2).mean(1, keepdim=True)
173
+ x = (x - u) / torch.sqrt(s + self.eps)
174
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
175
+ return x
176
+
177
+
178
+ class ConvBUPer(nn.Module):
179
+ def __init__(self,):
180
+ super(ConvBUPer, self).__init__()
181
+ self.backbone = ConvNeXt(in_chans=3, depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], drop_path_rate=0.4)
182
+ self.decode_head = UPerHead(
183
+ in_channels=[128, 256, 512, 1024],
184
+ in_index=[0,1,2,3],
185
+ pool_scales=(1,2,3,6),
186
+ channels=512,
187
+ dropout_ratio=0.1,
188
+ num_classes=2,
189
+ norm_cfg=dict(type='SyncBN'),
190
+ #norm_cfg=dict(type='SyncBN'),
191
+ )
192
+ self.auxiliary_head = FCNHead(
193
+ in_channels=512,
194
+ in_index=2,
195
+ channels=256,
196
+ num_convs=1,
197
+ concat_input=False,
198
+ dropout_ratio=0.1,
199
+ num_classes=2,
200
+ align_corners=False,
201
+ norm_cfg=dict(type='SyncBN'),
202
+ )
203
+
204
+ def forward(self,x):
205
+ outs = self.backbone(x)
206
+ outs = self.decode_head(outs)
207
+ return outs
208
+
209
+
modelsforCIML/convnext.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+
9
+ from functools import partial
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from timm.models.layers import trunc_normal_, DropPath
14
+
15
+ class Block(nn.Module):
16
+ r""" ConvNeXt Block. There are two equivalent implementations:
17
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
18
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
19
+ We use (2) as we find it slightly faster in PyTorch
20
+
21
+ Args:
22
+ dim (int): Number of input channels.
23
+ drop_path (float): Stochastic depth rate. Default: 0.0
24
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
25
+ """
26
+ def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
27
+ super().__init__()
28
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
29
+ self.norm = LayerNorm(dim, eps=1e-6)
30
+ self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
31
+ self.act = nn.GELU()
32
+ self.pwconv2 = nn.Linear(4 * dim, dim)
33
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
34
+ requires_grad=True) if layer_scale_init_value > 0 else None
35
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
36
+
37
+ def forward(self, x):
38
+ input = x
39
+ x = self.dwconv(x)
40
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
41
+ x = self.norm(x)
42
+ x = self.pwconv1(x)
43
+ x = self.act(x)
44
+ x = self.pwconv2(x)
45
+ if self.gamma is not None:
46
+ x = self.gamma * x
47
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
48
+
49
+ x = input + self.drop_path(x)
50
+ return x
51
+
52
+ class ConvNeXt(nn.Module):
53
+ r""" ConvNeXt
54
+ A PyTorch impl of : `A ConvNet for the 2020s` -
55
+ https://arxiv.org/pdf/2201.03545.pdf
56
+
57
+ Args:
58
+ in_chans (int): Number of input image channels. Default: 3
59
+ num_classes (int): Number of classes for classification head. Default: 1000
60
+ depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
61
+ dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
62
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
63
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
64
+ head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
65
+ """
66
+ def __init__(self, in_chans=3, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768],
67
+ drop_path_rate=0., layer_scale_init_value=1e-6, num_classes=8,
68
+ ):
69
+ super().__init__()
70
+
71
+ self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
72
+ stem = nn.Sequential(
73
+ nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
74
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
75
+ )
76
+ self.downsample_layers.append(stem)
77
+ for i in range(3):
78
+ downsample_layer = nn.Sequential(
79
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
80
+ nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
81
+ )
82
+ self.downsample_layers.append(downsample_layer)
83
+
84
+ self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
85
+ dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
86
+ cur = 0
87
+ for i in range(4):
88
+ stage = nn.Sequential(
89
+ *[Block(dim=dims[i], drop_path=dp_rates[cur + j],
90
+ layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
91
+ )
92
+ self.stages.append(stage)
93
+ cur += depths[i]
94
+
95
+ self.fc = nn.Sequential(nn.Dropout(p=0.3), nn.AdaptiveAvgPool2d(1), nn.Flatten(1), nn.Linear(dims[-1], num_classes))
96
+ norm_layer = partial(LayerNorm, eps=1e-6, data_format="channels_first")
97
+ for i_layer in range(3,4):
98
+ layer = norm_layer(dims[i_layer])
99
+ layer_name = f'norm'
100
+ self.add_module(layer_name, layer)
101
+
102
+ self.apply(self._init_weights)
103
+
104
+ def _init_weights(self, m):
105
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
106
+ trunc_normal_(m.weight, std=.02)
107
+ nn.init.constant_(m.bias, 0)
108
+
109
+ def init_weights(self, pretrained=None):
110
+ """Initialize the weights in backbone.
111
+ Args:
112
+ pretrained (str, optional): Path to pre-trained weights.
113
+ Defaults to None.
114
+ """
115
+
116
+ def _init_weights(m):
117
+ if isinstance(m, nn.Linear):
118
+ trunc_normal_(m.weight, std=.02)
119
+ if isinstance(m, nn.Linear) and m.bias is not None:
120
+ nn.init.constant_(m.bias, 0)
121
+ elif isinstance(m, nn.LayerNorm):
122
+ nn.init.constant_(m.bias, 0)
123
+ nn.init.constant_(m.weight, 1.0)
124
+
125
+ self.apply(_init_weights)
126
+
127
+ def forward_features(self, x):
128
+ for i in range(4):
129
+ x = self.downsample_layers[i](x)
130
+ x = self.stages[i](x)
131
+ if i==3:
132
+ norm_layer = getattr(self, f'norm')
133
+ x_out = norm_layer(x)
134
+ return self.fc(x_out)
135
+
136
+
137
+ def forward(self, x):
138
+ x = self.forward_features(x)
139
+ return x
140
+
141
+ class LayerNorm(nn.Module):
142
+ r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
143
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
144
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
145
+ with shape (batch_size, channels, height, width).
146
+ """
147
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
148
+ super().__init__()
149
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
150
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
151
+ self.eps = eps
152
+ self.data_format = data_format
153
+ if self.data_format not in ["channels_last", "channels_first"]:
154
+ raise NotImplementedError
155
+ self.normalized_shape = (normalized_shape, )
156
+
157
+ def forward(self, x):
158
+ if self.data_format == "channels_last":
159
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
160
+ elif self.data_format == "channels_first":
161
+ u = x.mean(1, keepdim=True)
162
+ s = (x - u).pow(2).mean(1, keepdim=True)
163
+ x = (x - u) / torch.sqrt(s + self.eps)
164
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
165
+ return x
modelsforCIML/dass.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from mmseg.models.decode_heads import UPerLab,FCNHead
14
+
15
+ # --------------------------------------------------------
16
+ # InternImage
17
+ # Copyright (c) 2022 OpenGVLab
18
+ # Licensed under The MIT License [see LICENSE for details]
19
+ # --------------------------------------------------------
20
+
21
+ from collections import OrderedDict
22
+ import torch.utils.checkpoint as checkpoint
23
+ from timm.models.layers import trunc_normal_, DropPath
24
+ from mmcv.cnn import constant_init, trunc_normal_init
25
+ import torch.nn.functional as F
26
+ from torch.nn.modules.utils import _pair as to_2tuple
27
+ from mmcv.cnn import build_norm_layer
28
+ from mmcv.runner import BaseModule
29
+ import math
30
+ import warnings
31
+
32
+
33
+ class Mlp(nn.Module):
34
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., linear=False):
35
+ super().__init__()
36
+ out_features = out_features or in_features
37
+ hidden_features = hidden_features or in_features
38
+ self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
39
+ self.dwconv = DWConv(hidden_features)
40
+ self.act = act_layer()
41
+ self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
42
+ self.drop = nn.Dropout(drop)
43
+ self.linear = linear
44
+ if self.linear:
45
+ self.relu = nn.ReLU(inplace=True)
46
+
47
+ def forward(self, x):
48
+ x = self.fc1(x)
49
+ if self.linear:
50
+ x = self.relu(x)
51
+ x = self.dwconv(x)
52
+ x = self.act(x)
53
+ x = self.drop(x)
54
+ x = self.fc2(x)
55
+ x = self.drop(x)
56
+ return x
57
+
58
+
59
+ class AttentionModule(nn.Module):
60
+ def __init__(self, dim):
61
+ super().__init__()
62
+ self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
63
+ self.conv_spatial = nn.Conv2d(
64
+ dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3)
65
+ self.conv1 = nn.Conv2d(dim, dim, 1)
66
+
67
+ def forward(self, x):
68
+ u = x.clone()
69
+ attn = self.conv0(x)
70
+ attn = self.conv_spatial(attn)
71
+ attn = self.conv1(attn)
72
+ return u * attn
73
+
74
+
75
+ class SpatialAttention(nn.Module):
76
+ def __init__(self, d_model):
77
+ super().__init__()
78
+ self.d_model = d_model
79
+ self.proj_1 = nn.Conv2d(d_model, d_model, 1)
80
+ self.activation = nn.GELU()
81
+ self.spatial_gating_unit = AttentionModule(d_model)
82
+ self.proj_2 = nn.Conv2d(d_model, d_model, 1)
83
+
84
+ def forward(self, x):
85
+ shorcut = x.clone()
86
+ x = self.proj_1(x)
87
+ x = self.activation(x)
88
+ x = self.spatial_gating_unit(x)
89
+ x = self.proj_2(x)
90
+ x = x + shorcut
91
+ return x
92
+
93
+
94
+ class Block(nn.Module):
95
+
96
+ def __init__(self,
97
+ dim,
98
+ mlp_ratio=4.,
99
+ drop=0.,
100
+ drop_path=0.,
101
+ act_layer=nn.GELU,
102
+ linear=False,
103
+ norm_cfg=dict(type='SyncBN', requires_grad=True)):
104
+ super().__init__()
105
+ self.norm1 = build_norm_layer(norm_cfg, dim)[1]
106
+ self.attn = SpatialAttention(dim)
107
+ self.drop_path = DropPath(
108
+ drop_path) if drop_path > 0. else nn.Identity()
109
+
110
+ self.norm2 = build_norm_layer(norm_cfg, dim)[1]
111
+ mlp_hidden_dim = int(dim * mlp_ratio)
112
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
113
+ act_layer=act_layer, drop=drop, linear=linear)
114
+ layer_scale_init_value = 1e-2
115
+ self.layer_scale_1 = nn.Parameter(
116
+ layer_scale_init_value * torch.ones((dim)), requires_grad=True)
117
+ self.layer_scale_2 = nn.Parameter(
118
+ layer_scale_init_value * torch.ones((dim)), requires_grad=True)
119
+
120
+ def forward(self, x, H, W):
121
+ B, N, C = x.shape
122
+ x = x.permute(0, 2, 1).view(B, C, H, W)
123
+ x = x + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
124
+ * self.attn(self.norm1(x)))
125
+ x = x + self.drop_path(self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
126
+ * self.mlp(self.norm2(x)))
127
+ x = x.view(B, C, N).permute(0, 2, 1)
128
+ return x
129
+
130
+
131
+ class OverlapPatchEmbed(nn.Module):
132
+ """ Image to Patch Embedding
133
+ """
134
+
135
+ def __init__(self,
136
+ patch_size=7,
137
+ stride=4,
138
+ in_chans=3,
139
+ embed_dim=768,
140
+ norm_cfg=dict(type='SyncBN', requires_grad=True)):
141
+ super().__init__()
142
+ patch_size = to_2tuple(patch_size)
143
+
144
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
145
+ padding=(patch_size[0] // 2, patch_size[1] // 2))
146
+ self.norm = build_norm_layer(norm_cfg, embed_dim)[1]
147
+
148
+ def forward(self, x):
149
+ x = self.proj(x)
150
+ _, _, H, W = x.shape
151
+ x = self.norm(x)
152
+
153
+ x = x.flatten(2).transpose(1, 2)
154
+
155
+ return x, H, W
156
+
157
+
158
+ class VAN(BaseModule):
159
+ def __init__(self,
160
+ in_chans=9,
161
+ embed_dims=[64, 128, 256, 512],
162
+ mlp_ratios=[8, 8, 4, 4],
163
+ drop_rate=0.,
164
+ drop_path_rate=0.,
165
+ depths=[3, 4, 6, 3],
166
+ num_stages=4,
167
+ linear=False,
168
+ pretrained=None,
169
+ init_cfg=None,
170
+ norm_cfg=dict(type='SyncBN', requires_grad=True)):
171
+ super(VAN, self).__init__(init_cfg=init_cfg)
172
+
173
+ assert not (init_cfg and pretrained), \
174
+ 'init_cfg and pretrained cannot be set at the same time'
175
+ if isinstance(pretrained, str):
176
+ warnings.warn('DeprecationWarning: pretrained is deprecated, '
177
+ 'please use "init_cfg" instead')
178
+ self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
179
+ elif pretrained is not None:
180
+ raise TypeError('pretrained must be a str or None')
181
+
182
+ self.depths = depths
183
+ self.num_stages = num_stages
184
+ self.linear = linear
185
+
186
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate,
187
+ sum(depths))] # stochastic depth decay rule
188
+ cur = 0
189
+
190
+ for i in range(num_stages):
191
+ patch_embed = OverlapPatchEmbed(patch_size=7 if i == 0 else 3,
192
+ stride=4 if i == 0 else 2,
193
+ in_chans=in_chans if i == 0 else embed_dims[i - 1],
194
+ embed_dim=embed_dims[i])
195
+
196
+ block = nn.ModuleList([Block(dim=embed_dims[i],
197
+ mlp_ratio=mlp_ratios[i],
198
+ drop=drop_rate,
199
+ drop_path=dpr[cur + j],
200
+ linear=linear,
201
+ norm_cfg=norm_cfg)
202
+ for j in range(depths[i])])
203
+ norm = nn.LayerNorm(embed_dims[i])
204
+ cur += depths[i]
205
+
206
+ setattr(self, f"patch_embed{i + 1}", patch_embed)
207
+ setattr(self, f"block{i + 1}", block)
208
+ setattr(self, f"norm{i + 1}", norm)
209
+
210
+ def init_weights(self):
211
+ print('init cfg', self.init_cfg)
212
+ if self.init_cfg is None:
213
+ for m in self.modules():
214
+ if isinstance(m, nn.Linear):
215
+ trunc_normal_init(m, std=.02, bias=0.)
216
+ elif isinstance(m, nn.LayerNorm):
217
+ constant_init(m, val=1.0, bias=0.)
218
+ elif isinstance(m, nn.Conv2d):
219
+ fan_out = m.kernel_size[0] * m.kernel_size[
220
+ 1] * m.out_channels
221
+ fan_out //= m.groups
222
+ normal_init(
223
+ m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
224
+ else:
225
+ super(VAN, self).init_weights()
226
+
227
+ def forward(self, x):
228
+ B = x.shape[0]
229
+ outs = []
230
+
231
+ for i in range(self.num_stages):
232
+ patch_embed = getattr(self, f"patch_embed{i + 1}")
233
+ block = getattr(self, f"block{i + 1}")
234
+ norm = getattr(self, f"norm{i + 1}")
235
+ x, H, W = patch_embed(x)
236
+ for blk in block:
237
+ x = blk(x, H, W)
238
+ x = norm(x)
239
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
240
+ outs.append(x)
241
+
242
+ return outs
243
+
244
+
245
+ class DWConv(nn.Module):
246
+ def __init__(self, dim=768):
247
+ super(DWConv, self).__init__()
248
+ self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
249
+
250
+ def forward(self, x):
251
+ x = self.dwconv(x)
252
+ return x
253
+
254
+ class DASS(nn.Module):
255
+ def __init__(self,in_chans=6):
256
+ super(DASS, self).__init__()
257
+ self.backbone = VAN(in_chans=in_chans, embed_dims=[96, 192, 480, 768], drop_rate=0.0, drop_path_rate=0.4, depths=[3, 3, 24, 3], norm_cfg=dict(type='SyncBN', requires_grad=True))
258
+ self.decode_head = UPerLab(
259
+ in_channels=[96, 192, 480, 768],
260
+ in_index=[0,1,2,3],
261
+ pool_scales=(1,2,3,6),
262
+ channels=512,
263
+ dropout_ratio=0.1,
264
+ num_classes=2,
265
+ norm_cfg=dict(type='SyncBN'),
266
+ #norm_cfg=dict(type='SyncBN'),
267
+ )
268
+ self.auxiliary_head = FCNHead(
269
+ in_channels=480,
270
+ in_index=2,
271
+ channels=256,
272
+ num_convs=1,
273
+ concat_input=False,
274
+ dropout_ratio=0.1,
275
+ num_classes=2,
276
+ align_corners=False,
277
+ norm_cfg=dict(type='SyncBN'),
278
+ )
279
+
280
+ def forward(self,x):
281
+ outs = self.backbone(x)
282
+ if self.training:
283
+ out1, out3 = self.decode_head(outs)
284
+ out2 = self.auxiliary_head(outs)
285
+ return F.upsample_bilinear(out1,scale_factor=4.0),F.upsample_bilinear(out2,scale_factor=16.0),F.upsample_bilinear(out3,scale_factor=4.0)
286
+ else:
287
+ out1 = self.decode_head(outs)
288
+ return F.upsample_bilinear(out1,scale_factor=4.0)
289
+
modelsforCIML/mmseg/__init__.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import warnings
3
+
4
+ import mmcv
5
+ from packaging.version import parse
6
+
7
+ from .version import __version__, version_info
8
+
9
+ MMCV_MIN = '1.3.13'
10
+ MMCV_MAX = '1.8.0'
11
+
12
+
13
+ def digit_version(version_str: str, length: int = 4):
14
+ """Convert a version string into a tuple of integers.
15
+
16
+ This method is usually used for comparing two versions. For pre-release
17
+ versions: alpha < beta < rc.
18
+
19
+ Args:
20
+ version_str (str): The version string.
21
+ length (int): The maximum number of version levels. Default: 4.
22
+
23
+ Returns:
24
+ tuple[int]: The version info in digits (integers).
25
+ """
26
+ version = parse(version_str)
27
+ assert version.release, f'failed to parse version {version_str}'
28
+ release = list(version.release)
29
+ release = release[:length]
30
+ if len(release) < length:
31
+ release = release + [0] * (length - len(release))
32
+ if version.is_prerelease:
33
+ mapping = {'a': -3, 'b': -2, 'rc': -1}
34
+ val = -4
35
+ # version.pre can be None
36
+ if version.pre:
37
+ if version.pre[0] not in mapping:
38
+ warnings.warn(f'unknown prerelease version {version.pre[0]}, '
39
+ 'version checking may go wrong')
40
+ else:
41
+ val = mapping[version.pre[0]]
42
+ release.extend([val, version.pre[-1]])
43
+ else:
44
+ release.extend([val, 0])
45
+
46
+ elif version.is_postrelease:
47
+ release.extend([1, version.post])
48
+ else:
49
+ release.extend([0, 0])
50
+ return tuple(release)
51
+
52
+
53
+ mmcv_min_version = digit_version(MMCV_MIN)
54
+ mmcv_max_version = digit_version(MMCV_MAX)
55
+ mmcv_version = digit_version(mmcv.__version__)
56
+
57
+
58
+ assert (mmcv_min_version <= mmcv_version < mmcv_max_version), \
59
+ f'MMCV=={mmcv.__version__} is used but incompatible. ' \
60
+ f'Please install mmcv>={mmcv_min_version}, <{mmcv_max_version}.'
61
+
62
+ __all__ = ['__version__', 'version_info', 'digit_version']
modelsforCIML/mmseg/core/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .builder import (OPTIMIZER_BUILDERS, build_optimizer,
3
+ build_optimizer_constructor)
4
+ from .evaluation import * # noqa: F401, F403
5
+ from .hook import * # noqa: F401, F403
6
+ from .optimizers import * # noqa: F401, F403
7
+ from .seg import * # noqa: F401, F403
8
+ from .utils import * # noqa: F401, F403
9
+
10
+ __all__ = [
11
+ 'OPTIMIZER_BUILDERS', 'build_optimizer', 'build_optimizer_constructor'
12
+ ]
modelsforCIML/mmseg/core/builder.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import copy
3
+
4
+ from mmcv.runner.optimizer import OPTIMIZER_BUILDERS as MMCV_OPTIMIZER_BUILDERS
5
+ from mmcv.utils import Registry, build_from_cfg
6
+
7
+ OPTIMIZER_BUILDERS = Registry(
8
+ 'optimizer builder', parent=MMCV_OPTIMIZER_BUILDERS)
9
+
10
+
11
+ def build_optimizer_constructor(cfg):
12
+ constructor_type = cfg.get('type')
13
+ if constructor_type in OPTIMIZER_BUILDERS:
14
+ return build_from_cfg(cfg, OPTIMIZER_BUILDERS)
15
+ elif constructor_type in MMCV_OPTIMIZER_BUILDERS:
16
+ return build_from_cfg(cfg, MMCV_OPTIMIZER_BUILDERS)
17
+ else:
18
+ raise KeyError(f'{constructor_type} is not registered '
19
+ 'in the optimizer builder registry.')
20
+
21
+
22
+ def build_optimizer(model, cfg):
23
+ optimizer_cfg = copy.deepcopy(cfg)
24
+ constructor_type = optimizer_cfg.pop('constructor',
25
+ 'DefaultOptimizerConstructor')
26
+ paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
27
+ optim_constructor = build_optimizer_constructor(
28
+ dict(
29
+ type=constructor_type,
30
+ optimizer_cfg=optimizer_cfg,
31
+ paramwise_cfg=paramwise_cfg))
32
+ optimizer = optim_constructor(model)
33
+ return optimizer
modelsforCIML/mmseg/core/evaluation/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .class_names import get_classes, get_palette
3
+ from .eval_hooks import DistEvalHook, EvalHook
4
+ from .metrics import (eval_metrics, intersect_and_union, mean_dice,
5
+ mean_fscore, mean_iou, pre_eval_to_metrics)
6
+
7
+ __all__ = [
8
+ 'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'mean_fscore',
9
+ 'eval_metrics', 'get_classes', 'get_palette', 'pre_eval_to_metrics',
10
+ 'intersect_and_union'
11
+ ]
modelsforCIML/mmseg/core/evaluation/class_names.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import mmcv
3
+
4
+
5
+ def cityscapes_classes():
6
+ """Cityscapes class names for external use."""
7
+ return [
8
+ 'road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
9
+ 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
10
+ 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
11
+ 'bicycle'
12
+ ]
13
+
14
+
15
+ def ade_classes():
16
+ """ADE20K class names for external use."""
17
+ return [
18
+ 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ',
19
+ 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth',
20
+ 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car',
21
+ 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug',
22
+ 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe',
23
+ 'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
24
+ 'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
25
+ 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path',
26
+ 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door',
27
+ 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table',
28
+ 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove',
29
+ 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',
30
+ 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
31
+ 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',
32
+ 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister',
33
+ 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',
34
+ 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything',
35
+ 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent',
36
+ 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',
37
+ 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',
38
+ 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',
39
+ 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen',
40
+ 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
41
+ 'clock', 'flag'
42
+ ]
43
+
44
+
45
+ def voc_classes():
46
+ """Pascal VOC class names for external use."""
47
+ return [
48
+ 'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
49
+ 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
50
+ 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
51
+ 'tvmonitor'
52
+ ]
53
+
54
+
55
+ def cocostuff_classes():
56
+ """CocoStuff class names for external use."""
57
+ return [
58
+ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
59
+ 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
60
+ 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
61
+ 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
62
+ 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
63
+ 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
64
+ 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
65
+ 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
66
+ 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
67
+ 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
68
+ 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
69
+ 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
70
+ 'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner',
71
+ 'blanket', 'branch', 'bridge', 'building-other', 'bush', 'cabinet',
72
+ 'cage', 'cardboard', 'carpet', 'ceiling-other', 'ceiling-tile',
73
+ 'cloth', 'clothes', 'clouds', 'counter', 'cupboard', 'curtain',
74
+ 'desk-stuff', 'dirt', 'door-stuff', 'fence', 'floor-marble',
75
+ 'floor-other', 'floor-stone', 'floor-tile', 'floor-wood', 'flower',
76
+ 'fog', 'food-other', 'fruit', 'furniture-other', 'grass', 'gravel',
77
+ 'ground-other', 'hill', 'house', 'leaves', 'light', 'mat', 'metal',
78
+ 'mirror-stuff', 'moss', 'mountain', 'mud', 'napkin', 'net', 'paper',
79
+ 'pavement', 'pillow', 'plant-other', 'plastic', 'platform',
80
+ 'playingfield', 'railing', 'railroad', 'river', 'road', 'rock', 'roof',
81
+ 'rug', 'salad', 'sand', 'sea', 'shelf', 'sky-other', 'skyscraper',
82
+ 'snow', 'solid-other', 'stairs', 'stone', 'straw', 'structural-other',
83
+ 'table', 'tent', 'textile-other', 'towel', 'tree', 'vegetable',
84
+ 'wall-brick', 'wall-concrete', 'wall-other', 'wall-panel',
85
+ 'wall-stone', 'wall-tile', 'wall-wood', 'water-other', 'waterdrops',
86
+ 'window-blind', 'window-other', 'wood'
87
+ ]
88
+
89
+
90
+ def loveda_classes():
91
+ """LoveDA class names for external use."""
92
+ return [
93
+ 'background', 'building', 'road', 'water', 'barren', 'forest',
94
+ 'agricultural'
95
+ ]
96
+
97
+
98
+ def potsdam_classes():
99
+ """Potsdam class names for external use."""
100
+ return [
101
+ 'impervious_surface', 'building', 'low_vegetation', 'tree', 'car',
102
+ 'clutter'
103
+ ]
104
+
105
+
106
+ def vaihingen_classes():
107
+ """Vaihingen class names for external use."""
108
+ return [
109
+ 'impervious_surface', 'building', 'low_vegetation', 'tree', 'car',
110
+ 'clutter'
111
+ ]
112
+
113
+
114
+ def isaid_classes():
115
+ """iSAID class names for external use."""
116
+ return [
117
+ 'background', 'ship', 'store_tank', 'baseball_diamond', 'tennis_court',
118
+ 'basketball_court', 'Ground_Track_Field', 'Bridge', 'Large_Vehicle',
119
+ 'Small_Vehicle', 'Helicopter', 'Swimming_pool', 'Roundabout',
120
+ 'Soccer_ball_field', 'plane', 'Harbor'
121
+ ]
122
+
123
+
124
+ def stare_classes():
125
+ """stare class names for external use."""
126
+ return ['background', 'vessel']
127
+
128
+
129
+ def occludedface_classes():
130
+ """occludedface class names for external use."""
131
+ return ['background', 'face']
132
+
133
+
134
+ def cityscapes_palette():
135
+ """Cityscapes palette for external use."""
136
+ return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
137
+ [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
138
+ [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
139
+ [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
140
+ [0, 0, 230], [119, 11, 32]]
141
+
142
+
143
+ def ade_palette():
144
+ """ADE20K palette for external use."""
145
+ return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
146
+ [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
147
+ [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
148
+ [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
149
+ [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
150
+ [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
151
+ [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
152
+ [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
153
+ [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
154
+ [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
155
+ [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
156
+ [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
157
+ [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
158
+ [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
159
+ [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
160
+ [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
161
+ [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
162
+ [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
163
+ [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
164
+ [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
165
+ [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
166
+ [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
167
+ [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
168
+ [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
169
+ [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
170
+ [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
171
+ [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
172
+ [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
173
+ [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
174
+ [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
175
+ [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
176
+ [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
177
+ [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
178
+ [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
179
+ [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
180
+ [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
181
+ [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
182
+ [102, 255, 0], [92, 0, 255]]
183
+
184
+
185
+ def voc_palette():
186
+ """Pascal VOC palette for external use."""
187
+ return [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
188
+ [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
189
+ [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
190
+ [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
191
+ [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
192
+
193
+
194
+ def cocostuff_palette():
195
+ """CocoStuff palette for external use."""
196
+ return [[0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192],
197
+ [0, 64, 64], [0, 192, 224], [0, 192, 192], [128, 192, 64],
198
+ [0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224],
199
+ [0, 0, 64], [0, 160, 192], [128, 0, 96], [128, 0, 192],
200
+ [0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192],
201
+ [128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128],
202
+ [64, 128, 32], [0, 160, 0], [0, 0, 0], [192, 128, 160], [0, 32, 0],
203
+ [0, 128, 128], [64, 128, 160], [128, 160, 0], [0, 128, 0],
204
+ [192, 128, 32], [128, 96, 128], [0, 0, 128], [64, 0, 32],
205
+ [0, 224, 128], [128, 0, 0], [192, 0, 160], [0, 96, 128],
206
+ [128, 128, 128], [64, 0, 160], [128, 224, 128], [128, 128, 64],
207
+ [192, 0, 32], [128, 96, 0], [128, 0, 192], [0, 128, 32],
208
+ [64, 224, 0], [0, 0, 64], [128, 128, 160], [64, 96, 0],
209
+ [0, 128, 192], [0, 128, 160], [192, 224, 0], [0, 128, 64],
210
+ [128, 128, 32], [192, 32, 128], [0, 64, 192], [0, 0, 32],
211
+ [64, 160, 128], [128, 64, 64], [128, 0, 160], [64, 32, 128],
212
+ [128, 192, 192], [0, 0, 160], [192, 160, 128], [128, 192, 0],
213
+ [128, 0, 96], [192, 32, 0], [128, 64, 128], [64, 128, 96],
214
+ [64, 160, 0], [0, 64, 0], [192, 128, 224], [64, 32, 0],
215
+ [0, 192, 128], [64, 128, 224], [192, 160, 0], [0, 192, 0],
216
+ [192, 128, 96], [192, 96, 128], [0, 64, 128], [64, 0, 96],
217
+ [64, 224, 128], [128, 64, 0], [192, 0, 224], [64, 96, 128],
218
+ [128, 192, 128], [64, 0, 224], [192, 224, 128], [128, 192, 64],
219
+ [192, 0, 96], [192, 96, 0], [128, 64, 192], [0, 128, 96],
220
+ [0, 224, 0], [64, 64, 64], [128, 128, 224], [0, 96, 0],
221
+ [64, 192, 192], [0, 128, 224], [128, 224, 0], [64, 192, 64],
222
+ [128, 128, 96], [128, 32, 128], [64, 0, 192], [0, 64, 96],
223
+ [0, 160, 128], [192, 0, 64], [128, 64, 224], [0, 32, 128],
224
+ [192, 128, 192], [0, 64, 224], [128, 160, 128], [192, 128, 0],
225
+ [128, 64, 32], [128, 32, 64], [192, 0, 128], [64, 192, 32],
226
+ [0, 160, 64], [64, 0, 0], [192, 192, 160], [0, 32, 64],
227
+ [64, 128, 128], [64, 192, 160], [128, 160, 64], [64, 128, 0],
228
+ [192, 192, 32], [128, 96, 192], [64, 0, 128], [64, 64, 32],
229
+ [0, 224, 192], [192, 0, 0], [192, 64, 160], [0, 96, 192],
230
+ [192, 128, 128], [64, 64, 160], [128, 224, 192], [192, 128, 64],
231
+ [192, 64, 32], [128, 96, 64], [192, 0, 192], [0, 192, 32],
232
+ [64, 224, 64], [64, 0, 64], [128, 192, 160], [64, 96, 64],
233
+ [64, 128, 192], [0, 192, 160], [192, 224, 64], [64, 128, 64],
234
+ [128, 192, 32], [192, 32, 192], [64, 64, 192], [0, 64, 32],
235
+ [64, 160, 192], [192, 64, 64], [128, 64, 160], [64, 32, 192],
236
+ [192, 192, 192], [0, 64, 160], [192, 160, 192], [192, 192, 0],
237
+ [128, 64, 96], [192, 32, 64], [192, 64, 128], [64, 192, 96],
238
+ [64, 160, 64], [64, 64, 0]]
239
+
240
+
241
+ def loveda_palette():
242
+ """LoveDA palette for external use."""
243
+ return [[255, 255, 255], [255, 0, 0], [255, 255, 0], [0, 0, 255],
244
+ [159, 129, 183], [0, 255, 0], [255, 195, 128]]
245
+
246
+
247
+ def potsdam_palette():
248
+ """Potsdam palette for external use."""
249
+ return [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0],
250
+ [255, 255, 0], [255, 0, 0]]
251
+
252
+
253
+ def vaihingen_palette():
254
+ """Vaihingen palette for external use."""
255
+ return [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0],
256
+ [255, 255, 0], [255, 0, 0]]
257
+
258
+
259
+ def isaid_palette():
260
+ """iSAID palette for external use."""
261
+ return [[0, 0, 0], [0, 0, 63], [0, 63, 63], [0, 63, 0], [0, 63, 127],
262
+ [0, 63, 191], [0, 63, 255], [0, 127, 63], [0, 127,
263
+ 127], [0, 0, 127],
264
+ [0, 0, 191], [0, 0, 255], [0, 191, 127], [0, 127, 191],
265
+ [0, 127, 255], [0, 100, 155]]
266
+
267
+
268
+ def stare_palette():
269
+ """STARE palette for external use."""
270
+ return [[120, 120, 120], [6, 230, 230]]
271
+
272
+
273
+ def occludedface_palette():
274
+ """occludedface palette for external use."""
275
+ return [[0, 0, 0], [128, 0, 0]]
276
+
277
+
278
+ dataset_aliases = {
279
+ 'cityscapes': ['cityscapes'],
280
+ 'ade': ['ade', 'ade20k'],
281
+ 'voc': ['voc', 'pascal_voc', 'voc12', 'voc12aug'],
282
+ 'loveda': ['loveda'],
283
+ 'potsdam': ['potsdam'],
284
+ 'vaihingen': ['vaihingen'],
285
+ 'cocostuff': [
286
+ 'cocostuff', 'cocostuff10k', 'cocostuff164k', 'coco-stuff',
287
+ 'coco-stuff10k', 'coco-stuff164k', 'coco_stuff', 'coco_stuff10k',
288
+ 'coco_stuff164k'
289
+ ],
290
+ 'isaid': ['isaid', 'iSAID'],
291
+ 'stare': ['stare', 'STARE'],
292
+ 'occludedface': ['occludedface']
293
+ }
294
+
295
+
296
+ def get_classes(dataset):
297
+ """Get class names of a dataset."""
298
+ alias2name = {}
299
+ for name, aliases in dataset_aliases.items():
300
+ for alias in aliases:
301
+ alias2name[alias] = name
302
+
303
+ if mmcv.is_str(dataset):
304
+ if dataset in alias2name:
305
+ labels = eval(alias2name[dataset] + '_classes()')
306
+ else:
307
+ raise ValueError(f'Unrecognized dataset: {dataset}')
308
+ else:
309
+ raise TypeError(f'dataset must a str, but got {type(dataset)}')
310
+ return labels
311
+
312
+
313
+ def get_palette(dataset):
314
+ """Get class palette (RGB) of a dataset."""
315
+ alias2name = {}
316
+ for name, aliases in dataset_aliases.items():
317
+ for alias in aliases:
318
+ alias2name[alias] = name
319
+
320
+ if mmcv.is_str(dataset):
321
+ if dataset in alias2name:
322
+ labels = eval(alias2name[dataset] + '_palette()')
323
+ else:
324
+ raise ValueError(f'Unrecognized dataset: {dataset}')
325
+ else:
326
+ raise TypeError(f'dataset must a str, but got {type(dataset)}')
327
+ return labels
modelsforCIML/mmseg/core/evaluation/eval_hooks.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os.path as osp
3
+ import warnings
4
+
5
+ import torch.distributed as dist
6
+ from mmcv.runner import DistEvalHook as _DistEvalHook
7
+ from mmcv.runner import EvalHook as _EvalHook
8
+ from torch.nn.modules.batchnorm import _BatchNorm
9
+
10
+
11
+ class EvalHook(_EvalHook):
12
+ """Single GPU EvalHook, with efficient test support.
13
+
14
+ Args:
15
+ by_epoch (bool): Determine perform evaluation by epoch or by iteration.
16
+ If set to True, it will perform by epoch. Otherwise, by iteration.
17
+ Default: False.
18
+ efficient_test (bool): Whether save the results as local numpy files to
19
+ save CPU memory during evaluation. Default: False.
20
+ pre_eval (bool): Whether to use progressive mode to evaluate model.
21
+ Default: False.
22
+ Returns:
23
+ list: The prediction results.
24
+ """
25
+
26
+ greater_keys = ['mIoU', 'mAcc', 'aAcc']
27
+
28
+ def __init__(self,
29
+ *args,
30
+ by_epoch=False,
31
+ efficient_test=False,
32
+ pre_eval=False,
33
+ **kwargs):
34
+ super().__init__(*args, by_epoch=by_epoch, **kwargs)
35
+ self.pre_eval = pre_eval
36
+ self.latest_results = None
37
+
38
+ if efficient_test:
39
+ warnings.warn(
40
+ 'DeprecationWarning: ``efficient_test`` for evaluation hook '
41
+ 'is deprecated, the evaluation hook is CPU memory friendly '
42
+ 'with ``pre_eval=True`` as argument for ``single_gpu_test()`` '
43
+ 'function')
44
+
45
+ def _do_evaluate(self, runner):
46
+ """perform evaluation and save ckpt."""
47
+ if not self._should_evaluate(runner):
48
+ return
49
+
50
+ from mmseg.apis import single_gpu_test
51
+ results = single_gpu_test(
52
+ runner.model, self.dataloader, show=False, pre_eval=self.pre_eval)
53
+ self.latest_results = results
54
+ runner.log_buffer.clear()
55
+ runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
56
+ key_score = self.evaluate(runner, results)
57
+ if self.save_best:
58
+ self._save_ckpt(runner, key_score)
59
+
60
+
61
+ class DistEvalHook(_DistEvalHook):
62
+ """Distributed EvalHook, with efficient test support.
63
+
64
+ Args:
65
+ by_epoch (bool): Determine perform evaluation by epoch or by iteration.
66
+ If set to True, it will perform by epoch. Otherwise, by iteration.
67
+ Default: False.
68
+ efficient_test (bool): Whether save the results as local numpy files to
69
+ save CPU memory during evaluation. Default: False.
70
+ pre_eval (bool): Whether to use progressive mode to evaluate model.
71
+ Default: False.
72
+ Returns:
73
+ list: The prediction results.
74
+ """
75
+
76
+ greater_keys = ['mIoU', 'mAcc', 'aAcc']
77
+
78
+ def __init__(self,
79
+ *args,
80
+ by_epoch=False,
81
+ efficient_test=False,
82
+ pre_eval=False,
83
+ **kwargs):
84
+ super().__init__(*args, by_epoch=by_epoch, **kwargs)
85
+ self.pre_eval = pre_eval
86
+ self.latest_results = None
87
+ if efficient_test:
88
+ warnings.warn(
89
+ 'DeprecationWarning: ``efficient_test`` for evaluation hook '
90
+ 'is deprecated, the evaluation hook is CPU memory friendly '
91
+ 'with ``pre_eval=True`` as argument for ``multi_gpu_test()`` '
92
+ 'function')
93
+
94
+ def _do_evaluate(self, runner):
95
+ """perform evaluation and save ckpt."""
96
+ # Synchronization of BatchNorm's buffer (running_mean
97
+ # and running_var) is not supported in the DDP of pytorch,
98
+ # which may cause the inconsistent performance of models in
99
+ # different ranks, so we broadcast BatchNorm's buffers
100
+ # of rank 0 to other ranks to avoid this.
101
+ if self.broadcast_bn_buffer:
102
+ model = runner.model
103
+ for name, module in model.named_modules():
104
+ if isinstance(module,
105
+ _BatchNorm) and module.track_running_stats:
106
+ dist.broadcast(module.running_var, 0)
107
+ dist.broadcast(module.running_mean, 0)
108
+
109
+ if not self._should_evaluate(runner):
110
+ return
111
+
112
+ tmpdir = self.tmpdir
113
+ if tmpdir is None:
114
+ tmpdir = osp.join(runner.work_dir, '.eval_hook')
115
+
116
+ from mmseg.apis import multi_gpu_test
117
+ results = multi_gpu_test(
118
+ runner.model,
119
+ self.dataloader,
120
+ tmpdir=tmpdir,
121
+ gpu_collect=self.gpu_collect,
122
+ pre_eval=self.pre_eval)
123
+ self.latest_results = results
124
+ runner.log_buffer.clear()
125
+
126
+ if runner.rank == 0:
127
+ print('\n')
128
+ runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
129
+ key_score = self.evaluate(runner, results)
130
+
131
+ if self.save_best:
132
+ self._save_ckpt(runner, key_score)
modelsforCIML/mmseg/core/evaluation/metrics.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from collections import OrderedDict
3
+
4
+ import mmcv
5
+ import numpy as np
6
+ import torch
7
+
8
+
9
+ def f_score(precision, recall, beta=1):
10
+ """calculate the f-score value.
11
+
12
+ Args:
13
+ precision (float | torch.Tensor): The precision value.
14
+ recall (float | torch.Tensor): The recall value.
15
+ beta (int): Determines the weight of recall in the combined score.
16
+ Default: False.
17
+
18
+ Returns:
19
+ [torch.tensor]: The f-score value.
20
+ """
21
+ score = (1 + beta**2) * (precision * recall) / (
22
+ (beta**2 * precision) + recall)
23
+ return score
24
+
25
+
26
+ def intersect_and_union(pred_label,
27
+ label,
28
+ num_classes,
29
+ ignore_index,
30
+ label_map=dict(),
31
+ reduce_zero_label=False):
32
+ """Calculate intersection and Union.
33
+
34
+ Args:
35
+ pred_label (ndarray | str): Prediction segmentation map
36
+ or predict result filename.
37
+ label (ndarray | str): Ground truth segmentation map
38
+ or label filename.
39
+ num_classes (int): Number of categories.
40
+ ignore_index (int): Index that will be ignored in evaluation.
41
+ label_map (dict): Mapping old labels to new labels. The parameter will
42
+ work only when label is str. Default: dict().
43
+ reduce_zero_label (bool): Whether ignore zero label. The parameter will
44
+ work only when label is str. Default: False.
45
+
46
+ Returns:
47
+ torch.Tensor: The intersection of prediction and ground truth
48
+ histogram on all classes.
49
+ torch.Tensor: The union of prediction and ground truth histogram on
50
+ all classes.
51
+ torch.Tensor: The prediction histogram on all classes.
52
+ torch.Tensor: The ground truth histogram on all classes.
53
+ """
54
+
55
+ if isinstance(pred_label, str):
56
+ pred_label = torch.from_numpy(np.load(pred_label))
57
+ else:
58
+ # pred_label = torch.from_numpy((pred_label))
59
+ pass
60
+
61
+ if isinstance(label, str):
62
+ label = torch.from_numpy(
63
+ mmcv.imread(label, flag='unchanged', backend='pillow'))
64
+ else:
65
+ label = torch.from_numpy(label)
66
+
67
+ if reduce_zero_label:
68
+ label[label == 0] = 255
69
+ label = label - 1
70
+ label[label == 254] = 255
71
+ if label_map is not None:
72
+ label_copy = label.clone()
73
+ for old_id, new_id in label_map.items():
74
+ label[label_copy == old_id] = new_id
75
+
76
+ mask = (label != ignore_index)
77
+ # print(mask.shape, pred_label.shape)
78
+ pred_label = pred_label[mask]
79
+ label = label[mask]
80
+
81
+ intersect = pred_label[pred_label == label]
82
+ area_intersect = torch.histc(
83
+ intersect.float(), bins=(num_classes), min=0, max=num_classes - 1)
84
+ area_pred_label = torch.histc(
85
+ pred_label.float(), bins=(num_classes), min=0, max=num_classes - 1)
86
+ area_label = torch.histc(
87
+ label.float(), bins=(num_classes), min=0, max=num_classes - 1)
88
+ area_union = area_pred_label + area_label - area_intersect
89
+ return area_intersect, area_union, area_pred_label, area_label
90
+
91
+
92
+ def total_intersect_and_union(results,
93
+ gt_seg_maps,
94
+ num_classes,
95
+ ignore_index,
96
+ label_map=dict(),
97
+ reduce_zero_label=False):
98
+ """Calculate Total Intersection and Union.
99
+
100
+ Args:
101
+ results (list[ndarray] | list[str]): List of prediction segmentation
102
+ maps or list of prediction result filenames.
103
+ gt_seg_maps (list[ndarray] | list[str] | Iterables): list of ground
104
+ truth segmentation maps or list of label filenames.
105
+ num_classes (int): Number of categories.
106
+ ignore_index (int): Index that will be ignored in evaluation.
107
+ label_map (dict): Mapping old labels to new labels. Default: dict().
108
+ reduce_zero_label (bool): Whether ignore zero label. Default: False.
109
+
110
+ Returns:
111
+ ndarray: The intersection of prediction and ground truth histogram
112
+ on all classes.
113
+ ndarray: The union of prediction and ground truth histogram on all
114
+ classes.
115
+ ndarray: The prediction histogram on all classes.
116
+ ndarray: The ground truth histogram on all classes.
117
+ """
118
+ # print('ss1',len(results),len(gt_seg_maps))
119
+ total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64)
120
+ total_area_union = torch.zeros((num_classes, ), dtype=torch.float64)
121
+ total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64)
122
+ total_area_label = torch.zeros((num_classes, ), dtype=torch.float64)
123
+ for result, gt_seg_map in zip(results, gt_seg_maps):
124
+ area_intersect, area_union, area_pred_label, area_label = \
125
+ intersect_and_union(
126
+ result, gt_seg_map, num_classes, ignore_index,
127
+ label_map, reduce_zero_label)
128
+ total_area_intersect += area_intersect
129
+ total_area_union += area_union
130
+ total_area_pred_label += area_pred_label
131
+ total_area_label += area_label
132
+ return total_area_intersect, total_area_union, total_area_pred_label, \
133
+ total_area_label
134
+
135
+
136
+ def mean_iou(results,
137
+ gt_seg_maps,
138
+ num_classes,
139
+ ignore_index,
140
+ nan_to_num=None,
141
+ label_map=dict(),
142
+ reduce_zero_label=False):
143
+ """Calculate Mean Intersection and Union (mIoU)
144
+
145
+ Args:
146
+ results (list[ndarray] | list[str]): List of prediction segmentation
147
+ maps or list of prediction result filenames.
148
+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
149
+ segmentation maps or list of label filenames.
150
+ num_classes (int): Number of categories.
151
+ ignore_index (int): Index that will be ignored in evaluation.
152
+ nan_to_num (int, optional): If specified, NaN values will be replaced
153
+ by the numbers defined by the user. Default: None.
154
+ label_map (dict): Mapping old labels to new labels. Default: dict().
155
+ reduce_zero_label (bool): Whether ignore zero label. Default: False.
156
+
157
+ Returns:
158
+ dict[str, float | ndarray]:
159
+ <aAcc> float: Overall accuracy on all images.
160
+ <Acc> ndarray: Per category accuracy, shape (num_classes, ).
161
+ <IoU> ndarray: Per category IoU, shape (num_classes, ).
162
+ """
163
+ iou_result = eval_metrics(
164
+ results=results,
165
+ gt_seg_maps=gt_seg_maps,
166
+ num_classes=num_classes,
167
+ ignore_index=ignore_index,
168
+ metrics=['mIoU'],
169
+ nan_to_num=nan_to_num,
170
+ label_map=label_map,
171
+ reduce_zero_label=reduce_zero_label)
172
+ return iou_result
173
+
174
+
175
+ def mean_dice(results,
176
+ gt_seg_maps,
177
+ num_classes,
178
+ ignore_index,
179
+ nan_to_num=None,
180
+ label_map=dict(),
181
+ reduce_zero_label=False):
182
+ """Calculate Mean Dice (mDice)
183
+
184
+ Args:
185
+ results (list[ndarray] | list[str]): List of prediction segmentation
186
+ maps or list of prediction result filenames.
187
+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
188
+ segmentation maps or list of label filenames.
189
+ num_classes (int): Number of categories.
190
+ ignore_index (int): Index that will be ignored in evaluation.
191
+ nan_to_num (int, optional): If specified, NaN values will be replaced
192
+ by the numbers defined by the user. Default: None.
193
+ label_map (dict): Mapping old labels to new labels. Default: dict().
194
+ reduce_zero_label (bool): Whether ignore zero label. Default: False.
195
+
196
+ Returns:
197
+ dict[str, float | ndarray]: Default metrics.
198
+ <aAcc> float: Overall accuracy on all images.
199
+ <Acc> ndarray: Per category accuracy, shape (num_classes, ).
200
+ <Dice> ndarray: Per category dice, shape (num_classes, ).
201
+ """
202
+
203
+ dice_result = eval_metrics(
204
+ results=results,
205
+ gt_seg_maps=gt_seg_maps,
206
+ num_classes=num_classes,
207
+ ignore_index=ignore_index,
208
+ metrics=['mDice'],
209
+ nan_to_num=nan_to_num,
210
+ label_map=label_map,
211
+ reduce_zero_label=reduce_zero_label)
212
+ return dice_result
213
+
214
+
215
+ def mean_fscore(results,
216
+ gt_seg_maps,
217
+ num_classes,
218
+ ignore_index,
219
+ nan_to_num=None,
220
+ label_map=dict(),
221
+ reduce_zero_label=False,
222
+ beta=1):
223
+ """Calculate Mean F-Score (mFscore)
224
+
225
+ Args:
226
+ results (list[ndarray] | list[str]): List of prediction segmentation
227
+ maps or list of prediction result filenames.
228
+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
229
+ segmentation maps or list of label filenames.
230
+ num_classes (int): Number of categories.
231
+ ignore_index (int): Index that will be ignored in evaluation.
232
+ nan_to_num (int, optional): If specified, NaN values will be replaced
233
+ by the numbers defined by the user. Default: None.
234
+ label_map (dict): Mapping old labels to new labels. Default: dict().
235
+ reduce_zero_label (bool): Whether ignore zero label. Default: False.
236
+ beta (int): Determines the weight of recall in the combined score.
237
+ Default: False.
238
+
239
+
240
+ Returns:
241
+ dict[str, float | ndarray]: Default metrics.
242
+ <aAcc> float: Overall accuracy on all images.
243
+ <Fscore> ndarray: Per category recall, shape (num_classes, ).
244
+ <Precision> ndarray: Per category precision, shape (num_classes, ).
245
+ <Recall> ndarray: Per category f-score, shape (num_classes, ).
246
+ """
247
+ fscore_result = eval_metrics(
248
+ results=results,
249
+ gt_seg_maps=gt_seg_maps,
250
+ num_classes=num_classes,
251
+ ignore_index=ignore_index,
252
+ metrics=['mFscore'],
253
+ nan_to_num=nan_to_num,
254
+ label_map=label_map,
255
+ reduce_zero_label=reduce_zero_label,
256
+ beta=beta)
257
+ return fscore_result
258
+
259
+
260
+ def eval_metrics(results,
261
+ gt_seg_maps,
262
+ num_classes,
263
+ ignore_index,
264
+ metrics=['mIoU'],
265
+ nan_to_num=None,
266
+ label_map=dict(),
267
+ reduce_zero_label=False,
268
+ beta=1):
269
+ """Calculate evaluation metrics
270
+ Args:
271
+ results (list[ndarray] | list[str]): List of prediction segmentation
272
+ maps or list of prediction result filenames.
273
+ gt_seg_maps (list[ndarray] | list[str] | Iterables): list of ground
274
+ truth segmentation maps or list of label filenames.
275
+ num_classes (int): Number of categories.
276
+ ignore_index (int): Index that will be ignored in evaluation.
277
+ metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
278
+ nan_to_num (int, optional): If specified, NaN values will be replaced
279
+ by the numbers defined by the user. Default: None.
280
+ label_map (dict): Mapping old labels to new labels. Default: dict().
281
+ reduce_zero_label (bool): Whether ignore zero label. Default: False.
282
+ Returns:
283
+ float: Overall accuracy on all images.
284
+ ndarray: Per category accuracy, shape (num_classes, ).
285
+ ndarray: Per category evaluation metrics, shape (num_classes, ).
286
+ """
287
+
288
+ total_area_intersect, total_area_union, total_area_pred_label, total_area_label = total_intersect_and_union(results, gt_seg_maps, num_classes, ignore_index, label_map, reduce_zero_label)
289
+ ret_metrics = total_area_to_metrics(total_area_intersect, total_area_union,
290
+ total_area_pred_label,
291
+ total_area_label, metrics, nan_to_num,
292
+ beta)
293
+
294
+ return ret_metrics
295
+
296
+
297
+ def pre_eval_to_metrics(pre_eval_results,
298
+ metrics=['mIoU'],
299
+ nan_to_num=None,
300
+ beta=1):
301
+ """Convert pre-eval results to metrics.
302
+
303
+ Args:
304
+ pre_eval_results (list[tuple[torch.Tensor]]): per image eval results
305
+ for computing evaluation metric
306
+ metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
307
+ nan_to_num (int, optional): If specified, NaN values will be replaced
308
+ by the numbers defined by the user. Default: None.
309
+ Returns:
310
+ float: Overall accuracy on all images.
311
+ ndarray: Per category accuracy, shape (num_classes, ).
312
+ ndarray: Per category evaluation metrics, shape (num_classes, ).
313
+ """
314
+
315
+ # convert list of tuples to tuple of lists, e.g.
316
+ # [(A_1, B_1, C_1, D_1), ..., (A_n, B_n, C_n, D_n)] to
317
+ # ([A_1, ..., A_n], ..., [D_1, ..., D_n])
318
+ pre_eval_results = tuple(zip(*pre_eval_results))
319
+ assert len(pre_eval_results) == 4
320
+
321
+ total_area_intersect = sum(pre_eval_results[0])
322
+ total_area_union = sum(pre_eval_results[1])
323
+ total_area_pred_label = sum(pre_eval_results[2])
324
+ total_area_label = sum(pre_eval_results[3])
325
+
326
+ ret_metrics = total_area_to_metrics(total_area_intersect, total_area_union,
327
+ total_area_pred_label,
328
+ total_area_label, metrics, nan_to_num,
329
+ beta)
330
+
331
+ return ret_metrics
332
+
333
+
334
+ def total_area_to_metrics(total_area_intersect,
335
+ total_area_union,
336
+ total_area_pred_label,
337
+ total_area_label,
338
+ metrics=['mIoU'],
339
+ nan_to_num=None,
340
+ beta=1):
341
+ """Calculate evaluation metrics
342
+ Args:
343
+ total_area_intersect (ndarray): The intersection of prediction and
344
+ ground truth histogram on all classes.
345
+ total_area_union (ndarray): The union of prediction and ground truth
346
+ histogram on all classes.
347
+ total_area_pred_label (ndarray): The prediction histogram on all
348
+ classes.
349
+ total_area_label (ndarray): The ground truth histogram on all classes.
350
+ metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
351
+ nan_to_num (int, optional): If specified, NaN values will be replaced
352
+ by the numbers defined by the user. Default: None.
353
+ Returns:
354
+ float: Overall accuracy on all images.
355
+ ndarray: Per category accuracy, shape (num_classes, ).
356
+ ndarray: Per category evaluation metrics, shape (num_classes, ).
357
+ """
358
+ if isinstance(metrics, str):
359
+ metrics = [metrics]
360
+ allowed_metrics = ['mIoU', 'mDice', 'mFscore']
361
+ if not set(metrics).issubset(set(allowed_metrics)):
362
+ raise KeyError('metrics {} is not supported'.format(metrics))
363
+
364
+ all_acc = total_area_intersect.sum() / total_area_label.sum()
365
+ ret_metrics = OrderedDict({'aAcc': all_acc})
366
+ for metric in metrics:
367
+ if metric == 'mIoU':
368
+ iou = total_area_intersect / total_area_union
369
+ acc = total_area_intersect / total_area_label
370
+ ret_metrics['IoU'] = iou
371
+ ret_metrics['Acc'] = acc
372
+ elif metric == 'mDice':
373
+ dice = 2 * total_area_intersect / (
374
+ total_area_pred_label + total_area_label)
375
+ acc = total_area_intersect / total_area_label
376
+ ret_metrics['Dice'] = dice
377
+ ret_metrics['Acc'] = acc
378
+ elif metric == 'mFscore':
379
+ precision = total_area_intersect / total_area_pred_label
380
+ recall = total_area_intersect / total_area_label
381
+ f_value = torch.tensor(
382
+ [f_score(x[0], x[1], beta) for x in zip(precision, recall)])
383
+ ret_metrics['Fscore'] = f_value
384
+ ret_metrics['Precision'] = precision
385
+ ret_metrics['Recall'] = recall
386
+
387
+ ret_metrics = {
388
+ metric: value.numpy()
389
+ for metric, value in ret_metrics.items()
390
+ }
391
+ if nan_to_num is not None:
392
+ ret_metrics = OrderedDict({
393
+ metric: np.nan_to_num(metric_value, nan=nan_to_num)
394
+ for metric, metric_value in ret_metrics.items()
395
+ })
396
+ return ret_metrics
modelsforCIML/mmseg/core/hook/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .wandblogger_hook import MMSegWandbHook
3
+
4
+ __all__ = ['MMSegWandbHook']
modelsforCIML/mmseg/core/hook/wandblogger_hook.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os.path as osp
3
+
4
+ import mmcv
5
+ import numpy as np
6
+ from mmcv.runner import HOOKS
7
+ from mmcv.runner.dist_utils import master_only
8
+ from mmcv.runner.hooks.checkpoint import CheckpointHook
9
+ from mmcv.runner.hooks.logger.wandb import WandbLoggerHook
10
+
11
+ from mmseg.core import DistEvalHook, EvalHook
12
+
13
+
14
+ @HOOKS.register_module()
15
+ class MMSegWandbHook(WandbLoggerHook):
16
+ """Enhanced Wandb logger hook for MMSegmentation.
17
+
18
+ Comparing with the :cls:`mmcv.runner.WandbLoggerHook`, this hook can not
19
+ only automatically log all the metrics but also log the following extra
20
+ information - saves model checkpoints as W&B Artifact, and
21
+ logs model prediction as interactive W&B Tables.
22
+
23
+ - Metrics: The MMSegWandbHook will automatically log training
24
+ and validation metrics along with system metrics (CPU/GPU).
25
+
26
+ - Checkpointing: If `log_checkpoint` is True, the checkpoint saved at
27
+ every checkpoint interval will be saved as W&B Artifacts.
28
+ This depends on the : class:`mmcv.runner.CheckpointHook` whose priority
29
+ is higher than this hook. Please refer to
30
+ https://docs.wandb.ai/guides/artifacts/model-versioning
31
+ to learn more about model versioning with W&B Artifacts.
32
+
33
+ - Checkpoint Metadata: If evaluation results are available for a given
34
+ checkpoint artifact, it will have a metadata associated with it.
35
+ The metadata contains the evaluation metrics computed on validation
36
+ data with that checkpoint along with the current epoch. It depends
37
+ on `EvalHook` whose priority is more than MMSegWandbHook.
38
+
39
+ - Evaluation: At every evaluation interval, the `MMSegWandbHook` logs the
40
+ model prediction as interactive W&B Tables. The number of samples
41
+ logged is given by `num_eval_images`. Currently, the `MMSegWandbHook`
42
+ logs the predicted segmentation masks along with the ground truth at
43
+ every evaluation interval. This depends on the `EvalHook` whose
44
+ priority is more than `MMSegWandbHook`. Also note that the data is just
45
+ logged once and subsequent evaluation tables uses reference to the
46
+ logged data to save memory usage. Please refer to
47
+ https://docs.wandb.ai/guides/data-vis to learn more about W&B Tables.
48
+
49
+ ```
50
+ Example:
51
+ log_config = dict(
52
+ ...
53
+ hooks=[
54
+ ...,
55
+ dict(type='MMSegWandbHook',
56
+ init_kwargs={
57
+ 'entity': "YOUR_ENTITY",
58
+ 'project': "YOUR_PROJECT_NAME"
59
+ },
60
+ interval=50,
61
+ log_checkpoint=True,
62
+ log_checkpoint_metadata=True,
63
+ num_eval_images=100,
64
+ bbox_score_thr=0.3)
65
+ ])
66
+ ```
67
+
68
+ Args:
69
+ init_kwargs (dict): A dict passed to wandb.init to initialize
70
+ a W&B run. Please refer to https://docs.wandb.ai/ref/python/init
71
+ for possible key-value pairs.
72
+ interval (int): Logging interval (every k iterations).
73
+ Default 10.
74
+ log_checkpoint (bool): Save the checkpoint at every checkpoint interval
75
+ as W&B Artifacts. Use this for model versioning where each version
76
+ is a checkpoint.
77
+ Default: False
78
+ log_checkpoint_metadata (bool): Log the evaluation metrics computed
79
+ on the validation data with the checkpoint, along with current
80
+ epoch as a metadata to that checkpoint.
81
+ Default: True
82
+ num_eval_images (int): Number of validation images to be logged.
83
+ Default: 100
84
+ """
85
+
86
+ def __init__(self,
87
+ init_kwargs=None,
88
+ interval=50,
89
+ log_checkpoint=False,
90
+ log_checkpoint_metadata=False,
91
+ num_eval_images=100,
92
+ **kwargs):
93
+ super(MMSegWandbHook, self).__init__(init_kwargs, interval, **kwargs)
94
+
95
+ self.log_checkpoint = log_checkpoint
96
+ self.log_checkpoint_metadata = (
97
+ log_checkpoint and log_checkpoint_metadata)
98
+ self.num_eval_images = num_eval_images
99
+ self.log_evaluation = (num_eval_images > 0)
100
+ self.ckpt_hook: CheckpointHook = None
101
+ self.eval_hook: EvalHook = None
102
+ self.test_fn = None
103
+
104
+ @master_only
105
+ def before_run(self, runner):
106
+ super(MMSegWandbHook, self).before_run(runner)
107
+
108
+ # Check if EvalHook and CheckpointHook are available.
109
+ for hook in runner.hooks:
110
+ if isinstance(hook, CheckpointHook):
111
+ self.ckpt_hook = hook
112
+ if isinstance(hook, EvalHook):
113
+ from mmseg.apis import single_gpu_test
114
+ self.eval_hook = hook
115
+ self.test_fn = single_gpu_test
116
+ if isinstance(hook, DistEvalHook):
117
+ from mmseg.apis import multi_gpu_test
118
+ self.eval_hook = hook
119
+ self.test_fn = multi_gpu_test
120
+
121
+ # Check conditions to log checkpoint
122
+ if self.log_checkpoint:
123
+ if self.ckpt_hook is None:
124
+ self.log_checkpoint = False
125
+ self.log_checkpoint_metadata = False
126
+ runner.logger.warning(
127
+ 'To log checkpoint in MMSegWandbHook, `CheckpointHook` is'
128
+ 'required, please check hooks in the runner.')
129
+ else:
130
+ self.ckpt_interval = self.ckpt_hook.interval
131
+
132
+ # Check conditions to log evaluation
133
+ if self.log_evaluation or self.log_checkpoint_metadata:
134
+ if self.eval_hook is None:
135
+ self.log_evaluation = False
136
+ self.log_checkpoint_metadata = False
137
+ runner.logger.warning(
138
+ 'To log evaluation or checkpoint metadata in '
139
+ 'MMSegWandbHook, `EvalHook` or `DistEvalHook` in mmseg '
140
+ 'is required, please check whether the validation '
141
+ 'is enabled.')
142
+ else:
143
+ self.eval_interval = self.eval_hook.interval
144
+ self.val_dataset = self.eval_hook.dataloader.dataset
145
+ # Determine the number of samples to be logged.
146
+ if self.num_eval_images > len(self.val_dataset):
147
+ self.num_eval_images = len(self.val_dataset)
148
+ runner.logger.warning(
149
+ f'The num_eval_images ({self.num_eval_images}) is '
150
+ 'greater than the total number of validation samples '
151
+ f'({len(self.val_dataset)}). The complete validation '
152
+ 'dataset will be logged.')
153
+
154
+ # Check conditions to log checkpoint metadata
155
+ if self.log_checkpoint_metadata:
156
+ assert self.ckpt_interval % self.eval_interval == 0, \
157
+ 'To log checkpoint metadata in MMSegWandbHook, the interval ' \
158
+ f'of checkpoint saving ({self.ckpt_interval}) should be ' \
159
+ 'divisible by the interval of evaluation ' \
160
+ f'({self.eval_interval}).'
161
+
162
+ # Initialize evaluation table
163
+ if self.log_evaluation:
164
+ # Initialize data table
165
+ self._init_data_table()
166
+ # Add data to the data table
167
+ self._add_ground_truth(runner)
168
+ # Log ground truth data
169
+ self._log_data_table()
170
+
171
+ # for the reason of this double-layered structure, refer to
172
+ # https://github.com/open-mmlab/mmdetection/issues/8145#issuecomment-1345343076
173
+ def after_train_iter(self, runner):
174
+ if self.get_mode(runner) == 'train':
175
+ # An ugly patch. The iter-based eval hook will call the
176
+ # `after_train_iter` method of all logger hooks before evaluation.
177
+ # Use this trick to skip that call.
178
+ # Don't call super method at first, it will clear the log_buffer
179
+ return super(MMSegWandbHook, self).after_train_iter(runner)
180
+ else:
181
+ super(MMSegWandbHook, self).after_train_iter(runner)
182
+ self._after_train_iter(runner)
183
+
184
+ @master_only
185
+ def _after_train_iter(self, runner):
186
+ if self.by_epoch:
187
+ return
188
+
189
+ # Save checkpoint and metadata
190
+ if (self.log_checkpoint
191
+ and self.every_n_iters(runner, self.ckpt_interval)
192
+ or (self.ckpt_hook.save_last and self.is_last_iter(runner))):
193
+ if self.log_checkpoint_metadata and self.eval_hook:
194
+ metadata = {
195
+ 'iter': runner.iter + 1,
196
+ **self._get_eval_results()
197
+ }
198
+ else:
199
+ metadata = None
200
+ aliases = [f'iter_{runner.iter+1}', 'latest']
201
+ model_path = osp.join(self.ckpt_hook.out_dir,
202
+ f'iter_{runner.iter+1}.pth')
203
+ self._log_ckpt_as_artifact(model_path, aliases, metadata)
204
+
205
+ # Save prediction table
206
+ if self.log_evaluation and self.eval_hook._should_evaluate(runner):
207
+ # Currently the results of eval_hook is not reused by wandb, so
208
+ # wandb will run evaluation again internally. We will consider
209
+ # refactoring this function afterwards
210
+ results = self.test_fn(runner.model, self.eval_hook.dataloader)
211
+ # Initialize evaluation table
212
+ self._init_pred_table()
213
+ # Log predictions
214
+ self._log_predictions(results, runner)
215
+ # Log the table
216
+ self._log_eval_table(runner.iter + 1)
217
+
218
+ @master_only
219
+ def after_run(self, runner):
220
+ self.wandb.finish()
221
+
222
+ def _log_ckpt_as_artifact(self, model_path, aliases, metadata=None):
223
+ """Log model checkpoint as W&B Artifact.
224
+
225
+ Args:
226
+ model_path (str): Path of the checkpoint to log.
227
+ aliases (list): List of the aliases associated with this artifact.
228
+ metadata (dict, optional): Metadata associated with this artifact.
229
+ """
230
+ model_artifact = self.wandb.Artifact(
231
+ f'run_{self.wandb.run.id}_model', type='model', metadata=metadata)
232
+ model_artifact.add_file(model_path)
233
+ self.wandb.log_artifact(model_artifact, aliases=aliases)
234
+
235
+ def _get_eval_results(self):
236
+ """Get model evaluation results."""
237
+ results = self.eval_hook.latest_results
238
+ eval_results = self.val_dataset.evaluate(
239
+ results, logger='silent', **self.eval_hook.eval_kwargs)
240
+ return eval_results
241
+
242
+ def _init_data_table(self):
243
+ """Initialize the W&B Tables for validation data."""
244
+ columns = ['image_name', 'image']
245
+ self.data_table = self.wandb.Table(columns=columns)
246
+
247
+ def _init_pred_table(self):
248
+ """Initialize the W&B Tables for model evaluation."""
249
+ columns = ['image_name', 'ground_truth', 'prediction']
250
+ self.eval_table = self.wandb.Table(columns=columns)
251
+
252
+ def _add_ground_truth(self, runner):
253
+ # Get image loading pipeline
254
+ from mmseg.datasets.pipelines import LoadImageFromFile
255
+ img_loader = None
256
+ for t in self.val_dataset.pipeline.transforms:
257
+ if isinstance(t, LoadImageFromFile):
258
+ img_loader = t
259
+
260
+ if img_loader is None:
261
+ self.log_evaluation = False
262
+ runner.logger.warning(
263
+ 'LoadImageFromFile is required to add images '
264
+ 'to W&B Tables.')
265
+ return
266
+
267
+ # Select the images to be logged.
268
+ self.eval_image_indexs = np.arange(len(self.val_dataset))
269
+ # Set seed so that same validation set is logged each time.
270
+ np.random.seed(42)
271
+ np.random.shuffle(self.eval_image_indexs)
272
+ self.eval_image_indexs = self.eval_image_indexs[:self.num_eval_images]
273
+
274
+ classes = self.val_dataset.CLASSES
275
+ self.class_id_to_label = {id: name for id, name in enumerate(classes)}
276
+ self.class_set = self.wandb.Classes([{
277
+ 'id': id,
278
+ 'name': name
279
+ } for id, name in self.class_id_to_label.items()])
280
+
281
+ for idx in self.eval_image_indexs:
282
+ img_info = self.val_dataset.img_infos[idx]
283
+ image_name = img_info['filename']
284
+
285
+ # Get image and convert from BGR to RGB
286
+ img_meta = img_loader(
287
+ dict(img_info=img_info, img_prefix=self.val_dataset.img_dir))
288
+ image = mmcv.bgr2rgb(img_meta['img'])
289
+
290
+ # Get segmentation mask
291
+ seg_mask = self.val_dataset.get_gt_seg_map_by_idx(idx)
292
+ # Dict of masks to be logged.
293
+ wandb_masks = None
294
+ if seg_mask.ndim == 2:
295
+ wandb_masks = {
296
+ 'ground_truth': {
297
+ 'mask_data': seg_mask,
298
+ 'class_labels': self.class_id_to_label
299
+ }
300
+ }
301
+
302
+ # Log a row to the data table.
303
+ self.data_table.add_data(
304
+ image_name,
305
+ self.wandb.Image(
306
+ image, masks=wandb_masks, classes=self.class_set))
307
+ else:
308
+ runner.logger.warning(
309
+ f'The segmentation mask is {seg_mask.ndim}D which '
310
+ 'is not supported by W&B.')
311
+ self.log_evaluation = False
312
+ return
313
+
314
+ def _log_predictions(self, results, runner):
315
+ table_idxs = self.data_table_ref.get_index()
316
+ assert len(table_idxs) == len(self.eval_image_indexs)
317
+ assert len(results) == len(self.val_dataset)
318
+
319
+ for ndx, eval_image_index in enumerate(self.eval_image_indexs):
320
+ # Get the result
321
+ pred_mask = results[eval_image_index]
322
+
323
+ if pred_mask.ndim == 2:
324
+ wandb_masks = {
325
+ 'prediction': {
326
+ 'mask_data': pred_mask,
327
+ 'class_labels': self.class_id_to_label
328
+ }
329
+ }
330
+
331
+ # Log a row to the data table.
332
+ self.eval_table.add_data(
333
+ self.data_table_ref.data[ndx][0],
334
+ self.data_table_ref.data[ndx][1],
335
+ self.wandb.Image(
336
+ self.data_table_ref.data[ndx][1],
337
+ masks=wandb_masks,
338
+ classes=self.class_set))
339
+ else:
340
+ runner.logger.warning(
341
+ 'The predictio segmentation mask is '
342
+ f'{pred_mask.ndim}D which is not supported by W&B.')
343
+ self.log_evaluation = False
344
+ return
345
+
346
+ def _log_data_table(self):
347
+ """Log the W&B Tables for validation data as artifact and calls
348
+ `use_artifact` on it so that the evaluation table can use the reference
349
+ of already uploaded images.
350
+
351
+ This allows the data to be uploaded just once.
352
+ """
353
+ data_artifact = self.wandb.Artifact('val', type='dataset')
354
+ data_artifact.add(self.data_table, 'val_data')
355
+
356
+ self.wandb.run.use_artifact(data_artifact)
357
+ data_artifact.wait()
358
+
359
+ self.data_table_ref = data_artifact.get('val_data')
360
+
361
+ def _log_eval_table(self, iter):
362
+ """Log the W&B Tables for model evaluation.
363
+
364
+ The table will be logged multiple times creating new version. Use this
365
+ to compare models at different intervals interactively.
366
+ """
367
+ pred_artifact = self.wandb.Artifact(
368
+ f'run_{self.wandb.run.id}_pred', type='evaluation')
369
+ pred_artifact.add(self.eval_table, 'eval_data')
370
+ self.wandb.run.log_artifact(pred_artifact)
modelsforCIML/mmseg/core/optimizers/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .layer_decay_optimizer_constructor import (
3
+ LayerDecayOptimizerConstructor, LearningRateDecayOptimizerConstructor)
4
+
5
+ __all__ = [
6
+ 'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor'
7
+ ]
modelsforCIML/mmseg/core/optimizers/layer_decay_optimizer_constructor.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import json
3
+ import warnings
4
+
5
+ from mmcv.runner import DefaultOptimizerConstructor, get_dist_info
6
+
7
+ from mmseg.utils import get_root_logger
8
+ from ..builder import OPTIMIZER_BUILDERS
9
+
10
+
11
+ def get_layer_id_for_convnext(var_name, max_layer_id):
12
+ """Get the layer id to set the different learning rates in ``layer_wise``
13
+ decay_type.
14
+
15
+ Args:
16
+ var_name (str): The key of the model.
17
+ max_layer_id (int): Maximum number of backbone layers.
18
+
19
+ Returns:
20
+ int: The id number corresponding to different learning rate in
21
+ ``LearningRateDecayOptimizerConstructor``.
22
+ """
23
+
24
+ if var_name in ('backbone.cls_token', 'backbone.mask_token',
25
+ 'backbone.pos_embed', 'backbone2.cls_token', 'backbone2.mask_token',
26
+ 'backbone2.pos_embed'):
27
+ return 0
28
+ elif (var_name.startswith('backbone.downsample_layers') or var_name.startswith('backbone2.downsample_layers')):
29
+ stage_id = int(var_name.split('.')[2])
30
+ if stage_id == 0:
31
+ layer_id = 0
32
+ elif stage_id == 1:
33
+ layer_id = 2
34
+ elif stage_id == 2:
35
+ layer_id = 3
36
+ elif stage_id == 3:
37
+ layer_id = max_layer_id
38
+ return layer_id
39
+ elif (var_name.startswith('backbone.stages') or var_name.startswith('backbone2.stages')):
40
+ stage_id = int(var_name.split('.')[2])
41
+ block_id = int(var_name.split('.')[3])
42
+ if stage_id == 0:
43
+ layer_id = 1
44
+ elif stage_id == 1:
45
+ layer_id = 2
46
+ elif stage_id == 2:
47
+ layer_id = 3 + block_id // 3
48
+ elif stage_id == 3:
49
+ layer_id = max_layer_id
50
+ return layer_id
51
+ else:
52
+ return max_layer_id + 1
53
+
54
+
55
+ def get_stage_id_for_convnext(var_name, max_stage_id):
56
+ """Get the stage id to set the different learning rates in ``stage_wise``
57
+ decay_type.
58
+
59
+ Args:
60
+ var_name (str): The key of the model.
61
+ max_stage_id (int): Maximum number of backbone layers.
62
+
63
+ Returns:
64
+ int: The id number corresponding to different learning rate in
65
+ ``LearningRateDecayOptimizerConstructor``.
66
+ """
67
+
68
+ if var_name in ('backbone.cls_token', 'backbone.mask_token',
69
+ 'backbone.pos_embed', 'backbone2.cls_token', 'backbone2.mask_token',
70
+ 'backbone2.pos_embed'):
71
+ return 0
72
+ elif (var_name.startswith('backbone.downsample_layers') or var_name.startswith('backbone2.downsample_layers')):
73
+ return 0
74
+ elif ((var_name.startswith('backbone.stages') or var_name.startswith('backbone2.stages'))):
75
+ stage_id = int(var_name.split('.')[2])
76
+ return stage_id + 1
77
+ else:
78
+ return max_stage_id - 1
79
+
80
+
81
+ def get_layer_id_for_vit(var_name, max_layer_id):
82
+ """Get the layer id to set the different learning rates.
83
+
84
+ Args:
85
+ var_name (str): The key of the model.
86
+ num_max_layer (int): Maximum number of backbone layers.
87
+
88
+ Returns:
89
+ int: Returns the layer id of the key.
90
+ """
91
+
92
+ if var_name in ('backbone.cls_token', 'backbone.mask_token',
93
+ 'backbone.pos_embed', 'backbone2.cls_token', 'backbone2.mask_token',
94
+ 'backbone2.pos_embed'):
95
+ return 0
96
+ elif (var_name.startswith('backbone.patch_embed') or var_name.startswith('backbone2.patch_embed')):
97
+ return 0
98
+ elif (var_name.startswith('backbone.layers') or var_name.startswith('backbone2.layers')):
99
+ layer_id = int(var_name.split('.')[2])
100
+ return layer_id + 1
101
+ else:
102
+ return max_layer_id - 1
103
+
104
+
105
+ @OPTIMIZER_BUILDERS.register_module()
106
+ class LearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor):
107
+ """Different learning rates are set for different layers of backbone.
108
+
109
+ Note: Currently, this optimizer constructor is built for ConvNeXt,
110
+ BEiT and MAE.
111
+ """
112
+
113
+ def add_params(self, params, module, **kwargs):
114
+ """Add all parameters of module to the params list.
115
+
116
+ The parameters of the given module will be added to the list of param
117
+ groups, with specific rules defined by paramwise_cfg.
118
+
119
+ Args:
120
+ params (list[dict]): A list of param groups, it will be modified
121
+ in place.
122
+ module (nn.Module): The module to be added.
123
+ """
124
+ logger = get_root_logger()
125
+
126
+ parameter_groups = {}
127
+ logger.info(f'self.paramwise_cfg is {self.paramwise_cfg}')
128
+ num_layers = self.paramwise_cfg.get('num_layers') + 2
129
+ decay_rate = self.paramwise_cfg.get('decay_rate')
130
+ decay_type = self.paramwise_cfg.get('decay_type', 'layer_wise')
131
+ logger.info('Build LearningRateDecayOptimizerConstructor '
132
+ f'{decay_type} {decay_rate} - {num_layers}')
133
+ weight_decay = self.base_wd
134
+ for name, param in module.named_parameters():
135
+ if not param.requires_grad:
136
+ continue # frozen weights
137
+ if len(param.shape) == 1 or name.endswith('.bias') or name in (
138
+ 'pos_embed', 'cls_token'):
139
+ group_name = 'no_decay'
140
+ this_weight_decay = 0.
141
+ else:
142
+ group_name = 'decay'
143
+ this_weight_decay = weight_decay
144
+ if 'layer_wise' in decay_type:
145
+ if 'ConvNeXt' in module.backbone.__class__.__name__:
146
+ layer_id = get_layer_id_for_convnext(
147
+ name, self.paramwise_cfg.get('num_layers'))
148
+ logger.info(f'set param {name} as id {layer_id}')
149
+ elif 'BEiT' in module.backbone.__class__.__name__ or \
150
+ 'MAE' in module.backbone.__class__.__name__:
151
+ layer_id = get_layer_id_for_vit(name, num_layers)
152
+ logger.info(f'set param {name} as id {layer_id}')
153
+ else:
154
+ raise NotImplementedError()
155
+ elif decay_type == 'stage_wise':
156
+ if 'ConvNeXt' in module.backbone.__class__.__name__:
157
+ layer_id = get_stage_id_for_convnext(name, num_layers)
158
+ logger.info(f'set param {name} as id {layer_id}')
159
+ else:
160
+ raise NotImplementedError()
161
+ group_name = f'layer_{layer_id}_{group_name}'
162
+
163
+ if group_name not in parameter_groups:
164
+ scale = decay_rate**(num_layers - layer_id - 1)
165
+
166
+ parameter_groups[group_name] = {
167
+ 'weight_decay': this_weight_decay,
168
+ 'params': [],
169
+ 'param_names': [],
170
+ 'lr_scale': scale,
171
+ 'group_name': group_name,
172
+ 'lr': scale * self.base_lr,
173
+ }
174
+
175
+ parameter_groups[group_name]['params'].append(param)
176
+ parameter_groups[group_name]['param_names'].append(name)
177
+ rank, _ = get_dist_info()
178
+ if rank == 0:
179
+ to_display = {}
180
+ for key in parameter_groups:
181
+ to_display[key] = {
182
+ 'param_names': parameter_groups[key]['param_names'],
183
+ 'lr_scale': parameter_groups[key]['lr_scale'],
184
+ 'lr': parameter_groups[key]['lr'],
185
+ 'weight_decay': parameter_groups[key]['weight_decay'],
186
+ }
187
+ logger.info(f'Param groups = {json.dumps(to_display, indent=2)}')
188
+ params.extend(parameter_groups.values())
189
+
190
+
191
+ @OPTIMIZER_BUILDERS.register_module()
192
+ class LayerDecayOptimizerConstructor(LearningRateDecayOptimizerConstructor):
193
+ """Different learning rates are set for different layers of backbone.
194
+
195
+ Note: Currently, this optimizer constructor is built for BEiT,
196
+ and it will be deprecated.
197
+ Please use ``LearningRateDecayOptimizerConstructor`` instead.
198
+ """
199
+
200
+ def __init__(self, optimizer_cfg, paramwise_cfg):
201
+ warnings.warn('DeprecationWarning: Original '
202
+ 'LayerDecayOptimizerConstructor of BEiT '
203
+ 'will be deprecated. Please use '
204
+ 'LearningRateDecayOptimizerConstructor instead, '
205
+ 'and set decay_type = layer_wise_vit in paramwise_cfg.')
206
+ paramwise_cfg.update({'decay_type': 'layer_wise_vit'})
207
+ warnings.warn('DeprecationWarning: Layer_decay_rate will '
208
+ 'be deleted, please use decay_rate instead.')
209
+ paramwise_cfg['decay_rate'] = paramwise_cfg.pop('layer_decay_rate')
210
+ super(LayerDecayOptimizerConstructor,
211
+ self).__init__(optimizer_cfg, paramwise_cfg)
modelsforCIML/mmseg/core/seg/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .builder import build_pixel_sampler
3
+ from .sampler import BasePixelSampler, OHEMPixelSampler
4
+
5
+ __all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler']
modelsforCIML/mmseg/core/seg/builder.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from mmcv.utils import Registry, build_from_cfg
3
+
4
+ PIXEL_SAMPLERS = Registry('pixel sampler')
5
+
6
+
7
+ def build_pixel_sampler(cfg, **default_args):
8
+ """Build pixel sampler for segmentation map."""
9
+ return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args)
modelsforCIML/mmseg/core/seg/sampler/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .base_pixel_sampler import BasePixelSampler
3
+ from .ohem_pixel_sampler import OHEMPixelSampler
4
+
5
+ __all__ = ['BasePixelSampler', 'OHEMPixelSampler']
modelsforCIML/mmseg/core/seg/sampler/base_pixel_sampler.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from abc import ABCMeta, abstractmethod
3
+
4
+
5
+ class BasePixelSampler(metaclass=ABCMeta):
6
+ """Base class of pixel sampler."""
7
+
8
+ def __init__(self, **kwargs):
9
+ pass
10
+
11
+ @abstractmethod
12
+ def sample(self, seg_logit, seg_label):
13
+ """Placeholder for sample function."""
modelsforCIML/mmseg/core/seg/sampler/ohem_pixel_sampler.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from ..builder import PIXEL_SAMPLERS
7
+ from .base_pixel_sampler import BasePixelSampler
8
+
9
+
10
+ @PIXEL_SAMPLERS.register_module()
11
+ class OHEMPixelSampler(BasePixelSampler):
12
+ """Online Hard Example Mining Sampler for segmentation.
13
+
14
+ Args:
15
+ context (nn.Module): The context of sampler, subclass of
16
+ :obj:`BaseDecodeHead`.
17
+ thresh (float, optional): The threshold for hard example selection.
18
+ Below which, are prediction with low confidence. If not
19
+ specified, the hard examples will be pixels of top ``min_kept``
20
+ loss. Default: None.
21
+ min_kept (int, optional): The minimum number of predictions to keep.
22
+ Default: 100000.
23
+ """
24
+
25
+ def __init__(self, context, thresh=None, min_kept=100000):
26
+ super(OHEMPixelSampler, self).__init__()
27
+ self.context = context
28
+ assert min_kept > 1
29
+ self.thresh = thresh
30
+ self.min_kept = min_kept
31
+
32
+ def sample(self, seg_logit, seg_label):
33
+ """Sample pixels that have high loss or with low prediction confidence.
34
+
35
+ Args:
36
+ seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W)
37
+ seg_label (torch.Tensor): segmentation label, shape (N, 1, H, W)
38
+
39
+ Returns:
40
+ torch.Tensor: segmentation weight, shape (N, H, W)
41
+ """
42
+ with torch.no_grad():
43
+ assert seg_logit.shape[2:] == seg_label.shape[2:]
44
+ assert seg_label.shape[1] == 1
45
+ seg_label = seg_label.squeeze(1).long()
46
+ batch_kept = self.min_kept * seg_label.size(0)
47
+ valid_mask = seg_label != self.context.ignore_index
48
+ seg_weight = seg_logit.new_zeros(size=seg_label.size())
49
+ valid_seg_weight = seg_weight[valid_mask]
50
+ if self.thresh is not None:
51
+ seg_prob = F.softmax(seg_logit, dim=1)
52
+
53
+ tmp_seg_label = seg_label.clone().unsqueeze(1)
54
+ tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0
55
+ seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1)
56
+ sort_prob, sort_indices = seg_prob[valid_mask].sort()
57
+
58
+ if sort_prob.numel() > 0:
59
+ min_threshold = sort_prob[min(batch_kept,
60
+ sort_prob.numel() - 1)]
61
+ else:
62
+ min_threshold = 0.0
63
+ threshold = max(min_threshold, self.thresh)
64
+ valid_seg_weight[seg_prob[valid_mask] < threshold] = 1.
65
+ else:
66
+ if not isinstance(self.context.loss_decode, nn.ModuleList):
67
+ losses_decode = [self.context.loss_decode]
68
+ else:
69
+ losses_decode = self.context.loss_decode
70
+ losses = 0.0
71
+ for loss_module in losses_decode:
72
+ losses += loss_module(
73
+ seg_logit,
74
+ seg_label,
75
+ weight=None,
76
+ ignore_index=self.context.ignore_index,
77
+ reduction_override='none')
78
+
79
+ # faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa
80
+ _, sort_indices = losses[valid_mask].sort(descending=True)
81
+ valid_seg_weight[sort_indices[:batch_kept]] = 1.
82
+
83
+ seg_weight[valid_mask] = valid_seg_weight
84
+
85
+ return seg_weight
modelsforCIML/mmseg/core/utils/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .dist_util import check_dist_init, sync_random_seed
3
+ from .misc import add_prefix
4
+
5
+ __all__ = ['add_prefix', 'check_dist_init', 'sync_random_seed']
modelsforCIML/mmseg/core/utils/dist_util.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import numpy as np
3
+ import torch
4
+ import torch.distributed as dist
5
+ from mmcv.runner import get_dist_info
6
+
7
+
8
+ def check_dist_init():
9
+ return dist.is_available() and dist.is_initialized()
10
+
11
+
12
+ def sync_random_seed(seed=None, device='cuda'):
13
+ """Make sure different ranks share the same seed. All workers must call
14
+ this function, otherwise it will deadlock. This method is generally used in
15
+ `DistributedSampler`, because the seed should be identical across all
16
+ processes in the distributed group.
17
+
18
+ In distributed sampling, different ranks should sample non-overlapped
19
+ data in the dataset. Therefore, this function is used to make sure that
20
+ each rank shuffles the data indices in the same order based
21
+ on the same seed. Then different ranks could use different indices
22
+ to select non-overlapped data from the same data list.
23
+
24
+ Args:
25
+ seed (int, Optional): The seed. Default to None.
26
+ device (str): The device where the seed will be put on.
27
+ Default to 'cuda'.
28
+ Returns:
29
+ int: Seed to be used.
30
+ """
31
+
32
+ if seed is None:
33
+ seed = np.random.randint(2**31)
34
+ assert isinstance(seed, int)
35
+
36
+ rank, world_size = get_dist_info()
37
+
38
+ if world_size == 1:
39
+ return seed
40
+
41
+ if rank == 0:
42
+ random_num = torch.tensor(seed, dtype=torch.int32, device=device)
43
+ else:
44
+ random_num = torch.tensor(0, dtype=torch.int32, device=device)
45
+ dist.broadcast(random_num, src=0)
46
+ return random_num.item()
modelsforCIML/mmseg/core/utils/misc.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ def add_prefix(inputs, prefix):
3
+ """Add prefix for dict.
4
+
5
+ Args:
6
+ inputs (dict): The input dict with str keys.
7
+ prefix (str): The prefix to add.
8
+
9
+ Returns:
10
+
11
+ dict: The dict with keys updated with ``prefix``.
12
+ """
13
+
14
+ outputs = dict()
15
+ for name, value in inputs.items():
16
+ outputs[f'{prefix}.{name}'] = value
17
+
18
+ return outputs
modelsforCIML/mmseg/models/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone,
3
+ build_head, build_loss, build_segmentor)
4
+ from .decode_heads import * # noqa: F401,F403
5
+ from .losses import *
6
+
7
+ __all__ = [
8
+ 'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'build_backbone',
9
+ 'build_head', 'build_loss', 'build_segmentor'
10
+ ]
modelsforCIML/mmseg/models/builder.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import warnings
3
+
4
+ from mmcv.cnn import MODELS as MMCV_MODELS
5
+ from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION
6
+ from mmcv.utils import Registry
7
+
8
+ MODELS = Registry('models', parent=MMCV_MODELS)
9
+ ATTENTION = Registry('attention', parent=MMCV_ATTENTION)
10
+
11
+ BACKBONES = MODELS
12
+ NECKS = MODELS
13
+ HEADS = MODELS
14
+ LOSSES = MODELS
15
+ SEGMENTORS = MODELS
16
+
17
+
18
+ def build_backbone(cfg):
19
+ """Build backbone."""
20
+ return BACKBONES.build(cfg)
21
+
22
+
23
+ def build_neck(cfg):
24
+ """Build neck."""
25
+ return NECKS.build(cfg)
26
+
27
+
28
+ def build_head(cfg):
29
+ """Build head."""
30
+ return HEADS.build(cfg)
31
+
32
+
33
+ def build_loss(cfg):
34
+ """Build loss."""
35
+ return LOSSES.build(cfg)
36
+
37
+
38
+ def build_segmentor(cfg, train_cfg=None, test_cfg=None):
39
+ """Build segmentor."""
40
+ if train_cfg is not None or test_cfg is not None:
41
+ warnings.warn(
42
+ 'train_cfg and test_cfg is deprecated, '
43
+ 'please specify them in model', UserWarning)
44
+ assert cfg.get('train_cfg') is None or train_cfg is None, \
45
+ 'train_cfg specified in both outer field and model field '
46
+ assert cfg.get('test_cfg') is None or test_cfg is None, \
47
+ 'test_cfg specified in both outer field and model field '
48
+ return SEGMENTORS.build(
49
+ cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
modelsforCIML/mmseg/models/decode_heads/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .fcn_head import FCNHead
3
+ from .uper_lab import UPerLab
4
+ from .uper_head import UPerHead
5
+ from .sep_aspp_head import DepthwiseSeparableASPPHead
6
+
7
+ __all__ = [
8
+ 'FCNHead', 'UPerLab', 'UPerHead', 'DepthwiseSeparableASPPHead'
9
+ ]
modelsforCIML/mmseg/models/decode_heads/aspp_head.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ import torch.nn as nn
4
+ from mmcv.cnn import ConvModule
5
+
6
+ from mmseg.ops import resize
7
+ from ..builder import HEADS
8
+ from .decode_head import BaseDecodeHead
9
+
10
+
11
+ class ASPPModule(nn.ModuleList):
12
+ """Atrous Spatial Pyramid Pooling (ASPP) Module.
13
+
14
+ Args:
15
+ dilations (tuple[int]): Dilation rate of each layer.
16
+ in_channels (int): Input channels.
17
+ channels (int): Channels after modules, before conv_seg.
18
+ conv_cfg (dict|None): Config of conv layers.
19
+ norm_cfg (dict|None): Config of norm layers.
20
+ act_cfg (dict): Config of activation layers.
21
+ """
22
+
23
+ def __init__(self, dilations, in_channels, channels, conv_cfg, norm_cfg,
24
+ act_cfg):
25
+ super(ASPPModule, self).__init__()
26
+ self.dilations = dilations
27
+ self.in_channels = in_channels
28
+ self.channels = channels
29
+ self.conv_cfg = conv_cfg
30
+ self.norm_cfg = norm_cfg
31
+ self.act_cfg = act_cfg
32
+ for dilation in dilations:
33
+ self.append(
34
+ ConvModule(
35
+ self.in_channels,
36
+ self.channels,
37
+ 1 if dilation == 1 else 3,
38
+ dilation=dilation,
39
+ padding=0 if dilation == 1 else dilation,
40
+ conv_cfg=self.conv_cfg,
41
+ norm_cfg=self.norm_cfg,
42
+ act_cfg=self.act_cfg))
43
+
44
+ def forward(self, x):
45
+ """Forward function."""
46
+ aspp_outs = []
47
+ for aspp_module in self:
48
+ aspp_outs.append(aspp_module(x))
49
+
50
+ return aspp_outs
51
+
52
+
53
+ @HEADS.register_module()
54
+ class ASPPHead(BaseDecodeHead):
55
+ """Rethinking Atrous Convolution for Semantic Image Segmentation.
56
+
57
+ This head is the implementation of `DeepLabV3
58
+ <https://arxiv.org/abs/1706.05587>`_.
59
+
60
+ Args:
61
+ dilations (tuple[int]): Dilation rates for ASPP module.
62
+ Default: (1, 6, 12, 18).
63
+ """
64
+
65
+ def __init__(self, dilations=(1, 6, 12, 18), **kwargs):
66
+ super(ASPPHead, self).__init__(**kwargs)
67
+ assert isinstance(dilations, (list, tuple))
68
+ self.dilations = dilations
69
+ self.image_pool = nn.Sequential(
70
+ nn.AdaptiveAvgPool2d(1),
71
+ ConvModule(
72
+ self.in_channels,
73
+ self.channels,
74
+ 1,
75
+ conv_cfg=self.conv_cfg,
76
+ norm_cfg=self.norm_cfg,
77
+ act_cfg=self.act_cfg))
78
+ self.aspp_modules = ASPPModule(
79
+ dilations,
80
+ self.in_channels,
81
+ self.channels,
82
+ conv_cfg=self.conv_cfg,
83
+ norm_cfg=self.norm_cfg,
84
+ act_cfg=self.act_cfg)
85
+ self.bottleneck = ConvModule(
86
+ (len(dilations) + 1) * self.channels,
87
+ self.channels,
88
+ 3,
89
+ padding=1,
90
+ conv_cfg=self.conv_cfg,
91
+ norm_cfg=self.norm_cfg,
92
+ act_cfg=self.act_cfg)
93
+
94
+ def _forward_feature(self, inputs):
95
+ """Forward function for feature maps before classifying each pixel with
96
+ ``self.cls_seg`` fc.
97
+
98
+ Args:
99
+ inputs (list[Tensor]): List of multi-level img features.
100
+
101
+ Returns:
102
+ feats (Tensor): A tensor of shape (batch_size, self.channels,
103
+ H, W) which is feature map for last layer of decoder head.
104
+ """
105
+ x = self._transform_inputs(inputs)
106
+ aspp_outs = [
107
+ resize(
108
+ self.image_pool(x),
109
+ size=x.size()[2:],
110
+ mode='bilinear',
111
+ align_corners=self.align_corners)
112
+ ]
113
+ aspp_outs.extend(self.aspp_modules(x))
114
+ aspp_outs = torch.cat(aspp_outs, dim=1)
115
+ feats = self.bottleneck(aspp_outs)
116
+ return feats
117
+
118
+ def forward(self, inputs):
119
+ """Forward function."""
120
+ output = self._forward_feature(inputs)
121
+ output = self.cls_seg(output)
122
+ return output
modelsforCIML/mmseg/models/decode_heads/decode_head.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import warnings
3
+ from abc import ABCMeta, abstractmethod
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from mmcv.runner import BaseModule, auto_fp16, force_fp32
8
+
9
+ from mmseg.core import build_pixel_sampler
10
+ from mmseg.ops import resize
11
+ from ..builder import build_loss
12
+ from ..losses import accuracy
13
+
14
+
15
+ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
16
+ """Base class for BaseDecodeHead.
17
+
18
+ Args:
19
+ in_channels (int|Sequence[int]): Input channels.
20
+ channels (int): Channels after modules, before conv_seg.
21
+ num_classes (int): Number of classes.
22
+ out_channels (int): Output channels of conv_seg.
23
+ threshold (float): Threshold for binary segmentation in the case of
24
+ `out_channels==1`. Default: None.
25
+ dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
26
+ conv_cfg (dict|None): Config of conv layers. Default: None.
27
+ norm_cfg (dict|None): Config of norm layers. Default: None.
28
+ act_cfg (dict): Config of activation layers.
29
+ Default: dict(type='ReLU')
30
+ in_index (int|Sequence[int]): Input feature index. Default: -1
31
+ input_transform (str|None): Transformation type of input features.
32
+ Options: 'resize_concat', 'multiple_select', None.
33
+ 'resize_concat': Multiple feature maps will be resize to the
34
+ same size as first one and than concat together.
35
+ Usually used in FCN head of HRNet.
36
+ 'multiple_select': Multiple feature maps will be bundle into
37
+ a list and passed into decode head.
38
+ None: Only one select feature map is allowed.
39
+ Default: None.
40
+ loss_decode (dict | Sequence[dict]): Config of decode loss.
41
+ The `loss_name` is property of corresponding loss function which
42
+ could be shown in training log. If you want this loss
43
+ item to be included into the backward graph, `loss_` must be the
44
+ prefix of the name. Defaults to 'loss_ce'.
45
+ e.g. dict(type='CrossEntropyLoss'),
46
+ [dict(type='CrossEntropyLoss', loss_name='loss_ce'),
47
+ dict(type='DiceLoss', loss_name='loss_dice')]
48
+ Default: dict(type='CrossEntropyLoss').
49
+ ignore_index (int | None): The label index to be ignored. When using
50
+ masked BCE loss, ignore_index should be set to None. Default: 255.
51
+ sampler (dict|None): The config of segmentation map sampler.
52
+ Default: None.
53
+ align_corners (bool): align_corners argument of F.interpolate.
54
+ Default: False.
55
+ init_cfg (dict or list[dict], optional): Initialization config dict.
56
+ """
57
+
58
+ def __init__(self,
59
+ in_channels,
60
+ channels,
61
+ *,
62
+ num_classes,
63
+ out_channels=None,
64
+ threshold=None,
65
+ dropout_ratio=0.1,
66
+ conv_cfg=None,
67
+ norm_cfg=None,
68
+ act_cfg=dict(type='ReLU'),
69
+ in_index=-1,
70
+ input_transform=None,
71
+ loss_decode=dict(
72
+ type='CrossEntropyLoss',
73
+ use_sigmoid=False,
74
+ loss_weight=1.0),
75
+ ignore_index=255,
76
+ sampler=None,
77
+ align_corners=False,
78
+ init_cfg=dict(
79
+ type='Normal', std=0.01, override=dict(name='conv_seg'))):
80
+ super(BaseDecodeHead, self).__init__(init_cfg)
81
+ self._init_inputs(in_channels, in_index, input_transform)
82
+ self.channels = channels
83
+ self.dropout_ratio = dropout_ratio
84
+ self.conv_cfg = conv_cfg
85
+ self.norm_cfg = norm_cfg
86
+ self.act_cfg = act_cfg
87
+ self.in_index = in_index
88
+
89
+ self.ignore_index = ignore_index
90
+ self.align_corners = align_corners
91
+
92
+ if out_channels is None:
93
+ if num_classes == 2:
94
+ warnings.warn('For binary segmentation, we suggest using'
95
+ '`out_channels = 1` to define the output'
96
+ 'channels of segmentor, and use `threshold`'
97
+ 'to convert seg_logist into a prediction'
98
+ 'applying a threshold')
99
+ out_channels = num_classes
100
+
101
+ if out_channels != num_classes and out_channels != 1:
102
+ raise ValueError(
103
+ 'out_channels should be equal to num_classes,'
104
+ 'except binary segmentation set out_channels == 1 and'
105
+ f'num_classes == 2, but got out_channels={out_channels}'
106
+ f'and num_classes={num_classes}')
107
+
108
+ if out_channels == 1 and threshold is None:
109
+ threshold = 0.3
110
+ warnings.warn('threshold is not defined for binary, and defaults'
111
+ 'to 0.3')
112
+ self.num_classes = num_classes
113
+ self.out_channels = out_channels
114
+ self.threshold = threshold
115
+
116
+ if isinstance(loss_decode, dict):
117
+ self.loss_decode = build_loss(loss_decode)
118
+ elif isinstance(loss_decode, (list, tuple)):
119
+ self.loss_decode = nn.ModuleList()
120
+ for loss in loss_decode:
121
+ self.loss_decode.append(build_loss(loss))
122
+ else:
123
+ raise TypeError(f'loss_decode must be a dict or sequence of dict,\
124
+ but got {type(loss_decode)}')
125
+
126
+ if sampler is not None:
127
+ self.sampler = build_pixel_sampler(sampler, context=self)
128
+ else:
129
+ self.sampler = None
130
+
131
+ self.conv_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1)
132
+ if dropout_ratio > 0:
133
+ self.dropout = nn.Dropout2d(dropout_ratio)
134
+ else:
135
+ self.dropout = None
136
+ self.fp16_enabled = False
137
+
138
+ def extra_repr(self):
139
+ """Extra repr."""
140
+ s = f'input_transform={self.input_transform}, ' \
141
+ f'ignore_index={self.ignore_index}, ' \
142
+ f'align_corners={self.align_corners}'
143
+ return s
144
+
145
+ def _init_inputs(self, in_channels, in_index, input_transform):
146
+ """Check and initialize input transforms.
147
+
148
+ The in_channels, in_index and input_transform must match.
149
+ Specifically, when input_transform is None, only single feature map
150
+ will be selected. So in_channels and in_index must be of type int.
151
+ When input_transform
152
+
153
+ Args:
154
+ in_channels (int|Sequence[int]): Input channels.
155
+ in_index (int|Sequence[int]): Input feature index.
156
+ input_transform (str|None): Transformation type of input features.
157
+ Options: 'resize_concat', 'multiple_select', None.
158
+ 'resize_concat': Multiple feature maps will be resize to the
159
+ same size as first one and than concat together.
160
+ Usually used in FCN head of HRNet.
161
+ 'multiple_select': Multiple feature maps will be bundle into
162
+ a list and passed into decode head.
163
+ None: Only one select feature map is allowed.
164
+ """
165
+
166
+ if input_transform is not None:
167
+ assert input_transform in ['resize_concat', 'multiple_select']
168
+ self.input_transform = input_transform
169
+ self.in_index = in_index
170
+ if input_transform is not None:
171
+ assert isinstance(in_channels, (list, tuple))
172
+ assert isinstance(in_index, (list, tuple))
173
+ assert len(in_channels) == len(in_index)
174
+ if input_transform == 'resize_concat':
175
+ self.in_channels = sum(in_channels)
176
+ else:
177
+ self.in_channels = in_channels
178
+ else:
179
+ assert isinstance(in_channels, int)
180
+ assert isinstance(in_index, int)
181
+ self.in_channels = in_channels
182
+
183
+ def _transform_inputs(self, inputs):
184
+ """Transform inputs for decoder.
185
+
186
+ Args:
187
+ inputs (list[Tensor]): List of multi-level img features.
188
+
189
+ Returns:
190
+ Tensor: The transformed inputs
191
+ """
192
+
193
+ if self.input_transform == 'resize_concat':
194
+ inputs = [inputs[i] for i in self.in_index]
195
+ upsampled_inputs = [
196
+ resize(
197
+ input=x,
198
+ size=inputs[0].shape[2:],
199
+ mode='bilinear',
200
+ align_corners=self.align_corners) for x in inputs
201
+ ]
202
+ inputs = torch.cat(upsampled_inputs, dim=1)
203
+ elif self.input_transform == 'multiple_select':
204
+ inputs = [inputs[i] for i in self.in_index]
205
+ else:
206
+ inputs = inputs[self.in_index]
207
+
208
+ return inputs
209
+
210
+ @auto_fp16()
211
+ @abstractmethod
212
+ def forward(self, inputs):
213
+ """Placeholder of forward function."""
214
+ pass
215
+
216
+ def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):
217
+ """Forward function for training.
218
+ Args:
219
+ inputs (list[Tensor]): List of multi-level img features.
220
+ img_metas (list[dict]): List of image info dict where each dict
221
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
222
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
223
+ For details on the values of these keys see
224
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
225
+ gt_semantic_seg (Tensor): Semantic segmentation masks
226
+ used if the architecture supports semantic segmentation task.
227
+ train_cfg (dict): The training config.
228
+
229
+ Returns:
230
+ dict[str, Tensor]: a dictionary of loss components
231
+ """
232
+ seg_logits = self(inputs)
233
+ losses = self.losses(seg_logits, gt_semantic_seg)
234
+ return losses
235
+
236
+ def forward_test(self, inputs, img_metas, test_cfg):
237
+ """Forward function for testing.
238
+
239
+ Args:
240
+ inputs (list[Tensor]): List of multi-level img features.
241
+ img_metas (list[dict]): List of image info dict where each dict
242
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
243
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
244
+ For details on the values of these keys see
245
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
246
+ test_cfg (dict): The testing config.
247
+
248
+ Returns:
249
+ Tensor: Output segmentation map.
250
+ """
251
+ return self.forward(inputs)
252
+
253
+ def cls_seg(self, feat):
254
+ """Classify each pixel."""
255
+ if self.dropout is not None:
256
+ feat = self.dropout(feat)
257
+ output = self.conv_seg(feat)
258
+ return output
259
+
260
+ @force_fp32(apply_to=('seg_logit', ))
261
+ def losses(self, seg_logit, seg_label, addstr=''):
262
+ """Compute segmentation loss."""
263
+ loss = dict()
264
+ seg_logit = resize(
265
+ input=seg_logit,
266
+ size=seg_label.shape[2:],
267
+ mode='bilinear',
268
+ align_corners=self.align_corners)
269
+ if self.sampler is not None:
270
+ seg_weight = self.sampler.sample(seg_logit, seg_label)
271
+ else:
272
+ seg_weight = None
273
+ seg_label = seg_label.squeeze(1)
274
+
275
+ if not isinstance(self.loss_decode, nn.ModuleList):
276
+ losses_decode = [self.loss_decode]
277
+ else:
278
+ losses_decode = self.loss_decode
279
+ for loss_decode in losses_decode:
280
+ if loss_decode.loss_name not in loss:
281
+ loss[loss_decode.loss_name+addstr] = loss_decode(
282
+ seg_logit,
283
+ seg_label,
284
+ weight=seg_weight,
285
+ ignore_index=self.ignore_index)
286
+ else:
287
+ loss[loss_decode.loss_name+addstr] += loss_decode(
288
+ seg_logit,
289
+ seg_label,
290
+ weight=seg_weight,
291
+ ignore_index=self.ignore_index)
292
+
293
+ loss['acc_seg'+addstr] = accuracy(
294
+ seg_logit, seg_label, ignore_index=self.ignore_index)
295
+ return loss
modelsforCIML/mmseg/models/decode_heads/fcn_head.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ import torch.nn as nn
4
+ from mmcv.cnn import ConvModule
5
+
6
+ from ..builder import HEADS
7
+ from .decode_head import BaseDecodeHead
8
+
9
+
10
+ @HEADS.register_module()
11
+ class FCNHead(BaseDecodeHead):
12
+ """Fully Convolution Networks for Semantic Segmentation.
13
+
14
+ This head is implemented of `FCNNet <https://arxiv.org/abs/1411.4038>`_.
15
+
16
+ Args:
17
+ num_convs (int): Number of convs in the head. Default: 2.
18
+ kernel_size (int): The kernel size for convs in the head. Default: 3.
19
+ concat_input (bool): Whether concat the input and output of convs
20
+ before classification layer.
21
+ dilation (int): The dilation rate for convs in the head. Default: 1.
22
+ """
23
+
24
+ def __init__(self,
25
+ num_convs=2,
26
+ kernel_size=3,
27
+ concat_input=True,
28
+ dilation=1,
29
+ **kwargs):
30
+ assert num_convs >= 0 and dilation > 0 and isinstance(dilation, int)
31
+ self.num_convs = num_convs
32
+ self.concat_input = concat_input
33
+ self.kernel_size = kernel_size
34
+ super(FCNHead, self).__init__(**kwargs)
35
+ if num_convs == 0:
36
+ assert self.in_channels == self.channels
37
+
38
+ conv_padding = (kernel_size // 2) * dilation
39
+ convs = []
40
+ for i in range(num_convs):
41
+ _in_channels = self.in_channels if i == 0 else self.channels
42
+ convs.append(
43
+ ConvModule(
44
+ _in_channels,
45
+ self.channels,
46
+ kernel_size=kernel_size,
47
+ padding=conv_padding,
48
+ dilation=dilation,
49
+ conv_cfg=self.conv_cfg,
50
+ norm_cfg=self.norm_cfg,
51
+ act_cfg=self.act_cfg))
52
+
53
+ if len(convs) == 0:
54
+ self.convs = nn.Identity()
55
+ else:
56
+ self.convs = nn.Sequential(*convs)
57
+ if self.concat_input:
58
+ self.conv_cat = ConvModule(
59
+ self.in_channels + self.channels,
60
+ self.channels,
61
+ kernel_size=kernel_size,
62
+ padding=kernel_size // 2,
63
+ conv_cfg=self.conv_cfg,
64
+ norm_cfg=self.norm_cfg,
65
+ act_cfg=self.act_cfg)
66
+
67
+ def _forward_feature(self, inputs):
68
+ """Forward function for feature maps before classifying each pixel with
69
+ ``self.cls_seg`` fc.
70
+
71
+ Args:
72
+ inputs (list[Tensor]): List of multi-level img features.
73
+
74
+ Returns:
75
+ feats (Tensor): A tensor of shape (batch_size, self.channels,
76
+ H, W) which is feature map for last layer of decoder head.
77
+ """
78
+ x = self._transform_inputs(inputs)
79
+ feats = self.convs(x)
80
+ if self.concat_input:
81
+ feats = self.conv_cat(torch.cat([x, feats], dim=1))
82
+ return feats
83
+
84
+ def forward(self, inputs):
85
+ """Forward function."""
86
+ output = self._forward_feature(inputs)
87
+ output = self.cls_seg(output)
88
+ return output
modelsforCIML/mmseg/models/decode_heads/psp_head.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ import torch.nn as nn
4
+ from mmcv.cnn import ConvModule
5
+
6
+ from mmseg.ops import resize
7
+ from ..builder import HEADS
8
+ from .decode_head import BaseDecodeHead
9
+
10
+
11
+ class PPM(nn.ModuleList):
12
+ """Pooling Pyramid Module used in PSPNet.
13
+
14
+ Args:
15
+ pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
16
+ Module.
17
+ in_channels (int): Input channels.
18
+ channels (int): Channels after modules, before conv_seg.
19
+ conv_cfg (dict|None): Config of conv layers.
20
+ norm_cfg (dict|None): Config of norm layers.
21
+ act_cfg (dict): Config of activation layers.
22
+ align_corners (bool): align_corners argument of F.interpolate.
23
+ """
24
+
25
+ def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg,
26
+ act_cfg, align_corners, **kwargs):
27
+ super(PPM, self).__init__()
28
+ self.pool_scales = pool_scales
29
+ self.align_corners = align_corners
30
+ self.in_channels = in_channels
31
+ self.channels = channels
32
+ self.conv_cfg = conv_cfg
33
+ self.norm_cfg = norm_cfg
34
+ self.act_cfg = act_cfg
35
+ for pool_scale in pool_scales:
36
+ self.append(
37
+ nn.Sequential(
38
+ nn.AdaptiveAvgPool2d(pool_scale),
39
+ ConvModule(
40
+ self.in_channels,
41
+ self.channels,
42
+ 1,
43
+ conv_cfg=self.conv_cfg,
44
+ norm_cfg=self.norm_cfg,
45
+ act_cfg=self.act_cfg,
46
+ **kwargs)))
47
+
48
+ def forward(self, x):
49
+ """Forward function."""
50
+ ppm_outs = []
51
+ for ppm in self:
52
+ ppm_out = ppm(x)
53
+ upsampled_ppm_out = resize(
54
+ ppm_out,
55
+ size=x.size()[2:],
56
+ mode='bilinear',
57
+ align_corners=self.align_corners)
58
+ ppm_outs.append(upsampled_ppm_out)
59
+ return ppm_outs
60
+
61
+
62
+ @HEADS.register_module()
63
+ class PSPHead(BaseDecodeHead):
64
+ """Pyramid Scene Parsing Network.
65
+
66
+ This head is the implementation of
67
+ `PSPNet <https://arxiv.org/abs/1612.01105>`_.
68
+
69
+ Args:
70
+ pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
71
+ Module. Default: (1, 2, 3, 6).
72
+ """
73
+
74
+ def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
75
+ super(PSPHead, self).__init__(**kwargs)
76
+ assert isinstance(pool_scales, (list, tuple))
77
+ self.pool_scales = pool_scales
78
+ self.psp_modules = PPM(
79
+ self.pool_scales,
80
+ self.in_channels,
81
+ self.channels,
82
+ conv_cfg=self.conv_cfg,
83
+ norm_cfg=self.norm_cfg,
84
+ act_cfg=self.act_cfg,
85
+ align_corners=self.align_corners)
86
+ self.bottleneck = ConvModule(
87
+ self.in_channels + len(pool_scales) * self.channels,
88
+ self.channels,
89
+ 3,
90
+ padding=1,
91
+ conv_cfg=self.conv_cfg,
92
+ norm_cfg=self.norm_cfg,
93
+ act_cfg=self.act_cfg)
94
+
95
+ def _forward_feature(self, inputs):
96
+ """Forward function for feature maps before classifying each pixel with
97
+ ``self.cls_seg`` fc.
98
+
99
+ Args:
100
+ inputs (list[Tensor]): List of multi-level img features.
101
+
102
+ Returns:
103
+ feats (Tensor): A tensor of shape (batch_size, self.channels,
104
+ H, W) which is feature map for last layer of decoder head.
105
+ """
106
+ x = self._transform_inputs(inputs)
107
+ psp_outs = [x]
108
+ psp_outs.extend(self.psp_modules(x))
109
+ psp_outs = torch.cat(psp_outs, dim=1)
110
+ feats = self.bottleneck(psp_outs)
111
+ return feats
112
+
113
+ def forward(self, inputs):
114
+ """Forward function."""
115
+ output = self._forward_feature(inputs)
116
+ output = self.cls_seg(output)
117
+ return output
modelsforCIML/mmseg/models/decode_heads/sep_aspp_head.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ import torch.nn as nn
4
+ from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
5
+
6
+ from mmseg.ops import resize
7
+ from ..builder import HEADS
8
+ from .aspp_head import ASPPHead, ASPPModule
9
+
10
+
11
+ class DepthwiseSeparableASPPModule(ASPPModule):
12
+ """Atrous Spatial Pyramid Pooling (ASPP) Module with depthwise separable
13
+ conv."""
14
+
15
+ def __init__(self, **kwargs):
16
+ super(DepthwiseSeparableASPPModule, self).__init__(**kwargs)
17
+ for i, dilation in enumerate(self.dilations):
18
+ if dilation > 1:
19
+ self[i] = DepthwiseSeparableConvModule(
20
+ self.in_channels,
21
+ self.channels,
22
+ 3,
23
+ dilation=dilation,
24
+ padding=dilation,
25
+ norm_cfg=self.norm_cfg,
26
+ act_cfg=self.act_cfg)
27
+
28
+
29
+ @HEADS.register_module()
30
+ class DepthwiseSeparableASPPHead(ASPPHead):
31
+ """Encoder-Decoder with Atrous Separable Convolution for Semantic Image
32
+ Segmentation.
33
+
34
+ This head is the implementation of `DeepLabV3+
35
+ <https://arxiv.org/abs/1802.02611>`_.
36
+
37
+ Args:
38
+ c1_in_channels (int): The input channels of c1 decoder. If is 0,
39
+ the no decoder will be used.
40
+ c1_channels (int): The intermediate channels of c1 decoder.
41
+ """
42
+
43
+ def __init__(self, c1_in_channels, c1_channels, **kwargs):
44
+ super(DepthwiseSeparableASPPHead, self).__init__(**kwargs)
45
+ assert c1_in_channels >= 0
46
+ self.aspp_modules = DepthwiseSeparableASPPModule(
47
+ dilations=self.dilations,
48
+ in_channels=self.in_channels,
49
+ channels=self.channels,
50
+ conv_cfg=self.conv_cfg,
51
+ norm_cfg=self.norm_cfg,
52
+ act_cfg=self.act_cfg)
53
+ # self.cls_seg = nn.Conv2d(512,2,1,1,0)
54
+ if c1_in_channels > 0:
55
+ self.c1_bottleneck = ConvModule(
56
+ c1_in_channels,
57
+ c1_channels,
58
+ 1,
59
+ conv_cfg=self.conv_cfg,
60
+ norm_cfg=self.norm_cfg,
61
+ act_cfg=self.act_cfg)
62
+ else:
63
+ self.c1_bottleneck = None
64
+ self.sep_bottleneck = nn.Sequential(
65
+ DepthwiseSeparableConvModule(
66
+ self.channels + c1_channels,
67
+ self.channels,
68
+ 3,
69
+ padding=1,
70
+ norm_cfg=self.norm_cfg,
71
+ act_cfg=self.act_cfg),
72
+ DepthwiseSeparableConvModule(
73
+ self.channels,
74
+ self.channels,
75
+ 3,
76
+ padding=1,
77
+ norm_cfg=self.norm_cfg,
78
+ act_cfg=self.act_cfg))
79
+
80
+ def forward(self, inputs, trans=True):
81
+ """Forward function."""
82
+ if trans:
83
+ x = self._transform_inputs(inputs)
84
+ x = inputs[1]
85
+ aspp_outs = [
86
+ resize(
87
+ self.image_pool(x),
88
+ size=x.size()[2:],
89
+ mode='bilinear',
90
+ align_corners=self.align_corners)
91
+ ]
92
+ aspp_outs.extend(self.aspp_modules(x))
93
+ aspp_outs = torch.cat(aspp_outs, dim=1)
94
+ output = self.bottleneck(aspp_outs)
95
+ if self.c1_bottleneck is not None:
96
+ c1_output = self.c1_bottleneck(inputs[0])
97
+ output = resize(
98
+ input=output,
99
+ size=c1_output.shape[2:],
100
+ mode='bilinear',
101
+ align_corners=self.align_corners)
102
+ output = torch.cat([output, c1_output], dim=1)
103
+ output = self.sep_bottleneck(output)
104
+ output = self.cls_seg(output)
105
+ return output
modelsforCIML/mmseg/models/decode_heads/uper_head.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ import torch.nn as nn
4
+ from mmcv.cnn import ConvModule
5
+
6
+ from mmseg.ops import resize
7
+ from ..builder import HEADS
8
+ from .decode_head import BaseDecodeHead
9
+ from .psp_head import PPM
10
+
11
+
12
+ @HEADS.register_module()
13
+ class UPerHead(BaseDecodeHead):
14
+ """Unified Perceptual Parsing for Scene Understanding.
15
+
16
+ This head is the implementation of `UPerNet
17
+ <https://arxiv.org/abs/1807.10221>`_.
18
+
19
+ Args:
20
+ pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
21
+ Module applied on the last feature. Default: (1, 2, 3, 6).
22
+ """
23
+
24
+ def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
25
+ super(UPerHead, self).__init__(
26
+ input_transform='multiple_select', **kwargs)
27
+ # PSP Module
28
+ self.psp_modules = PPM(
29
+ pool_scales,
30
+ self.in_channels[-1],
31
+ self.channels,
32
+ conv_cfg=self.conv_cfg,
33
+ norm_cfg=self.norm_cfg,
34
+ act_cfg=self.act_cfg,
35
+ align_corners=self.align_corners)
36
+ self.bottleneck = ConvModule(
37
+ self.in_channels[-1] + len(pool_scales) * self.channels,
38
+ self.channels,
39
+ 3,
40
+ padding=1,
41
+ conv_cfg=self.conv_cfg,
42
+ norm_cfg=self.norm_cfg,
43
+ act_cfg=self.act_cfg)
44
+ # FPN Module
45
+ self.lateral_convs = nn.ModuleList()
46
+ self.fpn_convs = nn.ModuleList()
47
+ for in_channels in self.in_channels[:-1]: # skip the top layer
48
+ l_conv = ConvModule(
49
+ in_channels,
50
+ self.channels,
51
+ 1,
52
+ conv_cfg=self.conv_cfg,
53
+ norm_cfg=self.norm_cfg,
54
+ act_cfg=self.act_cfg,
55
+ inplace=False)
56
+ fpn_conv = ConvModule(
57
+ self.channels,
58
+ self.channels,
59
+ 3,
60
+ padding=1,
61
+ conv_cfg=self.conv_cfg,
62
+ norm_cfg=self.norm_cfg,
63
+ act_cfg=self.act_cfg,
64
+ inplace=False)
65
+ self.lateral_convs.append(l_conv)
66
+ self.fpn_convs.append(fpn_conv)
67
+
68
+ def psp_forward(self, inputs):
69
+ """Forward function of PSP module."""
70
+ x = inputs[-1]
71
+ psp_outs = [x]
72
+ psp_outs.extend(self.psp_modules(x))
73
+ psp_outs = torch.cat(psp_outs, dim=1)
74
+ output = self.bottleneck(psp_outs)
75
+
76
+ return output
77
+
78
+ def _forward_feature(self, inputs):
79
+ """Forward function for feature maps before classifying each pixel with
80
+ ``self.cls_seg`` fc.
81
+
82
+ Args:
83
+ inputs (list[Tensor]): List of multi-level img features.
84
+
85
+ Returns:
86
+ feats (Tensor): A tensor of shape (batch_size, self.channels,
87
+ H, W) which is feature map for last layer of decoder head.
88
+ """
89
+ inputs = self._transform_inputs(inputs)
90
+
91
+ # build laterals
92
+ laterals = [
93
+ lateral_conv(inputs[i])
94
+ for i, lateral_conv in enumerate(self.lateral_convs)
95
+ ]
96
+
97
+ laterals.append(self.psp_forward(inputs))
98
+
99
+ # build top-down path
100
+ used_backbone_levels = len(laterals)
101
+ for i in range(used_backbone_levels - 1, 0, -1):
102
+ prev_shape = laterals[i - 1].shape[2:]
103
+ laterals[i - 1] = laterals[i - 1] + resize(
104
+ laterals[i],
105
+ size=prev_shape,
106
+ mode='bilinear',
107
+ align_corners=self.align_corners)
108
+
109
+ # build outputs
110
+ fpn_outs = [
111
+ self.fpn_convs[i](laterals[i])
112
+ for i in range(used_backbone_levels - 1)
113
+ ]
114
+ # append psp feature
115
+ fpn_outs.append(laterals[-1])
116
+
117
+ for i in range(used_backbone_levels - 1, -1, -1):
118
+ fpn_outs[i] = resize(
119
+ fpn_outs[i],
120
+ size=fpn_outs[1].shape[2:],
121
+ mode='bilinear',
122
+ align_corners=self.align_corners)
123
+ return fpn_outs#[:3]
124
+
125
+ def forward(self, inputs):
126
+ """Forward function."""
127
+ output = self._forward_feature(inputs)
128
+ return output
modelsforCIML/mmseg/models/decode_heads/uper_lab.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ import torch.nn as nn
4
+ from mmcv.cnn import ConvModule
5
+ from .sep_aspp_head import DepthwiseSeparableASPPHead
6
+ from mmseg.ops import resize
7
+ from ..builder import HEADS
8
+ from .decode_head import BaseDecodeHead
9
+ from .psp_head import PPM
10
+
11
+
12
+ @HEADS.register_module()
13
+ class UPerLab(BaseDecodeHead):
14
+
15
+ def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
16
+ super(UPerLab, self).__init__(
17
+ input_transform='multiple_select', **kwargs)
18
+ # PSP Module
19
+ self.deeplab = DepthwiseSeparableASPPHead(in_channels=2048,in_index=3,channels=512,dilations=(1, 12, 24, 36),c1_in_channels=256,c1_channels=48,dropout_ratio=0.1,num_classes=2,norm_cfg=dict(type='SyncBN', requires_grad=True),align_corners=False)
20
+ self.convert = nn.Conv2d(512,256,1,1,0)
21
+ self.psp_modules = PPM(
22
+ pool_scales,
23
+ self.in_channels[-1],
24
+ self.channels,
25
+ conv_cfg=self.conv_cfg,
26
+ norm_cfg=self.norm_cfg,
27
+ act_cfg=self.act_cfg,
28
+ align_corners=self.align_corners)
29
+ self.bottleneck = ConvModule(
30
+ self.in_channels[-1] + len(pool_scales) * self.channels,
31
+ self.channels,
32
+ 3,
33
+ padding=1,
34
+ conv_cfg=self.conv_cfg,
35
+ norm_cfg=self.norm_cfg,
36
+ act_cfg=self.act_cfg)
37
+ # FPN Module
38
+ self.lateral_convs = nn.ModuleList()
39
+ self.fpn_convs = nn.ModuleList()
40
+ for in_channels in self.in_channels[:-1]: # skip the top layer
41
+ l_conv = ConvModule(
42
+ in_channels,
43
+ self.channels,
44
+ 1,
45
+ conv_cfg=self.conv_cfg,
46
+ norm_cfg=self.norm_cfg,
47
+ act_cfg=self.act_cfg,
48
+ inplace=False)
49
+ fpn_conv = ConvModule(
50
+ self.channels,
51
+ self.channels,
52
+ 3,
53
+ padding=1,
54
+ conv_cfg=self.conv_cfg,
55
+ norm_cfg=self.norm_cfg,
56
+ act_cfg=self.act_cfg,
57
+ inplace=False)
58
+ self.lateral_convs.append(l_conv)
59
+ self.fpn_convs.append(fpn_conv)
60
+
61
+ def psp_forward(self, inputs):
62
+ """Forward function of PSP module."""
63
+ x = inputs[-1]
64
+ psp_outs = [x]
65
+ psp_outs.extend(self.psp_modules(x))
66
+ psp_outs = torch.cat(psp_outs, dim=1)
67
+ output = self.bottleneck(psp_outs)
68
+
69
+ return output
70
+
71
+ def forward(self, inputs):
72
+ inputs = self._transform_inputs(inputs)
73
+
74
+ # build laterals
75
+ laterals = [
76
+ lateral_conv(inputs[i])
77
+ for i, lateral_conv in enumerate(self.lateral_convs)
78
+ ]
79
+
80
+ laterals.append(self.psp_forward(inputs))
81
+
82
+ # build top-down path
83
+ used_backbone_levels = len(laterals)
84
+ for i in range(used_backbone_levels - 1, 0, -1):
85
+ prev_shape = laterals[i - 1].shape[2:]
86
+ laterals[i - 1] = laterals[i - 1] + resize(
87
+ laterals[i],
88
+ size=prev_shape,
89
+ mode='bilinear',
90
+ align_corners=self.align_corners)
91
+
92
+ # build outputs
93
+ fpn_outs = [
94
+ self.fpn_convs[i](laterals[i])
95
+ for i in range(used_backbone_levels - 1)
96
+ ]
97
+ # append psp feature
98
+ fpn_outs.append(laterals[-1])
99
+ if self.training:
100
+ cls_aux = self.cls_seg(fpn_outs[0])
101
+ feat0 = self.convert(fpn_outs[0])
102
+ for i in range(used_backbone_levels - 1, 0, -1):
103
+ fpn_outs[i] = resize(
104
+ fpn_outs[i],
105
+ size=fpn_outs[1].shape[2:],
106
+ mode='bilinear',
107
+ align_corners=self.align_corners)
108
+ fpn_outs[0] = resize(fpn_outs[0], size=fpn_outs[1].shape[2:], mode='bilinear', align_corners=self.align_corners)
109
+ fpn_outs = torch.cat(fpn_outs, dim=1)
110
+ if self.training:
111
+ return (self.deeplab([feat0, fpn_outs], trans=False), cls_aux)# feats
112
+ else:
113
+ return self.deeplab([feat0, fpn_outs], trans=False)
114
+
115
+ def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):
116
+ seg_logits, aux_logits = self(inputs)
117
+ losses = self.losses(seg_logits, gt_semantic_seg)
118
+ losses_aux = self.losses(aux_logits, gt_semantic_seg, addstr='_uper')
119
+ losses.update(losses_aux)
120
+ return losses
modelsforCIML/mmseg/models/losses/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .accuracy import Accuracy, accuracy
3
+ from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
4
+ cross_entropy, mask_cross_entropy)
5
+ from .dice_loss import DiceLoss
6
+ # from .focal_loss import FocalLoss
7
+ from .lovasz_loss import LovaszLoss
8
+ from .tversky_loss import TverskyLoss
9
+ from .utils import reduce_loss, weight_reduce_loss, weighted_loss
10
+
11
+ __all__ = [
12
+ 'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy',
13
+ 'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
14
+ 'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss',
15
+ 'TverskyLoss'
16
+ ]
modelsforCIML/mmseg/models/losses/accuracy.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ def accuracy(pred, target, topk=1, thresh=None, ignore_index=-100):
7
+ """Calculate accuracy according to the prediction and target.
8
+
9
+ Args:
10
+ pred (torch.Tensor): The model prediction, shape (N, num_class, ...)
11
+ target (torch.Tensor): The target of each prediction, shape (N, , ...)
12
+ ignore_index (int | None): The label index to be ignored. Default: None
13
+ topk (int | tuple[int], optional): If the predictions in ``topk``
14
+ matches the target, the predictions will be regarded as
15
+ correct ones. Defaults to 1.
16
+ thresh (float, optional): If not None, predictions with scores under
17
+ this threshold are considered incorrect. Default to None.
18
+
19
+ Returns:
20
+ float | tuple[float]: If the input ``topk`` is a single integer,
21
+ the function will return a single float as accuracy. If
22
+ ``topk`` is a tuple containing multiple integers, the
23
+ function will return a tuple containing accuracies of
24
+ each ``topk`` number.
25
+ """
26
+ assert isinstance(topk, (int, tuple))
27
+ if isinstance(topk, int):
28
+ topk = (topk, )
29
+ return_single = True
30
+ else:
31
+ return_single = False
32
+
33
+ maxk = max(topk)
34
+ if pred.size(0) == 0:
35
+ accu = [pred.new_tensor(0.) for i in range(len(topk))]
36
+ return accu[0] if return_single else accu
37
+ assert pred.ndim == target.ndim + 1
38
+ assert pred.size(0) == target.size(0)
39
+ assert maxk <= pred.size(1), \
40
+ f'maxk {maxk} exceeds pred dimension {pred.size(1)}'
41
+ pred_value, pred_label = pred.topk(maxk, dim=1)
42
+ # transpose to shape (maxk, N, ...)
43
+ pred_label = pred_label.transpose(0, 1)
44
+ correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label))
45
+ if thresh is not None:
46
+ # Only prediction values larger than thresh are counted as correct
47
+ correct = correct & (pred_value > thresh).t()
48
+ if ignore_index is not None:
49
+ correct = correct[:, target != ignore_index]
50
+ res = []
51
+ eps = torch.finfo(torch.float32).eps
52
+ for k in topk:
53
+ # Avoid causing ZeroDivisionError when all pixels
54
+ # of an image are ignored
55
+ correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + eps
56
+ if ignore_index is not None:
57
+ total_num = target[target != ignore_index].numel() + eps
58
+ else:
59
+ total_num = target.numel() + eps
60
+ res.append(correct_k.mul_(100.0 / total_num))
61
+ return res[0] if return_single else res
62
+
63
+
64
+ class Accuracy(nn.Module):
65
+ """Accuracy calculation module."""
66
+
67
+ def __init__(self, topk=(1, ), thresh=None, ignore_index=None):
68
+ """Module to calculate the accuracy.
69
+
70
+ Args:
71
+ topk (tuple, optional): The criterion used to calculate the
72
+ accuracy. Defaults to (1,).
73
+ thresh (float, optional): If not None, predictions with scores
74
+ under this threshold are considered incorrect. Default to None.
75
+ """
76
+ super().__init__()
77
+ self.topk = topk
78
+ self.thresh = thresh
79
+ self.ignore_index = ignore_index
80
+
81
+ def forward(self, pred, target):
82
+ """Forward function to calculate accuracy.
83
+
84
+ Args:
85
+ pred (torch.Tensor): Prediction of models.
86
+ target (torch.Tensor): Target for each prediction.
87
+
88
+ Returns:
89
+ tuple[float]: The accuracies under different topk criterions.
90
+ """
91
+ return accuracy(pred, target, self.topk, self.thresh,
92
+ self.ignore_index)
modelsforCIML/mmseg/models/losses/cross_entropy_loss.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import warnings
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from ..builder import LOSSES
9
+ from .utils import get_class_weight, weight_reduce_loss
10
+
11
+
12
+ def cross_entropy(pred,
13
+ label,
14
+ weight=None,
15
+ class_weight=None,
16
+ reduction='mean',
17
+ avg_factor=None,
18
+ ignore_index=-100,
19
+ avg_non_ignore=False):
20
+ """cross_entropy. The wrapper function for :func:`F.cross_entropy`
21
+
22
+ Args:
23
+ pred (torch.Tensor): The prediction with shape (N, 1).
24
+ label (torch.Tensor): The learning label of the prediction.
25
+ weight (torch.Tensor, optional): Sample-wise loss weight.
26
+ Default: None.
27
+ class_weight (list[float], optional): The weight for each class.
28
+ Default: None.
29
+ reduction (str, optional): The method used to reduce the loss.
30
+ Options are 'none', 'mean' and 'sum'. Default: 'mean'.
31
+ avg_factor (int, optional): Average factor that is used to average
32
+ the loss. Default: None.
33
+ ignore_index (int): Specifies a target value that is ignored and
34
+ does not contribute to the input gradients. When
35
+ ``avg_non_ignore `` is ``True``, and the ``reduction`` is
36
+ ``''mean''``, the loss is averaged over non-ignored targets.
37
+ Defaults: -100.
38
+ avg_non_ignore (bool): The flag decides to whether the loss is
39
+ only averaged over non-ignored targets. Default: False.
40
+ `New in version 0.23.0.`
41
+ """
42
+
43
+ # class_weight is a manual rescaling weight given to each class.
44
+ # If given, has to be a Tensor of size C element-wise losses
45
+ loss = F.cross_entropy(
46
+ pred,
47
+ label,
48
+ weight=class_weight,
49
+ reduction='none',
50
+ ignore_index=ignore_index)
51
+
52
+ # apply weights and do the reduction
53
+ # average loss over non-ignored elements
54
+ # pytorch's official cross_entropy average loss over non-ignored elements
55
+ # refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa
56
+ if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
57
+ avg_factor = label.numel() - (label == ignore_index).sum().item()
58
+ if weight is not None:
59
+ weight = weight.float()
60
+ loss = weight_reduce_loss(
61
+ loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
62
+
63
+ return loss
64
+
65
+
66
+ def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
67
+ """Expand onehot labels to match the size of prediction."""
68
+ bin_labels = labels.new_zeros(target_shape)
69
+ valid_mask = (labels >= 0) & (labels != ignore_index)
70
+ inds = torch.nonzero(valid_mask, as_tuple=True)
71
+
72
+ if inds[0].numel() > 0:
73
+ if labels.dim() == 3:
74
+ bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1
75
+ else:
76
+ bin_labels[inds[0], labels[valid_mask]] = 1
77
+
78
+ valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()
79
+
80
+ if label_weights is None:
81
+ bin_label_weights = valid_mask
82
+ else:
83
+ bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
84
+ bin_label_weights = bin_label_weights * valid_mask
85
+
86
+ return bin_labels, bin_label_weights, valid_mask
87
+
88
+
89
+ def binary_cross_entropy(pred,
90
+ label,
91
+ weight=None,
92
+ reduction='mean',
93
+ avg_factor=None,
94
+ class_weight=None,
95
+ ignore_index=-100,
96
+ avg_non_ignore=False,
97
+ **kwargs):
98
+ """Calculate the binary CrossEntropy loss.
99
+
100
+ Args:
101
+ pred (torch.Tensor): The prediction with shape (N, 1).
102
+ label (torch.Tensor): The learning label of the prediction.
103
+ Note: In bce loss, label < 0 is invalid.
104
+ weight (torch.Tensor, optional): Sample-wise loss weight.
105
+ reduction (str, optional): The method used to reduce the loss.
106
+ Options are "none", "mean" and "sum".
107
+ avg_factor (int, optional): Average factor that is used to average
108
+ the loss. Defaults to None.
109
+ class_weight (list[float], optional): The weight for each class.
110
+ ignore_index (int): The label index to be ignored. Default: -100.
111
+ avg_non_ignore (bool): The flag decides to whether the loss is
112
+ only averaged over non-ignored targets. Default: False.
113
+ `New in version 0.23.0.`
114
+
115
+ Returns:
116
+ torch.Tensor: The calculated loss
117
+ """
118
+ if pred.size(1) == 1:
119
+ # For binary class segmentation, the shape of pred is
120
+ # [N, 1, H, W] and that of label is [N, H, W].
121
+ # As the ignore_index often set as 255, so the
122
+ # binary class label check should mask out
123
+ # ignore_index
124
+ assert label[label != ignore_index].max() <= 1, \
125
+ 'For pred with shape [N, 1, H, W], its label must have at ' \
126
+ 'most 2 classes'
127
+ pred = pred.squeeze(1)
128
+ if pred.dim() != label.dim():
129
+ assert (pred.dim() == 2 and label.dim() == 1) or (
130
+ pred.dim() == 4 and label.dim() == 3), \
131
+ 'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \
132
+ 'H, W], label shape [N, H, W] are supported'
133
+ # `weight` returned from `_expand_onehot_labels`
134
+ # has been treated for valid (non-ignore) pixels
135
+ label, weight, valid_mask = _expand_onehot_labels(
136
+ label, weight, pred.shape, ignore_index)
137
+ else:
138
+ # should mask out the ignored elements
139
+ valid_mask = ((label >= 0) & (label != ignore_index)).float()
140
+ if weight is not None:
141
+ weight = weight * valid_mask
142
+ else:
143
+ weight = valid_mask
144
+ # average loss over non-ignored and valid elements
145
+ if reduction == 'mean' and avg_factor is None and avg_non_ignore:
146
+ avg_factor = valid_mask.sum().item()
147
+
148
+ loss = F.binary_cross_entropy_with_logits(
149
+ pred, label.float(), pos_weight=class_weight, reduction='none')
150
+ # do the reduction for the weighted loss
151
+ loss = weight_reduce_loss(
152
+ loss, weight, reduction=reduction, avg_factor=avg_factor)
153
+
154
+ return loss
155
+
156
+
157
+ def mask_cross_entropy(pred,
158
+ target,
159
+ label,
160
+ reduction='mean',
161
+ avg_factor=None,
162
+ class_weight=None,
163
+ ignore_index=None,
164
+ **kwargs):
165
+ """Calculate the CrossEntropy loss for masks.
166
+
167
+ Args:
168
+ pred (torch.Tensor): The prediction with shape (N, C), C is the number
169
+ of classes.
170
+ target (torch.Tensor): The learning label of the prediction.
171
+ label (torch.Tensor): ``label`` indicates the class label of the mask'
172
+ corresponding object. This will be used to select the mask in the
173
+ of the class which the object belongs to when the mask prediction
174
+ if not class-agnostic.
175
+ reduction (str, optional): The method used to reduce the loss.
176
+ Options are "none", "mean" and "sum".
177
+ avg_factor (int, optional): Average factor that is used to average
178
+ the loss. Defaults to None.
179
+ class_weight (list[float], optional): The weight for each class.
180
+ ignore_index (None): Placeholder, to be consistent with other loss.
181
+ Default: None.
182
+
183
+ Returns:
184
+ torch.Tensor: The calculated loss
185
+ """
186
+ assert ignore_index is None, 'BCE loss does not support ignore_index'
187
+ # TODO: handle these two reserved arguments
188
+ assert reduction == 'mean' and avg_factor is None
189
+ num_rois = pred.size()[0]
190
+ inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
191
+ pred_slice = pred[inds, label].squeeze(1)
192
+ return F.binary_cross_entropy_with_logits(
193
+ pred_slice, target, weight=class_weight, reduction='mean')[None]
194
+
195
+
196
+ @LOSSES.register_module()
197
+ class CrossEntropyLoss(nn.Module):
198
+ """CrossEntropyLoss.
199
+
200
+ Args:
201
+ use_sigmoid (bool, optional): Whether the prediction uses sigmoid
202
+ instead of softmax. Defaults to False.
203
+ use_mask (bool, optional): Whether to use mask cross entropy loss.
204
+ Defaults to False.
205
+ reduction (str, optional): . Defaults to 'mean'.
206
+ Options are "none", "mean" and "sum".
207
+ class_weight (list[float] | str, optional): Weight of each class. If in
208
+ str format, read them from a file. Defaults to None.
209
+ loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
210
+ loss_name (str, optional): Name of the loss item. If you want this loss
211
+ item to be included into the backward graph, `loss_` must be the
212
+ prefix of the name. Defaults to 'loss_ce'.
213
+ avg_non_ignore (bool): The flag decides to whether the loss is
214
+ only averaged over non-ignored targets. Default: False.
215
+ `New in version 0.23.0.`
216
+ """
217
+
218
+ def __init__(self,
219
+ use_sigmoid=False,
220
+ use_mask=False,
221
+ reduction='mean',
222
+ class_weight=None,
223
+ loss_weight=1.0,
224
+ loss_name='loss_ce',
225
+ avg_non_ignore=False):
226
+ super(CrossEntropyLoss, self).__init__()
227
+ assert (use_sigmoid is False) or (use_mask is False)
228
+ self.use_sigmoid = use_sigmoid
229
+ self.use_mask = use_mask
230
+ self.reduction = reduction
231
+ self.loss_weight = loss_weight
232
+ self.class_weight = get_class_weight(class_weight)
233
+ self.avg_non_ignore = avg_non_ignore
234
+ if not self.avg_non_ignore and self.reduction == 'mean':
235
+ warnings.warn(
236
+ 'Default ``avg_non_ignore`` is False, if you would like to '
237
+ 'ignore the certain label and average loss over non-ignore '
238
+ 'labels, which is the same with PyTorch official '
239
+ 'cross_entropy, set ``avg_non_ignore=True``.')
240
+
241
+ if self.use_sigmoid:
242
+ self.cls_criterion = binary_cross_entropy
243
+ elif self.use_mask:
244
+ self.cls_criterion = mask_cross_entropy
245
+ else:
246
+ self.cls_criterion = cross_entropy
247
+ self._loss_name = loss_name
248
+
249
+ def extra_repr(self):
250
+ """Extra repr."""
251
+ s = f'avg_non_ignore={self.avg_non_ignore}'
252
+ return s
253
+
254
+ def forward(self,
255
+ cls_score,
256
+ label,
257
+ weight=None,
258
+ avg_factor=None,
259
+ reduction_override=None,
260
+ ignore_index=-100,
261
+ **kwargs):
262
+ """Forward function."""
263
+ assert reduction_override in (None, 'none', 'mean', 'sum')
264
+ reduction = (
265
+ reduction_override if reduction_override else self.reduction)
266
+ if self.class_weight is not None:
267
+ class_weight = cls_score.new_tensor(self.class_weight)
268
+ else:
269
+ class_weight = None
270
+ # Note: for BCE loss, label < 0 is invalid.
271
+ loss_cls = self.loss_weight * self.cls_criterion(
272
+ cls_score,
273
+ label,
274
+ weight,
275
+ class_weight=class_weight,
276
+ reduction=reduction,
277
+ avg_factor=avg_factor,
278
+ avg_non_ignore=self.avg_non_ignore,
279
+ ignore_index=ignore_index,
280
+ **kwargs)
281
+ return loss_cls
282
+
283
+ @property
284
+ def loss_name(self):
285
+ """Loss Name.
286
+
287
+ This function must be implemented and will return the name of this
288
+ loss function. This name will be used to combine different loss items
289
+ by simple sum operation. In addition, if you want this loss item to be
290
+ included into the backward graph, `loss_` must be the prefix of the
291
+ name.
292
+
293
+ Returns:
294
+ str: The name of this loss item.
295
+ """
296
+ return self._loss_name
modelsforCIML/mmseg/models/losses/dice_loss.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ """Modified from https://github.com/LikeLy-Journey/SegmenTron/blob/master/
3
+ segmentron/solver/loss.py (Apache-2.0 License)"""
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from ..builder import LOSSES
9
+ from .utils import get_class_weight, weighted_loss
10
+
11
+
12
+ @weighted_loss
13
+ def dice_loss(pred,
14
+ target,
15
+ valid_mask,
16
+ smooth=1,
17
+ exponent=2,
18
+ class_weight=None,
19
+ ignore_index=255):
20
+ assert pred.shape[0] == target.shape[0]
21
+ total_loss = 0
22
+ num_classes = pred.shape[1]
23
+ for i in range(num_classes):
24
+ if i != ignore_index:
25
+ dice_loss = binary_dice_loss(
26
+ pred[:, i],
27
+ target[..., i],
28
+ valid_mask=valid_mask,
29
+ smooth=smooth,
30
+ exponent=exponent)
31
+ if class_weight is not None:
32
+ dice_loss *= class_weight[i]
33
+ total_loss += dice_loss
34
+ return total_loss / num_classes
35
+
36
+
37
+ @weighted_loss
38
+ def binary_dice_loss(pred, target, valid_mask, smooth=1, exponent=2, **kwargs):
39
+ assert pred.shape[0] == target.shape[0]
40
+ pred = pred.reshape(pred.shape[0], -1)
41
+ target = target.reshape(target.shape[0], -1)
42
+ valid_mask = valid_mask.reshape(valid_mask.shape[0], -1)
43
+
44
+ num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + smooth
45
+ den = torch.sum(pred.pow(exponent) + target.pow(exponent), dim=1) + smooth
46
+
47
+ return 1 - num / den
48
+
49
+
50
+ @LOSSES.register_module()
51
+ class DiceLoss(nn.Module):
52
+ """DiceLoss.
53
+
54
+ This loss is proposed in `V-Net: Fully Convolutional Neural Networks for
55
+ Volumetric Medical Image Segmentation <https://arxiv.org/abs/1606.04797>`_.
56
+
57
+ Args:
58
+ smooth (float): A float number to smooth loss, and avoid NaN error.
59
+ Default: 1
60
+ exponent (float): An float number to calculate denominator
61
+ value: \\sum{x^exponent} + \\sum{y^exponent}. Default: 2.
62
+ reduction (str, optional): The method used to reduce the loss. Options
63
+ are "none", "mean" and "sum". This parameter only works when
64
+ per_image is True. Default: 'mean'.
65
+ class_weight (list[float] | str, optional): Weight of each class. If in
66
+ str format, read them from a file. Defaults to None.
67
+ loss_weight (float, optional): Weight of the loss. Default to 1.0.
68
+ ignore_index (int | None): The label index to be ignored. Default: 255.
69
+ loss_name (str, optional): Name of the loss item. If you want this loss
70
+ item to be included into the backward graph, `loss_` must be the
71
+ prefix of the name. Defaults to 'loss_dice'.
72
+ """
73
+
74
+ def __init__(self,
75
+ smooth=1,
76
+ exponent=2,
77
+ reduction='mean',
78
+ class_weight=None,
79
+ loss_weight=1.0,
80
+ ignore_index=255,
81
+ loss_name='loss_dice',
82
+ **kwargs):
83
+ super(DiceLoss, self).__init__()
84
+ self.smooth = smooth
85
+ self.exponent = exponent
86
+ self.reduction = reduction
87
+ self.class_weight = get_class_weight(class_weight)
88
+ self.loss_weight = loss_weight
89
+ self.ignore_index = ignore_index
90
+ self._loss_name = loss_name
91
+
92
+ def forward(self,
93
+ pred,
94
+ target,
95
+ avg_factor=None,
96
+ reduction_override=None,
97
+ **kwargs):
98
+ assert reduction_override in (None, 'none', 'mean', 'sum')
99
+ reduction = (
100
+ reduction_override if reduction_override else self.reduction)
101
+ if self.class_weight is not None:
102
+ class_weight = pred.new_tensor(self.class_weight)
103
+ else:
104
+ class_weight = None
105
+
106
+ pred = F.softmax(pred, dim=1)
107
+ num_classes = pred.shape[1]
108
+ one_hot_target = F.one_hot(
109
+ torch.clamp(target.long(), 0, num_classes - 1),
110
+ num_classes=num_classes)
111
+ valid_mask = (target != self.ignore_index).long()
112
+
113
+ loss = self.loss_weight * dice_loss(
114
+ pred,
115
+ one_hot_target,
116
+ valid_mask=valid_mask,
117
+ reduction=reduction,
118
+ avg_factor=avg_factor,
119
+ smooth=self.smooth,
120
+ exponent=self.exponent,
121
+ class_weight=class_weight,
122
+ ignore_index=self.ignore_index)
123
+ return loss
124
+
125
+ @property
126
+ def loss_name(self):
127
+ """Loss Name.
128
+
129
+ This function must be implemented and will return the name of this
130
+ loss function. This name will be used to combine different loss items
131
+ by simple sum operation. In addition, if you want this loss item to be
132
+ included into the backward graph, `loss_` must be the prefix of the
133
+ name.
134
+ Returns:
135
+ str: The name of this loss item.
136
+ """
137
+ return self._loss_name
modelsforCIML/mmseg/models/losses/focal_loss.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ # Modified from https://github.com/open-mmlab/mmdetection
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss
7
+
8
+ from ..builder import LOSSES
9
+ from .utils import weight_reduce_loss
10
+
11
+
12
+ # This method is used when cuda is not available
13
+ def py_sigmoid_focal_loss(pred,
14
+ target,
15
+ one_hot_target=None,
16
+ weight=None,
17
+ gamma=2.0,
18
+ alpha=0.5,
19
+ class_weight=None,
20
+ valid_mask=None,
21
+ reduction='mean',
22
+ avg_factor=None):
23
+ """PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_.
24
+
25
+ Args:
26
+ pred (torch.Tensor): The prediction with shape (N, C), C is the
27
+ number of classes
28
+ target (torch.Tensor): The learning label of the prediction with
29
+ shape (N, C)
30
+ one_hot_target (None): Placeholder. It should be None.
31
+ weight (torch.Tensor, optional): Sample-wise loss weight.
32
+ gamma (float, optional): The gamma for calculating the modulating
33
+ factor. Defaults to 2.0.
34
+ alpha (float | list[float], optional): A balanced form for Focal Loss.
35
+ Defaults to 0.5.
36
+ class_weight (list[float], optional): Weight of each class.
37
+ Defaults to None.
38
+ valid_mask (torch.Tensor, optional): A mask uses 1 to mark the valid
39
+ samples and uses 0 to mark the ignored samples. Default: None.
40
+ reduction (str, optional): The method used to reduce the loss into
41
+ a scalar. Defaults to 'mean'.
42
+ avg_factor (int, optional): Average factor that is used to average
43
+ the loss. Defaults to None.
44
+ """
45
+ if isinstance(alpha, list):
46
+ alpha = pred.new_tensor(alpha)
47
+ pred_sigmoid = pred.sigmoid()
48
+ target = target.type_as(pred)
49
+ one_minus_pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
50
+ focal_weight = (alpha * target + (1 - alpha) *
51
+ (1 - target)) * one_minus_pt.pow(gamma)
52
+
53
+ loss = F.binary_cross_entropy_with_logits(
54
+ pred, target, reduction='none') * focal_weight
55
+ final_weight = torch.ones(1, pred.size(1)).type_as(loss)
56
+ if weight is not None:
57
+ if weight.shape != loss.shape and weight.size(0) == loss.size(0):
58
+ # For most cases, weight is of shape (N, ),
59
+ # which means it does not have the second axis num_class
60
+ weight = weight.view(-1, 1)
61
+ assert weight.dim() == loss.dim()
62
+ final_weight = final_weight * weight
63
+ if class_weight is not None:
64
+ final_weight = final_weight * pred.new_tensor(class_weight)
65
+ if valid_mask is not None:
66
+ final_weight = final_weight * valid_mask
67
+ loss = weight_reduce_loss(loss, final_weight, reduction, avg_factor)
68
+ return loss
69
+
70
+
71
+ def sigmoid_focal_loss(pred,
72
+ target,
73
+ one_hot_target,
74
+ weight=None,
75
+ gamma=2.0,
76
+ alpha=0.5,
77
+ class_weight=None,
78
+ valid_mask=None,
79
+ reduction='mean',
80
+ avg_factor=None):
81
+ r"""A wrapper of cuda version `Focal Loss
82
+ <https://arxiv.org/abs/1708.02002>`_.
83
+ Args:
84
+ pred (torch.Tensor): The prediction with shape (N, C), C is the number
85
+ of classes.
86
+ target (torch.Tensor): The learning label of the prediction. It's shape
87
+ should be (N, )
88
+ one_hot_target (torch.Tensor): The learning label with shape (N, C)
89
+ weight (torch.Tensor, optional): Sample-wise loss weight.
90
+ gamma (float, optional): The gamma for calculating the modulating
91
+ factor. Defaults to 2.0.
92
+ alpha (float | list[float], optional): A balanced form for Focal Loss.
93
+ Defaults to 0.5.
94
+ class_weight (list[float], optional): Weight of each class.
95
+ Defaults to None.
96
+ valid_mask (torch.Tensor, optional): A mask uses 1 to mark the valid
97
+ samples and uses 0 to mark the ignored samples. Default: None.
98
+ reduction (str, optional): The method used to reduce the loss into
99
+ a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum".
100
+ avg_factor (int, optional): Average factor that is used to average
101
+ the loss. Defaults to None.
102
+ """
103
+ # Function.apply does not accept keyword arguments, so the decorator
104
+ # "weighted_loss" is not applicable
105
+ final_weight = torch.ones(1, pred.size(1)).type_as(pred)
106
+ if isinstance(alpha, list):
107
+ # _sigmoid_focal_loss doesn't accept alpha of list type. Therefore, if
108
+ # a list is given, we set the input alpha as 0.5. This means setting
109
+ # equal weight for foreground class and background class. By
110
+ # multiplying the loss by 2, the effect of setting alpha as 0.5 is
111
+ # undone. The alpha of type list is used to regulate the loss in the
112
+ # post-processing process.
113
+ loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(),
114
+ gamma, 0.5, None, 'none') * 2
115
+ alpha = pred.new_tensor(alpha)
116
+ final_weight = final_weight * (
117
+ alpha * one_hot_target + (1 - alpha) * (1 - one_hot_target))
118
+ else:
119
+ loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(),
120
+ gamma, alpha, None, 'none')
121
+ if weight is not None:
122
+ if weight.shape != loss.shape and weight.size(0) == loss.size(0):
123
+ # For most cases, weight is of shape (N, ),
124
+ # which means it does not have the second axis num_class
125
+ weight = weight.view(-1, 1)
126
+ assert weight.dim() == loss.dim()
127
+ final_weight = final_weight * weight
128
+ if class_weight is not None:
129
+ final_weight = final_weight * pred.new_tensor(class_weight)
130
+ if valid_mask is not None:
131
+ final_weight = final_weight * valid_mask
132
+ loss = weight_reduce_loss(loss, final_weight, reduction, avg_factor)
133
+ return loss
134
+
135
+
136
+ @LOSSES.register_module()
137
+ class FocalLoss(nn.Module):
138
+
139
+ def __init__(self,
140
+ use_sigmoid=True,
141
+ gamma=2.0,
142
+ alpha=0.5,
143
+ reduction='mean',
144
+ class_weight=None,
145
+ loss_weight=1.0,
146
+ loss_name='loss_focal'):
147
+ """`Focal Loss <https://arxiv.org/abs/1708.02002>`_
148
+ Args:
149
+ use_sigmoid (bool, optional): Whether to the prediction is
150
+ used for sigmoid or softmax. Defaults to True.
151
+ gamma (float, optional): The gamma for calculating the modulating
152
+ factor. Defaults to 2.0.
153
+ alpha (float | list[float], optional): A balanced form for Focal
154
+ Loss. Defaults to 0.5. When a list is provided, the length
155
+ of the list should be equal to the number of classes.
156
+ Please be careful that this parameter is not the
157
+ class-wise weight but the weight of a binary classification
158
+ problem. This binary classification problem regards the
159
+ pixels which belong to one class as the foreground
160
+ and the other pixels as the background, each element in
161
+ the list is the weight of the corresponding foreground class.
162
+ The value of alpha or each element of alpha should be a float
163
+ in the interval [0, 1]. If you want to specify the class-wise
164
+ weight, please use `class_weight` parameter.
165
+ reduction (str, optional): The method used to reduce the loss into
166
+ a scalar. Defaults to 'mean'. Options are "none", "mean" and
167
+ "sum".
168
+ class_weight (list[float], optional): Weight of each class.
169
+ Defaults to None.
170
+ loss_weight (float, optional): Weight of loss. Defaults to 1.0.
171
+ loss_name (str, optional): Name of the loss item. If you want this
172
+ loss item to be included into the backward graph, `loss_` must
173
+ be the prefix of the name. Defaults to 'loss_focal'.
174
+ """
175
+ super(FocalLoss, self).__init__()
176
+ assert use_sigmoid is True, \
177
+ 'AssertionError: Only sigmoid focal loss supported now.'
178
+ assert reduction in ('none', 'mean', 'sum'), \
179
+ "AssertionError: reduction should be 'none', 'mean' or " \
180
+ "'sum'"
181
+ assert isinstance(alpha, (float, list)), \
182
+ 'AssertionError: alpha should be of type float'
183
+ assert isinstance(gamma, float), \
184
+ 'AssertionError: gamma should be of type float'
185
+ assert isinstance(loss_weight, float), \
186
+ 'AssertionError: loss_weight should be of type float'
187
+ assert isinstance(loss_name, str), \
188
+ 'AssertionError: loss_name should be of type str'
189
+ assert isinstance(class_weight, list) or class_weight is None, \
190
+ 'AssertionError: class_weight must be None or of type list'
191
+ self.use_sigmoid = use_sigmoid
192
+ self.gamma = gamma
193
+ self.alpha = alpha
194
+ self.reduction = reduction
195
+ self.class_weight = class_weight
196
+ self.loss_weight = loss_weight
197
+ self._loss_name = loss_name
198
+
199
+ def forward(self,
200
+ pred,
201
+ target,
202
+ weight=None,
203
+ avg_factor=None,
204
+ reduction_override=None,
205
+ ignore_index=255,
206
+ **kwargs):
207
+ """Forward function.
208
+
209
+ Args:
210
+ pred (torch.Tensor): The prediction with shape
211
+ (N, C) where C = number of classes, or
212
+ (N, C, d_1, d_2, ..., d_K) with K≥1 in the
213
+ case of K-dimensional loss.
214
+ target (torch.Tensor): The ground truth. If containing class
215
+ indices, shape (N) where each value is 0≤targets[i]≤C−1,
216
+ or (N, d_1, d_2, ..., d_K) with K≥1 in the case of
217
+ K-dimensional loss. If containing class probabilities,
218
+ same shape as the input.
219
+ weight (torch.Tensor, optional): The weight of loss for each
220
+ prediction. Defaults to None.
221
+ avg_factor (int, optional): Average factor that is used to
222
+ average the loss. Defaults to None.
223
+ reduction_override (str, optional): The reduction method used
224
+ to override the original reduction method of the loss.
225
+ Options are "none", "mean" and "sum".
226
+ ignore_index (int, optional): The label index to be ignored.
227
+ Default: 255
228
+ Returns:
229
+ torch.Tensor: The calculated loss
230
+ """
231
+ assert isinstance(ignore_index, int), \
232
+ 'ignore_index must be of type int'
233
+ assert reduction_override in (None, 'none', 'mean', 'sum'), \
234
+ "AssertionError: reduction should be 'none', 'mean' or " \
235
+ "'sum'"
236
+ assert pred.shape == target.shape or \
237
+ (pred.size(0) == target.size(0) and
238
+ pred.shape[2:] == target.shape[1:]), \
239
+ "The shape of pred doesn't match the shape of target"
240
+
241
+ original_shape = pred.shape
242
+
243
+ # [B, C, d_1, d_2, ..., d_k] -> [C, B, d_1, d_2, ..., d_k]
244
+ pred = pred.transpose(0, 1)
245
+ # [C, B, d_1, d_2, ..., d_k] -> [C, N]
246
+ pred = pred.reshape(pred.size(0), -1)
247
+ # [C, N] -> [N, C]
248
+ pred = pred.transpose(0, 1).contiguous()
249
+
250
+ if original_shape == target.shape:
251
+ # target with shape [B, C, d_1, d_2, ...]
252
+ # transform it's shape into [N, C]
253
+ # [B, C, d_1, d_2, ...] -> [C, B, d_1, d_2, ..., d_k]
254
+ target = target.transpose(0, 1)
255
+ # [C, B, d_1, d_2, ..., d_k] -> [C, N]
256
+ target = target.reshape(target.size(0), -1)
257
+ # [C, N] -> [N, C]
258
+ target = target.transpose(0, 1).contiguous()
259
+ else:
260
+ # target with shape [B, d_1, d_2, ...]
261
+ # transform it's shape into [N, ]
262
+ target = target.view(-1).contiguous()
263
+ valid_mask = (target != ignore_index).view(-1, 1)
264
+ # avoid raising error when using F.one_hot()
265
+ target = torch.where(target == ignore_index, target.new_tensor(0),
266
+ target)
267
+
268
+ reduction = (
269
+ reduction_override if reduction_override else self.reduction)
270
+ if self.use_sigmoid:
271
+ num_classes = pred.size(1)
272
+ if torch.cuda.is_available() and pred.is_cuda:
273
+ if target.dim() == 1:
274
+ one_hot_target = F.one_hot(target, num_classes=num_classes)
275
+ else:
276
+ one_hot_target = target
277
+ target = target.argmax(dim=1)
278
+ valid_mask = (target != ignore_index).view(-1, 1)
279
+ calculate_loss_func = sigmoid_focal_loss
280
+ else:
281
+ one_hot_target = None
282
+ if target.dim() == 1:
283
+ target = F.one_hot(target, num_classes=num_classes)
284
+ else:
285
+ valid_mask = (target.argmax(dim=1) != ignore_index).view(
286
+ -1, 1)
287
+ calculate_loss_func = py_sigmoid_focal_loss
288
+
289
+ loss_cls = self.loss_weight * calculate_loss_func(
290
+ pred,
291
+ target,
292
+ one_hot_target,
293
+ weight,
294
+ gamma=self.gamma,
295
+ alpha=self.alpha,
296
+ class_weight=self.class_weight,
297
+ valid_mask=valid_mask,
298
+ reduction=reduction,
299
+ avg_factor=avg_factor)
300
+
301
+ if reduction == 'none':
302
+ # [N, C] -> [C, N]
303
+ loss_cls = loss_cls.transpose(0, 1)
304
+ # [C, N] -> [C, B, d1, d2, ...]
305
+ # original_shape: [B, C, d1, d2, ...]
306
+ loss_cls = loss_cls.reshape(original_shape[1],
307
+ original_shape[0],
308
+ *original_shape[2:])
309
+ # [C, B, d1, d2, ...] -> [B, C, d1, d2, ...]
310
+ loss_cls = loss_cls.transpose(0, 1).contiguous()
311
+ else:
312
+ raise NotImplementedError
313
+ return loss_cls
314
+
315
+ @property
316
+ def loss_name(self):
317
+ """Loss Name.
318
+
319
+ This function must be implemented and will return the name of this
320
+ loss function. This name will be used to combine different loss items
321
+ by simple sum operation. In addition, if you want this loss item to be
322
+ included into the backward graph, `loss_` must be the prefix of the
323
+ name.
324
+ Returns:
325
+ str: The name of this loss item.
326
+ """
327
+ return self._loss_name
modelsforCIML/mmseg/models/losses/lovasz_loss.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ """Modified from https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytor
3
+ ch/lovasz_losses.py Lovasz-Softmax and Jaccard hinge loss in PyTorch Maxim
4
+ Berman 2018 ESAT-PSI KU Leuven (MIT License)"""
5
+
6
+ import mmcv
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from ..builder import LOSSES
12
+ from .utils import get_class_weight, weight_reduce_loss
13
+
14
+
15
+ def lovasz_grad(gt_sorted):
16
+ """Computes gradient of the Lovasz extension w.r.t sorted errors.
17
+
18
+ See Alg. 1 in paper.
19
+ """
20
+ p = len(gt_sorted)
21
+ gts = gt_sorted.sum()
22
+ intersection = gts - gt_sorted.float().cumsum(0)
23
+ union = gts + (1 - gt_sorted).float().cumsum(0)
24
+ jaccard = 1. - intersection / union
25
+ if p > 1: # cover 1-pixel case
26
+ jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
27
+ return jaccard
28
+
29
+
30
+ def flatten_binary_logits(logits, labels, ignore_index=None):
31
+ """Flattens predictions in the batch (binary case) Remove labels equal to
32
+ 'ignore_index'."""
33
+ logits = logits.view(-1)
34
+ labels = labels.view(-1)
35
+ if ignore_index is None:
36
+ return logits, labels
37
+ valid = (labels != ignore_index)
38
+ vlogits = logits[valid]
39
+ vlabels = labels[valid]
40
+ return vlogits, vlabels
41
+
42
+
43
+ def flatten_probs(probs, labels, ignore_index=None):
44
+ """Flattens predictions in the batch."""
45
+ if probs.dim() == 3:
46
+ # assumes output of a sigmoid layer
47
+ B, H, W = probs.size()
48
+ probs = probs.view(B, 1, H, W)
49
+ B, C, H, W = probs.size()
50
+ probs = probs.permute(0, 2, 3, 1).contiguous().view(-1, C) # B*H*W, C=P,C
51
+ labels = labels.view(-1)
52
+ if ignore_index is None:
53
+ return probs, labels
54
+ valid = (labels != ignore_index)
55
+ vprobs = probs[valid.nonzero().squeeze()]
56
+ vlabels = labels[valid]
57
+ return vprobs, vlabels
58
+
59
+
60
+ def lovasz_hinge_flat(logits, labels):
61
+ """Binary Lovasz hinge loss.
62
+
63
+ Args:
64
+ logits (torch.Tensor): [P], logits at each prediction
65
+ (between -infty and +infty).
66
+ labels (torch.Tensor): [P], binary ground truth labels (0 or 1).
67
+
68
+ Returns:
69
+ torch.Tensor: The calculated loss.
70
+ """
71
+ if len(labels) == 0:
72
+ # only void pixels, the gradients should be 0
73
+ return logits.sum() * 0.
74
+ signs = 2. * labels.float() - 1.
75
+ errors = (1. - logits * signs)
76
+ errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
77
+ perm = perm.data
78
+ gt_sorted = labels[perm]
79
+ grad = lovasz_grad(gt_sorted)
80
+ loss = torch.dot(F.relu(errors_sorted), grad)
81
+ return loss
82
+
83
+
84
+ def lovasz_hinge(logits,
85
+ labels,
86
+ classes='present',
87
+ per_image=False,
88
+ class_weight=None,
89
+ reduction='mean',
90
+ avg_factor=None,
91
+ ignore_index=255):
92
+ """Binary Lovasz hinge loss.
93
+
94
+ Args:
95
+ logits (torch.Tensor): [B, H, W], logits at each pixel
96
+ (between -infty and +infty).
97
+ labels (torch.Tensor): [B, H, W], binary ground truth masks (0 or 1).
98
+ classes (str | list[int], optional): Placeholder, to be consistent with
99
+ other loss. Default: None.
100
+ per_image (bool, optional): If per_image is True, compute the loss per
101
+ image instead of per batch. Default: False.
102
+ class_weight (list[float], optional): Placeholder, to be consistent
103
+ with other loss. Default: None.
104
+ reduction (str, optional): The method used to reduce the loss. Options
105
+ are "none", "mean" and "sum". This parameter only works when
106
+ per_image is True. Default: 'mean'.
107
+ avg_factor (int, optional): Average factor that is used to average
108
+ the loss. This parameter only works when per_image is True.
109
+ Default: None.
110
+ ignore_index (int | None): The label index to be ignored. Default: 255.
111
+
112
+ Returns:
113
+ torch.Tensor: The calculated loss.
114
+ """
115
+ if per_image:
116
+ loss = [
117
+ lovasz_hinge_flat(*flatten_binary_logits(
118
+ logit.unsqueeze(0), label.unsqueeze(0), ignore_index))
119
+ for logit, label in zip(logits, labels)
120
+ ]
121
+ loss = weight_reduce_loss(
122
+ torch.stack(loss), None, reduction, avg_factor)
123
+ else:
124
+ loss = lovasz_hinge_flat(
125
+ *flatten_binary_logits(logits, labels, ignore_index))
126
+ return loss
127
+
128
+
129
+ def lovasz_softmax_flat(probs, labels, classes='present', class_weight=None):
130
+ """Multi-class Lovasz-Softmax loss.
131
+
132
+ Args:
133
+ probs (torch.Tensor): [P, C], class probabilities at each prediction
134
+ (between 0 and 1).
135
+ labels (torch.Tensor): [P], ground truth labels (between 0 and C - 1).
136
+ classes (str | list[int], optional): Classes chosen to calculate loss.
137
+ 'all' for all classes, 'present' for classes present in labels, or
138
+ a list of classes to average. Default: 'present'.
139
+ class_weight (list[float], optional): The weight for each class.
140
+ Default: None.
141
+
142
+ Returns:
143
+ torch.Tensor: The calculated loss.
144
+ """
145
+ if probs.numel() == 0:
146
+ # only void pixels, the gradients should be 0
147
+ return probs * 0.
148
+ C = probs.size(1)
149
+ losses = []
150
+ class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
151
+ for c in class_to_sum:
152
+ fg = (labels == c).float() # foreground for class c
153
+ if (classes == 'present' and fg.sum() == 0):
154
+ continue
155
+ if C == 1:
156
+ if len(classes) > 1:
157
+ raise ValueError('Sigmoid output possible only with 1 class')
158
+ class_pred = probs[:, 0]
159
+ else:
160
+ class_pred = probs[:, c]
161
+ errors = (fg - class_pred).abs()
162
+ errors_sorted, perm = torch.sort(errors, 0, descending=True)
163
+ perm = perm.data
164
+ fg_sorted = fg[perm]
165
+ loss = torch.dot(errors_sorted, lovasz_grad(fg_sorted))
166
+ if class_weight is not None:
167
+ loss *= class_weight[c]
168
+ losses.append(loss)
169
+ return torch.stack(losses).mean()
170
+
171
+
172
+ def lovasz_softmax(probs,
173
+ labels,
174
+ classes='present',
175
+ per_image=False,
176
+ class_weight=None,
177
+ reduction='mean',
178
+ avg_factor=None,
179
+ ignore_index=255):
180
+ """Multi-class Lovasz-Softmax loss.
181
+
182
+ Args:
183
+ probs (torch.Tensor): [B, C, H, W], class probabilities at each
184
+ prediction (between 0 and 1).
185
+ labels (torch.Tensor): [B, H, W], ground truth labels (between 0 and
186
+ C - 1).
187
+ classes (str | list[int], optional): Classes chosen to calculate loss.
188
+ 'all' for all classes, 'present' for classes present in labels, or
189
+ a list of classes to average. Default: 'present'.
190
+ per_image (bool, optional): If per_image is True, compute the loss per
191
+ image instead of per batch. Default: False.
192
+ class_weight (list[float], optional): The weight for each class.
193
+ Default: None.
194
+ reduction (str, optional): The method used to reduce the loss. Options
195
+ are "none", "mean" and "sum". This parameter only works when
196
+ per_image is True. Default: 'mean'.
197
+ avg_factor (int, optional): Average factor that is used to average
198
+ the loss. This parameter only works when per_image is True.
199
+ Default: None.
200
+ ignore_index (int | None): The label index to be ignored. Default: 255.
201
+
202
+ Returns:
203
+ torch.Tensor: The calculated loss.
204
+ """
205
+
206
+ if per_image:
207
+ loss = [
208
+ lovasz_softmax_flat(
209
+ *flatten_probs(
210
+ prob.unsqueeze(0), label.unsqueeze(0), ignore_index),
211
+ classes=classes,
212
+ class_weight=class_weight)
213
+ for prob, label in zip(probs, labels)
214
+ ]
215
+ loss = weight_reduce_loss(
216
+ torch.stack(loss), None, reduction, avg_factor)
217
+ else:
218
+ loss = lovasz_softmax_flat(
219
+ *flatten_probs(probs, labels, ignore_index),
220
+ classes=classes,
221
+ class_weight=class_weight)
222
+ return loss
223
+
224
+
225
+ @LOSSES.register_module()
226
+ class LovaszLoss(nn.Module):
227
+ """LovaszLoss.
228
+
229
+ This loss is proposed in `The Lovasz-Softmax loss: A tractable surrogate
230
+ for the optimization of the intersection-over-union measure in neural
231
+ networks <https://arxiv.org/abs/1705.08790>`_.
232
+
233
+ Args:
234
+ loss_type (str, optional): Binary or multi-class loss.
235
+ Default: 'multi_class'. Options are "binary" and "multi_class".
236
+ classes (str | list[int], optional): Classes chosen to calculate loss.
237
+ 'all' for all classes, 'present' for classes present in labels, or
238
+ a list of classes to average. Default: 'present'.
239
+ per_image (bool, optional): If per_image is True, compute the loss per
240
+ image instead of per batch. Default: False.
241
+ reduction (str, optional): The method used to reduce the loss. Options
242
+ are "none", "mean" and "sum". This parameter only works when
243
+ per_image is True. Default: 'mean'.
244
+ class_weight (list[float] | str, optional): Weight of each class. If in
245
+ str format, read them from a file. Defaults to None.
246
+ loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
247
+ loss_name (str, optional): Name of the loss item. If you want this loss
248
+ item to be included into the backward graph, `loss_` must be the
249
+ prefix of the name. Defaults to 'loss_lovasz'.
250
+ """
251
+
252
+ def __init__(self,
253
+ loss_type='multi_class',
254
+ classes='present',
255
+ per_image=False,
256
+ reduction='mean',
257
+ class_weight=None,
258
+ loss_weight=1.0,
259
+ loss_name='loss_lovasz'):
260
+ super(LovaszLoss, self).__init__()
261
+ assert loss_type in ('binary', 'multi_class'), "loss_type should be \
262
+ 'binary' or 'multi_class'."
263
+
264
+ if loss_type == 'binary':
265
+ self.cls_criterion = lovasz_hinge
266
+ else:
267
+ self.cls_criterion = lovasz_softmax
268
+ assert classes in ('all', 'present') or mmcv.is_list_of(classes, int)
269
+ if not per_image:
270
+ assert reduction == 'none', "reduction should be 'none' when \
271
+ per_image is False."
272
+
273
+ self.classes = classes
274
+ self.per_image = per_image
275
+ self.reduction = reduction
276
+ self.loss_weight = loss_weight
277
+ self.class_weight = get_class_weight(class_weight)
278
+ self._loss_name = loss_name
279
+
280
+ def forward(self,
281
+ cls_score,
282
+ label,
283
+ weight=None,
284
+ avg_factor=None,
285
+ reduction_override=None,
286
+ **kwargs):
287
+ """Forward function."""
288
+ assert reduction_override in (None, 'none', 'mean', 'sum')
289
+ reduction = (
290
+ reduction_override if reduction_override else self.reduction)
291
+ if self.class_weight is not None:
292
+ class_weight = cls_score.new_tensor(self.class_weight)
293
+ else:
294
+ class_weight = None
295
+
296
+ # if multi-class loss, transform logits to probs
297
+ if self.cls_criterion == lovasz_softmax:
298
+ cls_score = F.softmax(cls_score, dim=1)
299
+
300
+ loss_cls = self.loss_weight * self.cls_criterion(
301
+ cls_score,
302
+ label,
303
+ self.classes,
304
+ self.per_image,
305
+ class_weight=class_weight,
306
+ reduction=reduction,
307
+ avg_factor=avg_factor,
308
+ **kwargs)
309
+ return loss_cls
310
+
311
+ @property
312
+ def loss_name(self):
313
+ """Loss Name.
314
+
315
+ This function must be implemented and will return the name of this
316
+ loss function. This name will be used to combine different loss items
317
+ by simple sum operation. In addition, if you want this loss item to be
318
+ included into the backward graph, `loss_` must be the prefix of the
319
+ name.
320
+ Returns:
321
+ str: The name of this loss item.
322
+ """
323
+ return self._loss_name
modelsforCIML/mmseg/models/losses/tversky_loss.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ """Modified from
3
+ https://github.com/JunMa11/SegLoss/blob/master/losses_pytorch/dice_loss.py#L333
4
+ (Apache-2.0 License)"""
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from ..builder import LOSSES
10
+ from .utils import get_class_weight, weighted_loss
11
+
12
+
13
+ @weighted_loss
14
+ def tversky_loss(pred,
15
+ target,
16
+ valid_mask,
17
+ alpha=0.3,
18
+ beta=0.7,
19
+ smooth=1,
20
+ class_weight=None,
21
+ ignore_index=255):
22
+ assert pred.shape[0] == target.shape[0]
23
+ total_loss = 0
24
+ num_classes = pred.shape[1]
25
+ for i in range(num_classes):
26
+ if i != ignore_index:
27
+ tversky_loss = binary_tversky_loss(
28
+ pred[:, i],
29
+ target[..., i],
30
+ valid_mask=valid_mask,
31
+ alpha=alpha,
32
+ beta=beta,
33
+ smooth=smooth)
34
+ if class_weight is not None:
35
+ tversky_loss *= class_weight[i]
36
+ total_loss += tversky_loss
37
+ return total_loss / num_classes
38
+
39
+
40
+ @weighted_loss
41
+ def binary_tversky_loss(pred,
42
+ target,
43
+ valid_mask,
44
+ alpha=0.3,
45
+ beta=0.7,
46
+ smooth=1):
47
+ assert pred.shape[0] == target.shape[0]
48
+ pred = pred.reshape(pred.shape[0], -1)
49
+ target = target.reshape(target.shape[0], -1)
50
+ valid_mask = valid_mask.reshape(valid_mask.shape[0], -1)
51
+
52
+ TP = torch.sum(torch.mul(pred, target) * valid_mask, dim=1)
53
+ FP = torch.sum(torch.mul(pred, 1 - target) * valid_mask, dim=1)
54
+ FN = torch.sum(torch.mul(1 - pred, target) * valid_mask, dim=1)
55
+ tversky = (TP + smooth) / (TP + alpha * FP + beta * FN + smooth)
56
+
57
+ return 1 - tversky
58
+
59
+
60
+ @LOSSES.register_module()
61
+ class TverskyLoss(nn.Module):
62
+ """TverskyLoss. This loss is proposed in `Tversky loss function for image
63
+ segmentation using 3D fully convolutional deep networks.
64
+
65
+ <https://arxiv.org/abs/1706.05721>`_.
66
+ Args:
67
+ smooth (float): A float number to smooth loss, and avoid NaN error.
68
+ Default: 1.
69
+ class_weight (list[float] | str, optional): Weight of each class. If in
70
+ str format, read them from a file. Defaults to None.
71
+ loss_weight (float, optional): Weight of the loss. Default to 1.0.
72
+ ignore_index (int | None): The label index to be ignored. Default: 255.
73
+ alpha(float, in [0, 1]):
74
+ The coefficient of false positives. Default: 0.3.
75
+ beta (float, in [0, 1]):
76
+ The coefficient of false negatives. Default: 0.7.
77
+ Note: alpha + beta = 1.
78
+ loss_name (str, optional): Name of the loss item. If you want this loss
79
+ item to be included into the backward graph, `loss_` must be the
80
+ prefix of the name. Defaults to 'loss_tversky'.
81
+ """
82
+
83
+ def __init__(self,
84
+ smooth=1,
85
+ class_weight=None,
86
+ loss_weight=1.0,
87
+ ignore_index=255,
88
+ alpha=0.3,
89
+ beta=0.7,
90
+ loss_name='loss_tversky'):
91
+ super(TverskyLoss, self).__init__()
92
+ self.smooth = smooth
93
+ self.class_weight = get_class_weight(class_weight)
94
+ self.loss_weight = loss_weight
95
+ self.ignore_index = ignore_index
96
+ assert (alpha + beta == 1.0), 'Sum of alpha and beta but be 1.0!'
97
+ self.alpha = alpha
98
+ self.beta = beta
99
+ self._loss_name = loss_name
100
+
101
+ def forward(self, pred, target, **kwargs):
102
+ if self.class_weight is not None:
103
+ class_weight = pred.new_tensor(self.class_weight)
104
+ else:
105
+ class_weight = None
106
+
107
+ pred = F.softmax(pred, dim=1)
108
+ num_classes = pred.shape[1]
109
+ one_hot_target = F.one_hot(
110
+ torch.clamp(target.long(), 0, num_classes - 1),
111
+ num_classes=num_classes)
112
+ valid_mask = (target != self.ignore_index).long()
113
+
114
+ loss = self.loss_weight * tversky_loss(
115
+ pred,
116
+ one_hot_target,
117
+ valid_mask=valid_mask,
118
+ alpha=self.alpha,
119
+ beta=self.beta,
120
+ smooth=self.smooth,
121
+ class_weight=class_weight,
122
+ ignore_index=self.ignore_index)
123
+ return loss
124
+
125
+ @property
126
+ def loss_name(self):
127
+ """Loss Name.
128
+
129
+ This function must be implemented and will return the name of this
130
+ loss function. This name will be used to combine different loss items
131
+ by simple sum operation. In addition, if you want this loss item to be
132
+ included into the backward graph, `loss_` must be the prefix of the
133
+ name.
134
+ Returns:
135
+ str: The name of this loss item.
136
+ """
137
+ return self._loss_name
modelsforCIML/mmseg/models/losses/utils.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import functools
3
+
4
+ import mmcv
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+
10
+ def get_class_weight(class_weight):
11
+ """Get class weight for loss function.
12
+
13
+ Args:
14
+ class_weight (list[float] | str | None): If class_weight is a str,
15
+ take it as a file name and read from it.
16
+ """
17
+ if isinstance(class_weight, str):
18
+ # take it as a file path
19
+ if class_weight.endswith('.npy'):
20
+ class_weight = np.load(class_weight)
21
+ else:
22
+ # pkl, json or yaml
23
+ class_weight = mmcv.load(class_weight)
24
+
25
+ return class_weight
26
+
27
+
28
+ def reduce_loss(loss, reduction):
29
+ """Reduce loss as specified.
30
+
31
+ Args:
32
+ loss (Tensor): Elementwise loss tensor.
33
+ reduction (str): Options are "none", "mean" and "sum".
34
+
35
+ Return:
36
+ Tensor: Reduced loss tensor.
37
+ """
38
+ reduction_enum = F._Reduction.get_enum(reduction)
39
+ # none: 0, elementwise_mean:1, sum: 2
40
+ if reduction_enum == 0:
41
+ return loss
42
+ elif reduction_enum == 1:
43
+ return loss.mean()
44
+ elif reduction_enum == 2:
45
+ return loss.sum()
46
+
47
+
48
+ def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
49
+ """Apply element-wise weight and reduce loss.
50
+
51
+ Args:
52
+ loss (Tensor): Element-wise loss.
53
+ weight (Tensor): Element-wise weights.
54
+ reduction (str): Same as built-in losses of PyTorch.
55
+ avg_factor (float): Average factor when computing the mean of losses.
56
+
57
+ Returns:
58
+ Tensor: Processed loss values.
59
+ """
60
+ # if weight is specified, apply element-wise weight
61
+ if weight is not None:
62
+ assert weight.dim() == loss.dim()
63
+ if weight.dim() > 1:
64
+ assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
65
+ loss = loss * weight
66
+
67
+ # if avg_factor is not specified, just reduce the loss
68
+ if avg_factor is None:
69
+ loss = reduce_loss(loss, reduction)
70
+ else:
71
+ # if reduction is mean, then average the loss by avg_factor
72
+ if reduction == 'mean':
73
+ # Avoid causing ZeroDivisionError when avg_factor is 0.0,
74
+ # i.e., all labels of an image belong to ignore index.
75
+ eps = torch.finfo(torch.float32).eps
76
+ loss = loss.sum() / (avg_factor + eps)
77
+ # if reduction is 'none', then do nothing, otherwise raise an error
78
+ elif reduction != 'none':
79
+ raise ValueError('avg_factor can not be used with reduction="sum"')
80
+ return loss
81
+
82
+
83
+ def weighted_loss(loss_func):
84
+ """Create a weighted version of a given loss function.
85
+
86
+ To use this decorator, the loss function must have the signature like
87
+ `loss_func(pred, target, **kwargs)`. The function only needs to compute
88
+ element-wise loss without any reduction. This decorator will add weight
89
+ and reduction arguments to the function. The decorated function will have
90
+ the signature like `loss_func(pred, target, weight=None, reduction='mean',
91
+ avg_factor=None, **kwargs)`.
92
+
93
+ :Example:
94
+
95
+ >>> import torch
96
+ >>> @weighted_loss
97
+ >>> def l1_loss(pred, target):
98
+ >>> return (pred - target).abs()
99
+
100
+ >>> pred = torch.Tensor([0, 2, 3])
101
+ >>> target = torch.Tensor([1, 1, 1])
102
+ >>> weight = torch.Tensor([1, 0, 1])
103
+
104
+ >>> l1_loss(pred, target)
105
+ tensor(1.3333)
106
+ >>> l1_loss(pred, target, weight)
107
+ tensor(1.)
108
+ >>> l1_loss(pred, target, reduction='none')
109
+ tensor([1., 1., 2.])
110
+ >>> l1_loss(pred, target, weight, avg_factor=2)
111
+ tensor(1.5000)
112
+ """
113
+
114
+ @functools.wraps(loss_func)
115
+ def wrapper(pred,
116
+ target,
117
+ weight=None,
118
+ reduction='mean',
119
+ avg_factor=None,
120
+ **kwargs):
121
+ # get element-wise loss
122
+ loss = loss_func(pred, target, **kwargs)
123
+ loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
124
+ return loss
125
+
126
+ return wrapper
modelsforCIML/mmseg/ops/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .encoding import Encoding
3
+ from .wrappers import Upsample, resize
4
+
5
+ __all__ = ['Upsample', 'resize', 'Encoding']
modelsforCIML/mmseg/ops/encoding.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+
7
+ class Encoding(nn.Module):
8
+ """Encoding Layer: a learnable residual encoder.
9
+
10
+ Input is of shape (batch_size, channels, height, width).
11
+ Output is of shape (batch_size, num_codes, channels).
12
+
13
+ Args:
14
+ channels: dimension of the features or feature channels
15
+ num_codes: number of code words
16
+ """
17
+
18
+ def __init__(self, channels, num_codes):
19
+ super(Encoding, self).__init__()
20
+ # init codewords and smoothing factor
21
+ self.channels, self.num_codes = channels, num_codes
22
+ std = 1. / ((num_codes * channels)**0.5)
23
+ # [num_codes, channels]
24
+ self.codewords = nn.Parameter(
25
+ torch.empty(num_codes, channels,
26
+ dtype=torch.float).uniform_(-std, std),
27
+ requires_grad=True)
28
+ # [num_codes]
29
+ self.scale = nn.Parameter(
30
+ torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0),
31
+ requires_grad=True)
32
+
33
+ @staticmethod
34
+ def scaled_l2(x, codewords, scale):
35
+ num_codes, channels = codewords.size()
36
+ batch_size = x.size(0)
37
+ reshaped_scale = scale.view((1, 1, num_codes))
38
+ expanded_x = x.unsqueeze(2).expand(
39
+ (batch_size, x.size(1), num_codes, channels))
40
+ reshaped_codewords = codewords.view((1, 1, num_codes, channels))
41
+
42
+ scaled_l2_norm = reshaped_scale * (
43
+ expanded_x - reshaped_codewords).pow(2).sum(dim=3)
44
+ return scaled_l2_norm
45
+
46
+ @staticmethod
47
+ def aggregate(assignment_weights, x, codewords):
48
+ num_codes, channels = codewords.size()
49
+ reshaped_codewords = codewords.view((1, 1, num_codes, channels))
50
+ batch_size = x.size(0)
51
+
52
+ expanded_x = x.unsqueeze(2).expand(
53
+ (batch_size, x.size(1), num_codes, channels))
54
+ encoded_feat = (assignment_weights.unsqueeze(3) *
55
+ (expanded_x - reshaped_codewords)).sum(dim=1)
56
+ return encoded_feat
57
+
58
+ def forward(self, x):
59
+ assert x.dim() == 4 and x.size(1) == self.channels
60
+ # [batch_size, channels, height, width]
61
+ batch_size = x.size(0)
62
+ # [batch_size, height x width, channels]
63
+ x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous()
64
+ # assignment_weights: [batch_size, channels, num_codes]
65
+ assignment_weights = F.softmax(
66
+ self.scaled_l2(x, self.codewords, self.scale), dim=2)
67
+ # aggregate
68
+ encoded_feat = self.aggregate(assignment_weights, x, self.codewords)
69
+ return encoded_feat
70
+
71
+ def __repr__(self):
72
+ repr_str = self.__class__.__name__
73
+ repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \
74
+ f'x{self.channels})'
75
+ return repr_str
modelsforCIML/mmseg/ops/wrappers.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import warnings
3
+
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def resize(input,
9
+ size=None,
10
+ scale_factor=None,
11
+ mode='nearest',
12
+ align_corners=None,
13
+ warning=True):
14
+ if warning:
15
+ if size is not None and align_corners:
16
+ input_h, input_w = tuple(int(x) for x in input.shape[2:])
17
+ output_h, output_w = tuple(int(x) for x in size)
18
+ if output_h > input_h or output_w > input_w:
19
+ if ((output_h > 1 and output_w > 1 and input_h > 1
20
+ and input_w > 1) and (output_h - 1) % (input_h - 1)
21
+ and (output_w - 1) % (input_w - 1)):
22
+ warnings.warn(
23
+ f'When align_corners={align_corners}, '
24
+ 'the output would more aligned if '
25
+ f'input size {(input_h, input_w)} is `x+1` and '
26
+ f'out size {(output_h, output_w)} is `nx+1`')
27
+ return F.interpolate(input, size, scale_factor, mode, align_corners)
28
+
29
+
30
+ class Upsample(nn.Module):
31
+
32
+ def __init__(self,
33
+ size=None,
34
+ scale_factor=None,
35
+ mode='nearest',
36
+ align_corners=None):
37
+ super(Upsample, self).__init__()
38
+ self.size = size
39
+ if isinstance(scale_factor, tuple):
40
+ self.scale_factor = tuple(float(factor) for factor in scale_factor)
41
+ else:
42
+ self.scale_factor = float(scale_factor) if scale_factor else None
43
+ self.mode = mode
44
+ self.align_corners = align_corners
45
+
46
+ def forward(self, x):
47
+ if not self.size:
48
+ size = [int(t * self.scale_factor) for t in x.shape[-2:]]
49
+ else:
50
+ size = self.size
51
+ return resize(x, size, None, self.mode, self.align_corners)
modelsforCIML/mmseg/utils/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .collect_env import collect_env
3
+ from .logger import get_root_logger
4
+ from .misc import find_latest_checkpoint
5
+ from .set_env import setup_multi_processes
6
+ from .util_distribution import build_ddp, build_dp, get_device
7
+
8
+ __all__ = [
9
+ 'get_root_logger', 'collect_env', 'find_latest_checkpoint',
10
+ 'setup_multi_processes', 'build_ddp', 'build_dp', 'get_device'
11
+ ]