Upload 58 files
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +1 -0
- modelsforCIML/Auto_Annotate_SDG.py +113 -0
- modelsforCIML/Auto_Annotate_SPG.py +248 -0
- modelsforCIML/CAAA_OK.png +3 -0
- modelsforCIML/Readme.md +50 -0
- modelsforCIML/classify_convxl.py +168 -0
- modelsforCIML/convbuper.py +209 -0
- modelsforCIML/convnext.py +165 -0
- modelsforCIML/dass.py +289 -0
- modelsforCIML/mmseg/__init__.py +62 -0
- modelsforCIML/mmseg/core/__init__.py +12 -0
- modelsforCIML/mmseg/core/builder.py +33 -0
- modelsforCIML/mmseg/core/evaluation/__init__.py +11 -0
- modelsforCIML/mmseg/core/evaluation/class_names.py +327 -0
- modelsforCIML/mmseg/core/evaluation/eval_hooks.py +132 -0
- modelsforCIML/mmseg/core/evaluation/metrics.py +396 -0
- modelsforCIML/mmseg/core/hook/__init__.py +4 -0
- modelsforCIML/mmseg/core/hook/wandblogger_hook.py +370 -0
- modelsforCIML/mmseg/core/optimizers/__init__.py +7 -0
- modelsforCIML/mmseg/core/optimizers/layer_decay_optimizer_constructor.py +211 -0
- modelsforCIML/mmseg/core/seg/__init__.py +5 -0
- modelsforCIML/mmseg/core/seg/builder.py +9 -0
- modelsforCIML/mmseg/core/seg/sampler/__init__.py +5 -0
- modelsforCIML/mmseg/core/seg/sampler/base_pixel_sampler.py +13 -0
- modelsforCIML/mmseg/core/seg/sampler/ohem_pixel_sampler.py +85 -0
- modelsforCIML/mmseg/core/utils/__init__.py +5 -0
- modelsforCIML/mmseg/core/utils/dist_util.py +46 -0
- modelsforCIML/mmseg/core/utils/misc.py +18 -0
- modelsforCIML/mmseg/models/__init__.py +10 -0
- modelsforCIML/mmseg/models/builder.py +49 -0
- modelsforCIML/mmseg/models/decode_heads/__init__.py +9 -0
- modelsforCIML/mmseg/models/decode_heads/aspp_head.py +122 -0
- modelsforCIML/mmseg/models/decode_heads/decode_head.py +295 -0
- modelsforCIML/mmseg/models/decode_heads/fcn_head.py +88 -0
- modelsforCIML/mmseg/models/decode_heads/psp_head.py +117 -0
- modelsforCIML/mmseg/models/decode_heads/sep_aspp_head.py +105 -0
- modelsforCIML/mmseg/models/decode_heads/uper_head.py +128 -0
- modelsforCIML/mmseg/models/decode_heads/uper_lab.py +120 -0
- modelsforCIML/mmseg/models/losses/__init__.py +16 -0
- modelsforCIML/mmseg/models/losses/accuracy.py +92 -0
- modelsforCIML/mmseg/models/losses/cross_entropy_loss.py +296 -0
- modelsforCIML/mmseg/models/losses/dice_loss.py +137 -0
- modelsforCIML/mmseg/models/losses/focal_loss.py +327 -0
- modelsforCIML/mmseg/models/losses/lovasz_loss.py +323 -0
- modelsforCIML/mmseg/models/losses/tversky_loss.py +137 -0
- modelsforCIML/mmseg/models/losses/utils.py +126 -0
- modelsforCIML/mmseg/ops/__init__.py +5 -0
- modelsforCIML/mmseg/ops/encoding.py +75 -0
- modelsforCIML/mmseg/ops/wrappers.py +51 -0
- modelsforCIML/mmseg/utils/__init__.py +11 -0
.gitattributes
CHANGED
|
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
models[[:space:]]for[[:space:]]CIML/CAAA_OK.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
models[[:space:]]for[[:space:]]CIML/CAAA_OK.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
modelsforCIML/CAAA_OK.png filter=lfs diff=lfs merge=lfs -text
|
modelsforCIML/Auto_Annotate_SDG.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python2
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
@author: liuyaqi
|
| 5 |
+
"""
|
| 6 |
+
import os
|
| 7 |
+
import cv2
|
| 8 |
+
import random
|
| 9 |
+
import torch
|
| 10 |
+
import torchvision
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.optim as optim
|
| 13 |
+
from torch.nn import functional as F
|
| 14 |
+
import numpy as np
|
| 15 |
+
import time
|
| 16 |
+
import logging
|
| 17 |
+
import argparse
|
| 18 |
+
from PIL import Image
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
import albumentations as A
|
| 21 |
+
import torch.distributed as dist
|
| 22 |
+
from albumentations.pytorch import ToTensorV2
|
| 23 |
+
from torch.utils.data import Dataset, DataLoader
|
| 24 |
+
import safm_convb as safm
|
| 25 |
+
parser = argparse.ArgumentParser()
|
| 26 |
+
parser.add_argument('--nm', type=str, default='ori')
|
| 27 |
+
parser.add_argument('--epoch', type=int, default=1)
|
| 28 |
+
parser.add_argument('--pth', type=str, default='SAFM.pth')
|
| 29 |
+
parser.add_argument('--thres', type=float, default=0.5)
|
| 30 |
+
parser.add_argument('--numw', type=int, default=16)
|
| 31 |
+
parser.add_argument('--batch_size', type=int, default=1)
|
| 32 |
+
parser.add_argument('--input_scale', type=int, default=512)
|
| 33 |
+
parser.add_argument('--local_rank', default=-1, type=int, help='node rank for distributed training')
|
| 34 |
+
args = parser.parse_args()
|
| 35 |
+
|
| 36 |
+
class CVPR24EvalDataset(Dataset):
|
| 37 |
+
def __init__(self, roots, img_dir, sz=512, fan=False):
|
| 38 |
+
self.fan = fan
|
| 39 |
+
self.roots = os.path.join(roots, img_dir)
|
| 40 |
+
'''
|
| 41 |
+
Dir strucure in self.roots:
|
| 42 |
+
|
|
| 43 |
+
self.roots
|
| 44 |
+
|
|
| 45 |
+
|---dir1
|
| 46 |
+
| |----0.jpg (SDG authentic image)
|
| 47 |
+
| |----1.jpg (SDG manipulated image)
|
| 48 |
+
|
|
| 49 |
+
|---dir2
|
| 50 |
+
| |----0.jpg (SDG authentic image)
|
| 51 |
+
| |----1.jpg (SDG manipulated image)
|
| 52 |
+
|
|
| 53 |
+
.........
|
| 54 |
+
'''
|
| 55 |
+
self.indexs = [os.path.join(self.roots, x) for x in os.listdir(self.roots)]
|
| 56 |
+
self.indexs.sort()
|
| 57 |
+
self.lens = len(self.indexs)
|
| 58 |
+
self.tsr = ToTensorV2()
|
| 59 |
+
self.lbl = torch.FloatTensor([1])
|
| 60 |
+
self.rsz = torchvision.transforms.Compose([torchvision.transforms.Resize((sz,sz))])
|
| 61 |
+
self.toctsr =torchvision.transforms.Compose([torchvision.transforms.Resize((sz, sz)), torchvision.transforms.Normalize(mean=((0.485, 0.455, 0.406)), std=((0.229, 0.224, 0.225)))])
|
| 62 |
+
|
| 63 |
+
def __len__(self):
|
| 64 |
+
return self.lens
|
| 65 |
+
|
| 66 |
+
def __getitem__(self, idx):
|
| 67 |
+
this_r = self.indexs[idx]
|
| 68 |
+
img1 = self.toctsr(self.tsr(image=cv2.cvtColor(cv2.imread(os.path.join(this_r, '0.jpg')), cv2.COLOR_BGR2RGB))['image'].float()/255.0)
|
| 69 |
+
img2 = self.toctsr(self.tsr(image=cv2.cvtColor(cv2.imread(os.path.join(this_r, '1.jpg')), cv2.COLOR_BGR2RGB))['image'].float()/255.0)
|
| 70 |
+
return (img1, img2, this_r.split('/')[-1])
|
| 71 |
+
|
| 72 |
+
test_data = CVPR24EvalDataset('./', 'SDG')
|
| 73 |
+
test_loader = DataLoader(dataset=test_data, batch_size=1, num_workers=4)
|
| 74 |
+
|
| 75 |
+
model = safm.SAFM(2, 512)
|
| 76 |
+
model = model.cuda()
|
| 77 |
+
model = nn.DataParallel(model)
|
| 78 |
+
loader = torch.load(args.pth, map_location='cpu')
|
| 79 |
+
model.load_state_dict(loader)
|
| 80 |
+
model.eval()
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
if not os.path.exists('SDG_preds'):
|
| 84 |
+
os.makedirs('SDG_preds')
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
with torch.no_grad():
|
| 88 |
+
ious = []
|
| 89 |
+
ps = []
|
| 90 |
+
rs = []
|
| 91 |
+
fs = []
|
| 92 |
+
for (im1, im2, fnm) in tqdm(test_loader):
|
| 93 |
+
im1 = im1.cuda()
|
| 94 |
+
im2 = im2.cuda()
|
| 95 |
+
_, pred, _, _ = model(im1, im2)
|
| 96 |
+
_, pred2, _, _ = model(im1, torch.flip(im2, [2]))
|
| 97 |
+
pred2 = torch.flip(pred2, [2])
|
| 98 |
+
|
| 99 |
+
_, pred3, _, _ = model(im1, torch.flip(im2, [3]))
|
| 100 |
+
pred3 = torch.flip(pred3, [3])
|
| 101 |
+
|
| 102 |
+
preds = F.softmax((pred+pred2+pred3) ,dim=1)[:,1:2].squeeze().cpu().numpy()
|
| 103 |
+
s1 = (preds>(1/16)).sum()
|
| 104 |
+
s2 = (preds>(15/16)).sum()
|
| 105 |
+
if (s2/(s1+1e-6)>0.5):
|
| 106 |
+
cv2.imwrite('SDG_preds/'+fnm[0]+'.png', (preds*255).astype(np.uint8))
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
|
modelsforCIML/Auto_Annotate_SPG.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import math
|
| 4 |
+
import torch#用户ID:7fb702cd-1293-4470-a3b2-4ba88c3b3d4a
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import logging
|
| 8 |
+
import torch.optim as optim
|
| 9 |
+
import torch.distributed as dist
|
| 10 |
+
import random
|
| 11 |
+
import pickle
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
from torch.autograd import Variable
|
| 15 |
+
from torch.cuda.amp import autocast
|
| 16 |
+
import segmentation_models_pytorch as smp
|
| 17 |
+
from torch.utils.data import Dataset, DataLoader
|
| 18 |
+
import albumentations as A
|
| 19 |
+
from albumentations.pytorch import ToTensorV2
|
| 20 |
+
import torchvision
|
| 21 |
+
import argparse
|
| 22 |
+
parser = argparse.ArgumentParser()
|
| 23 |
+
parser.add_argument('--data_root', type=str, default='../../')
|
| 24 |
+
parser.add_argument('--train_name', type=str, default='CHDOC_JPEG0')
|
| 25 |
+
parser.add_argument('--model_name', type=str, default='exp')
|
| 26 |
+
parser.add_argument('--att', type=str, default='None')
|
| 27 |
+
parser.add_argument('--num', type=str, default='1')
|
| 28 |
+
parser.add_argument('--n_class', type=int, default=2)
|
| 29 |
+
parser.add_argument('--bs', type=int, default=1)
|
| 30 |
+
parser.add_argument('--es', type=int, default=0)
|
| 31 |
+
parser.add_argument('--ep', type=int, default=1)
|
| 32 |
+
parser.add_argument('--xk', type=int, default=0)
|
| 33 |
+
parser.add_argument('--numw', type=int, default=8)
|
| 34 |
+
parser.add_argument('--load', type=int, default=0)
|
| 35 |
+
parser.add_argument('--pilt', type=int, default=0)
|
| 36 |
+
parser.add_argument('--base', type=int, default=1)
|
| 37 |
+
parser.add_argument('--lr_base', type=float, default=3e-4)
|
| 38 |
+
parser.add_argument('--cp', type=float, default=1.0)
|
| 39 |
+
parser.add_argument('--mode', type=str, default='0123')
|
| 40 |
+
parser.add_argument('--adds', type=str, default='123')
|
| 41 |
+
parser.add_argument('--loss-', type=str, default='1,2,3,4')
|
| 42 |
+
args = parser.parse_args()
|
| 43 |
+
|
| 44 |
+
def getdir(path):
|
| 45 |
+
if not os.path.exists(path):
|
| 46 |
+
os.makedirs(path)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class CVPR24REDataset(Dataset):
|
| 50 |
+
def __init__(self, roots, img_dir, times=3, repeats=1):
|
| 51 |
+
self.roots = os.path.join(roots, img_dir)
|
| 52 |
+
self.indexs = [os.path.join(self.roots, x) for x in os.listdir(self.roots)]
|
| 53 |
+
self.lens = len(self.indexs)
|
| 54 |
+
self.roots = os.path.join(roots, img_dir)
|
| 55 |
+
'''
|
| 56 |
+
Dir strucure in self.roots:
|
| 57 |
+
|
|
| 58 |
+
self.roots
|
| 59 |
+
|
|
| 60 |
+
|---dir1
|
| 61 |
+
| |----0.jpg (SDG authentic image)
|
| 62 |
+
| |----1.jpg (SDG manipulated image)
|
| 63 |
+
|
|
| 64 |
+
|---dir2
|
| 65 |
+
| |----0.jpg (SDG authentic image)
|
| 66 |
+
| |----1.jpg (SDG manipulated image)
|
| 67 |
+
|
|
| 68 |
+
.........
|
| 69 |
+
'''
|
| 70 |
+
self.rsz = A.Compose([A.Resize(1024,1024)])
|
| 71 |
+
self.transforms = A.Compose([ToTensorV2()])
|
| 72 |
+
self.toctsr =torchvision.transforms.Compose([torchvision.transforms.Normalize(mean=((0.485, 0.455, 0.406)*times), std=((0.229, 0.224, 0.225)*times))])
|
| 73 |
+
|
| 74 |
+
def __len__(self):
|
| 75 |
+
return self.lens
|
| 76 |
+
|
| 77 |
+
def __getitem__(self, idx):
|
| 78 |
+
this_r = self.indexs[idx]
|
| 79 |
+
print(this_r)
|
| 80 |
+
this_r = (os.path.join(this_r, '1.jpg'), os.path.join(this_r, '0.jpg'))
|
| 81 |
+
img1 = cv2.cvtColor(cv2.imread(this_r[1]), cv2.COLOR_BGR2RGB)
|
| 82 |
+
img2 = cv2.cvtColor(cv2.imread(this_r[0]), cv2.COLOR_BGR2RGB)
|
| 83 |
+
h,w = img2.shape[:2]
|
| 84 |
+
mask = np.zeros((h,w),dtype=np.uint8)
|
| 85 |
+
img1 = self.rsz(image=img1)['image']
|
| 86 |
+
rsts = self.rsz(image=img2, mask=mask)
|
| 87 |
+
img2 = rsts['image']
|
| 88 |
+
mask = rsts['mask']
|
| 89 |
+
imgs = np.concatenate((img1,img2),2)
|
| 90 |
+
rsts = self.transforms(image=imgs,mask=mask)
|
| 91 |
+
imgs = rsts['image']
|
| 92 |
+
imgs = (torch.cat((imgs,torch.abs(imgs[:3]-imgs[3:])), 0).float()/255.0)
|
| 93 |
+
imgs = self.toctsr(imgs)
|
| 94 |
+
mask = rsts['mask'].long()
|
| 95 |
+
return (imgs, mask, this_r[0].split('/')[-2], h, w)
|
| 96 |
+
|
| 97 |
+
ngpu = torch.cuda.device_count()
|
| 98 |
+
ngpub = ngpu * args.base
|
| 99 |
+
if False:
|
| 100 |
+
gpus = True
|
| 101 |
+
device = torch.device("cuda",args.local_rank)
|
| 102 |
+
torch.cuda.set_device(args.local_rank)
|
| 103 |
+
dist.init_process_group(backend='nccl')
|
| 104 |
+
else:
|
| 105 |
+
gpus = False
|
| 106 |
+
device = torch.device("cuda")
|
| 107 |
+
|
| 108 |
+
roots1 = './'
|
| 109 |
+
|
| 110 |
+
test_data1 = CVPR24REDataset('your_data_dir/', 'SPG')
|
| 111 |
+
test_data2 = CVPR24REDataset('your_data_dir/', 'SPG')
|
| 112 |
+
|
| 113 |
+
class AverageMeter(object):
|
| 114 |
+
def __init__(self):
|
| 115 |
+
self.reset()
|
| 116 |
+
def reset(self):
|
| 117 |
+
self.val = 0
|
| 118 |
+
self.avg = 0
|
| 119 |
+
self.sum = 0
|
| 120 |
+
self.count = 0
|
| 121 |
+
def update(self, val, n=1):
|
| 122 |
+
self.val = val
|
| 123 |
+
self.sum += val * n
|
| 124 |
+
self.count += n
|
| 125 |
+
self.avg = self.sum / self.count
|
| 126 |
+
|
| 127 |
+
def second2time(second):
|
| 128 |
+
if second < 60:
|
| 129 |
+
return str('{}'.format(round(second, 4)))
|
| 130 |
+
elif second < 60*60:
|
| 131 |
+
m = second//60
|
| 132 |
+
s = second % 60
|
| 133 |
+
return str('{}:{}'.format(int(m), round(s, 1)))
|
| 134 |
+
elif second < 60*60*60:
|
| 135 |
+
h = second//(60*60)
|
| 136 |
+
m = second % (60*60)//60
|
| 137 |
+
s = second % (60*60) % 60
|
| 138 |
+
return str('{}:{}:{}'.format(int(h), int(m), int(s)))
|
| 139 |
+
|
| 140 |
+
def inial_logger(file):
|
| 141 |
+
logger = logging.getLogger('log')
|
| 142 |
+
logger.setLevel(level=logging.DEBUG)
|
| 143 |
+
formatter = logging.Formatter('%(message)s')
|
| 144 |
+
file_handler = logging.FileHandler(file)
|
| 145 |
+
file_handler.setLevel(level=logging.INFO)
|
| 146 |
+
file_handler.setFormatter(formatter)
|
| 147 |
+
stream_handler = logging.StreamHandler()
|
| 148 |
+
stream_handler.setLevel(logging.DEBUG)
|
| 149 |
+
stream_handler.setFormatter(formatter)
|
| 150 |
+
logger.addHandler(file_handler)
|
| 151 |
+
logger.addHandler(stream_handler)
|
| 152 |
+
return logger
|
| 153 |
+
|
| 154 |
+
from functools import partial
|
| 155 |
+
import torch
|
| 156 |
+
import torch.nn as nn
|
| 157 |
+
import torch.nn.functional as F
|
| 158 |
+
from timm.models.layers import trunc_normal_, DropPath
|
| 159 |
+
from mmseg.utils import get_root_logger
|
| 160 |
+
|
| 161 |
+
from dass import DASS
|
| 162 |
+
|
| 163 |
+
model=DASS(in_chans=9).to(device)
|
| 164 |
+
|
| 165 |
+
model = nn.DataParallel(model)
|
| 166 |
+
loader = torch.load('DASS.pth',map_location='cpu')['state_dict']
|
| 167 |
+
model.load_state_dict(loader)
|
| 168 |
+
|
| 169 |
+
model_name = args.model_name
|
| 170 |
+
save_ckpt_dir = os.path.join('./outputs/', model_name, 'ckpt')
|
| 171 |
+
save_log_dir = os.path.join('./outputs/', model_name)
|
| 172 |
+
try:
|
| 173 |
+
if not os.path.exists(save_ckpt_dir):
|
| 174 |
+
os.makedirs(save_ckpt_dir)
|
| 175 |
+
except:
|
| 176 |
+
pass
|
| 177 |
+
try:
|
| 178 |
+
if not os.path.exists(save_log_dir):
|
| 179 |
+
os.makedirs(save_log_dir)
|
| 180 |
+
except:
|
| 181 |
+
pass
|
| 182 |
+
import gc
|
| 183 |
+
param = {}
|
| 184 |
+
param['batch_size'] = args.bs # 批大小
|
| 185 |
+
param['epochs'] = args.ep # 训练轮数,请和scheduler的策略对应,不然复现不出效果,对于t0=3,t_mut=2的scheduler来讲,44的时候会达到最优
|
| 186 |
+
param['disp_inter'] = 1 # 显示间隔(epoch)
|
| 187 |
+
param['save_inter'] = 4 # 保存间隔(epoch)
|
| 188 |
+
param['iter_inter'] = 64 # 显示迭代间隔(batch)
|
| 189 |
+
param['min_inter'] = 10
|
| 190 |
+
param['model_name'] = model_name # 模型名称
|
| 191 |
+
param['save_log_dir'] = save_log_dir # 日志保存路径
|
| 192 |
+
param['save_ckpt_dir'] = save_ckpt_dir # 权重保存路径
|
| 193 |
+
param['T0']=int(24/ngpub) #cosine warmup的参数
|
| 194 |
+
param['load_ckpt_dir'] = None
|
| 195 |
+
import time
|
| 196 |
+
|
| 197 |
+
def collate_batch(batch_list):
|
| 198 |
+
assert type(batch_list) == list, f"Error"
|
| 199 |
+
batch_size = len(batch_list)
|
| 200 |
+
data = torch.cat([item[0] for item in batch_list]).reshape(batch_size, -1)
|
| 201 |
+
labels = torch.cat([item[1] for item in batch_list]).reshape(batch_size, -1)
|
| 202 |
+
return data, labels
|
| 203 |
+
|
| 204 |
+
def train_net_qyl(param, model, test_data1, test_data2, plot=False,device='cuda'):
|
| 205 |
+
# 初始化参数
|
| 206 |
+
global gpus
|
| 207 |
+
model_name = param['model_name']
|
| 208 |
+
epochs = param['epochs']
|
| 209 |
+
batch_size = param['batch_size']
|
| 210 |
+
iter_inter = param['iter_inter']
|
| 211 |
+
save_log_dir = param['save_log_dir']
|
| 212 |
+
save_ckpt_dir = param['save_ckpt_dir']
|
| 213 |
+
load_ckpt_dir = param['load_ckpt_dir']
|
| 214 |
+
T0=param['T0']
|
| 215 |
+
lr_base = args.lr_base
|
| 216 |
+
if gpus:
|
| 217 |
+
# valid_loader1 = DataLoader(dataset=test_data1, batch_size=batch_size, num_workers=args.numw, shuffle=False)
|
| 218 |
+
valid_loader2 = DataLoader(dataset=test_data2, batch_size=batch_size, num_workers=args.numw, shuffle=False)
|
| 219 |
+
else:
|
| 220 |
+
# valid_loader1 = DataLoader(dataset=test_data1, batch_size=batch_size, num_workers=args.numw, shuffle=False)
|
| 221 |
+
valid_loader2 = DataLoader(dataset=test_data2, batch_size=batch_size, num_workers=args.numw, shuffle=False)
|
| 222 |
+
optimizer = optim.AdamW(model.parameters(), lr=1e-4 ,weight_decay=5e-2)
|
| 223 |
+
if True:
|
| 224 |
+
model.eval()
|
| 225 |
+
with torch.no_grad():
|
| 226 |
+
for batch_idx, batch_samples in enumerate(tqdm(valid_loader2)):
|
| 227 |
+
data, target, fnms, h, w = batch_samples
|
| 228 |
+
h = h.item()
|
| 229 |
+
w = w.item()
|
| 230 |
+
data, target = Variable(data.to(device)), Variable(target.to(device))
|
| 231 |
+
if True:
|
| 232 |
+
d2 = torch.flip(data,dims=[2])
|
| 233 |
+
d3 = torch.flip(data,dims=[3])
|
| 234 |
+
data = torch.cat((data,d2,d3),0)
|
| 235 |
+
pred = model(data)
|
| 236 |
+
pred[1:2] = torch.flip(pred[1:2], dims=[2])
|
| 237 |
+
pred[2:3] = torch.flip(pred[2:3], dims=[3])
|
| 238 |
+
pred = pred.mean(0,keepdim=True)
|
| 239 |
+
pred= (F.softmax(pred,dim=1)[:,1:2].cpu().numpy()*255).astype(np.uint8)
|
| 240 |
+
for (p, fnm) in zip(pred, fnms):
|
| 241 |
+
ds = 'SPG_preds/'
|
| 242 |
+
getdir(ds)
|
| 243 |
+
p = cv2.resize(p.squeeze(),(w,h))
|
| 244 |
+
cv2.imwrite(ds+'/'+fnm+'.png', p)
|
| 245 |
+
|
| 246 |
+
train_net_qyl(param, model, test_data1, test_data2, device=device)
|
| 247 |
+
|
| 248 |
+
|
modelsforCIML/CAAA_OK.png
ADDED
|
Git LFS Details
|
modelsforCIML/Readme.md
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
### This is the official implement of Category-Aware Auto-Annotation (CAAA)
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+

|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
The classifiers are available at [Google Drive](https://drive.google.com/file/d/1OMGtuzqhjwcvDaP3OO1njPfAS_2s0vg8/view?usp=sharing) and [Baidu Drive](https://pan.baidu.com/s/1-NidYwgVZUA0Pi0KE3ngGw?pwd=conv).
|
| 8 |
+
|
| 9 |
+
The DASS model is available at [Google Drive](https://drive.google.com/file/d/1PXL9e8XiRGlSIcGhhppLXJtVG2rdQh5a/view?usp=sharing) and [Baidu Drive](https://pan.baidu.com/s/1lmksoTe2b2xObGkhUbd5-A?pwd=DASS).
|
| 10 |
+
|
| 11 |
+
The SACM model is available at [Google Drive](https://drive.google.com/file/d/1_C5gATKv8Mh7SyKNE_ubSpXlEASkEYja/view?usp=sharing) and [Baidu Drive](https://pan.baidu.com/s/1PnLepP7bAd-8L5NcUGBx4A?pwd=SAFM).
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
To leverage the CAAA for auto-annotation, you should first categorize the image pairs (each pair contains a forged image and its authentic image) into aligned SPG and SDG. Then construct the dir structure as follows:
|
| 16 |
+
|
| 17 |
+
```
|
| 18 |
+
roots (dir of SPG or SDG pairs)
|
| 19 |
+
|
|
| 20 |
+
|---dir1
|
| 21 |
+
| |----0.jpg (authentic image)
|
| 22 |
+
| |----1.jpg (manipulated image)
|
| 23 |
+
|
|
| 24 |
+
|---dir2
|
| 25 |
+
| |----0.jpg (authentic image)
|
| 26 |
+
| |----1.jpg (manipulated image)
|
| 27 |
+
|
|
| 28 |
+
..........
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
Then run the scripts for auto-annotation.
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
Commands to run the classifier to catogerize the image pairs into SPG or SDG:
|
| 35 |
+
```
|
| 36 |
+
CUDA_VISIBLE_DEVICES=0 python classify_convxl.py
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
Commands to run the DASS to auto-annotate the image pairs in SPG:
|
| 41 |
+
```
|
| 42 |
+
CUDA_VISIBLE_DEVICES=0 python Auto_Annotate_SPG.py --pth DASS.pth
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
Commands to run the SACM to auto-annotate the image pairs in SDG:
|
| 47 |
+
|
| 48 |
+
```
|
| 49 |
+
CUDA_VISIBLE_DEVICES=0 python Auto_Annotate_SDG.py --pth SAFM.pth
|
| 50 |
+
```
|
modelsforCIML/classify_convxl.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import math
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import logging
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import torch.optim as optim
|
| 10 |
+
import torch.distributed as dist
|
| 11 |
+
import random
|
| 12 |
+
import pickle
|
| 13 |
+
from PIL import Image
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
from torch.autograd import Variable
|
| 16 |
+
from torch.utils.data import Dataset, DataLoader
|
| 17 |
+
import albumentations as A
|
| 18 |
+
from albumentations.pytorch import ToTensorV2
|
| 19 |
+
import torchvision
|
| 20 |
+
import argparse
|
| 21 |
+
parser = argparse.ArgumentParser()
|
| 22 |
+
parser.add_argument('--img_dir', type=str)
|
| 23 |
+
parser.add_argument('--model_name', type=str, default='cls')
|
| 24 |
+
parser.add_argument('--att', type=str, default='None')
|
| 25 |
+
parser.add_argument('--num', type=str, default='1')
|
| 26 |
+
parser.add_argument('--n_class', type=int, default=2)
|
| 27 |
+
parser.add_argument('--bs', type=int, default=4)
|
| 28 |
+
parser.add_argument('--es', type=int, default=0)
|
| 29 |
+
parser.add_argument('--ep', type=int, default=10)
|
| 30 |
+
parser.add_argument('--xk', type=int, default=0)
|
| 31 |
+
parser.add_argument('--numw', type=int, default=16)
|
| 32 |
+
parser.add_argument('--load', type=int, default=0)
|
| 33 |
+
parser.add_argument('--pilt', type=int, default=0)
|
| 34 |
+
parser.add_argument('--base', type=int, default=1)
|
| 35 |
+
parser.add_argument('--lr_base', type=float, default=3e-4)
|
| 36 |
+
parser.add_argument('--cp', type=float, default=1.0)
|
| 37 |
+
parser.add_argument('--mode', type=str, default='0123')
|
| 38 |
+
parser.add_argument('--local-rank', default=-1, type=int, help='node rank for distributed training')
|
| 39 |
+
parser.add_argument('--adds', type=str, default='123')
|
| 40 |
+
parser.add_argument('--lossw', type=str, default='1,2,3,4')
|
| 41 |
+
args = parser.parse_args()
|
| 42 |
+
|
| 43 |
+
from tqdm import tqdm
|
| 44 |
+
|
| 45 |
+
class CVPR24EVALDataset(Dataset):
|
| 46 |
+
def __init__(self, roots):
|
| 47 |
+
self.indexs = [(os.path.join(roots, d,'0.jpg'), os.path.join(roots, d,'1.jpg')) for d in os.listdir(roots)]
|
| 48 |
+
self.roots = roots
|
| 49 |
+
self.indexs.sort()
|
| 50 |
+
self.lens = len(self.indexs)
|
| 51 |
+
self.rsztsr = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Resize((512,512)),torchvision.transforms.Normalize(mean=((0.485, 0.455, 0.406)), std=((0.229, 0.224, 0.225)))])
|
| 52 |
+
|
| 53 |
+
def __len__(self):
|
| 54 |
+
return self.lens
|
| 55 |
+
|
| 56 |
+
def __getitem__(self, idx):
|
| 57 |
+
try:
|
| 58 |
+
img1 = cv2.cvtColor(cv2.imread(self.indexs[idx][0]),cv2.COLOR_BGR2RGB)
|
| 59 |
+
img2 = cv2.cvtColor(cv2.imread(self.indexs[idx][1]),cv2.COLOR_BGR2RGB)
|
| 60 |
+
img1 = self.rsztsr(img1)
|
| 61 |
+
img2 = self.rsztsr(img2)
|
| 62 |
+
imgs = torch.cat((img1, img2), 0)
|
| 63 |
+
return (imgs, self.indexs[idx][0], self.indexs[idx][1], False)
|
| 64 |
+
except:
|
| 65 |
+
print('error')
|
| 66 |
+
return (None, None, None, True)
|
| 67 |
+
|
| 68 |
+
device = torch.device("cuda")
|
| 69 |
+
|
| 70 |
+
roots1 = './'
|
| 71 |
+
test_data = CVPR24EVALDataset(roots1)
|
| 72 |
+
|
| 73 |
+
def get_logger(filename, verbosity=1, name=None):
|
| 74 |
+
level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}
|
| 75 |
+
formatter = logging.Formatter("[%(asctime)s][%(filename)s][%(levelname)s] %(message)s")
|
| 76 |
+
logger = logging.getLogger(name)
|
| 77 |
+
logger.setLevel(level_dict[verbosity])
|
| 78 |
+
fh = logging.FileHandler(filename, "w")
|
| 79 |
+
fh.setFormatter(formatter)
|
| 80 |
+
logger.addHandler(fh)
|
| 81 |
+
sh = logging.StreamHandler()
|
| 82 |
+
sh.setFormatter(formatter)
|
| 83 |
+
logger.addHandler(sh)
|
| 84 |
+
return logger
|
| 85 |
+
|
| 86 |
+
class AverageMeter(object):
|
| 87 |
+
def __init__(self):
|
| 88 |
+
self.reset()
|
| 89 |
+
def reset(self):
|
| 90 |
+
self.val = 0
|
| 91 |
+
self.avg = 0
|
| 92 |
+
self.sum = 0
|
| 93 |
+
self.count = 0
|
| 94 |
+
def update(self, val, n=1):
|
| 95 |
+
self.val = val
|
| 96 |
+
self.sum += val * n
|
| 97 |
+
self.count += n
|
| 98 |
+
self.avg = self.sum / self.count
|
| 99 |
+
|
| 100 |
+
def second2time(second):
|
| 101 |
+
if second < 60:
|
| 102 |
+
return str('{}'.format(round(second, 4)))
|
| 103 |
+
elif second < 60*60:
|
| 104 |
+
m = second//60
|
| 105 |
+
s = second % 60
|
| 106 |
+
return str('{}:{}'.format(int(m), round(s, 1)))
|
| 107 |
+
elif second < 60*60*60:
|
| 108 |
+
h = second//(60*60)
|
| 109 |
+
m = second % (60*60)//60
|
| 110 |
+
s = second % (60*60) % 60
|
| 111 |
+
return str('{}:{}:{}'.format(int(h), int(m), int(s)))
|
| 112 |
+
|
| 113 |
+
def inial_logger(file):
|
| 114 |
+
logger = logging.getLogger('log')
|
| 115 |
+
logger.setLevel(level=logging.DEBUG)
|
| 116 |
+
formatter = logging.Formatter('%(message)s')
|
| 117 |
+
file_handler = logging.FileHandler(file)
|
| 118 |
+
file_handler.setLevel(level=logging.INFO)
|
| 119 |
+
file_handler.setFormatter(formatter)
|
| 120 |
+
stream_handler = logging.StreamHandler()
|
| 121 |
+
stream_handler.setLevel(logging.DEBUG)
|
| 122 |
+
stream_handler.setFormatter(formatter)
|
| 123 |
+
logger.addHandler(file_handler)
|
| 124 |
+
logger.addHandler(stream_handler)
|
| 125 |
+
return logger
|
| 126 |
+
|
| 127 |
+
from functools import partial
|
| 128 |
+
import torch
|
| 129 |
+
import torch.nn as nn
|
| 130 |
+
import torch.nn.functional as F
|
| 131 |
+
from timm.models.layers import trunc_normal_, DropPath
|
| 132 |
+
from mmseg.utils import get_root_logger
|
| 133 |
+
from convnext import ConvNeXt
|
| 134 |
+
|
| 135 |
+
model=ConvNeXt(in_chans=6, depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], drop_path_rate=0.8, layer_scale_init_value=1.0, num_classes=8).to(device)
|
| 136 |
+
|
| 137 |
+
model = nn.DataParallel(model)
|
| 138 |
+
loaders = torch.load('convxl.pth',map_location='cpu')['state_dict']
|
| 139 |
+
model.load_state_dict(loaders)
|
| 140 |
+
model = model.cuda()
|
| 141 |
+
model.eval()
|
| 142 |
+
|
| 143 |
+
all_dict = {}
|
| 144 |
+
SPG = []
|
| 145 |
+
SDG = []
|
| 146 |
+
NotAlignedSPG = []
|
| 147 |
+
|
| 148 |
+
with torch.no_grad():
|
| 149 |
+
for idx in tqdm(range(len(test_data))):
|
| 150 |
+
(imgs,auth,temp,flags) = test_data.__getitem__(idx)
|
| 151 |
+
if flags:
|
| 152 |
+
continue
|
| 153 |
+
pred = model(imgs.unsqueeze(0))
|
| 154 |
+
b,c = pred.shape
|
| 155 |
+
pred = F.softmax(pred.reshape(b,c//2,2),dim=-1).cpu().numpy()
|
| 156 |
+
all_dict[temp]=(auth, pred)
|
| 157 |
+
if ((pred[0,0,1]>0.5) and (pred[0,1,1]>0.5)): # SPG
|
| 158 |
+
SPG.append((auth, temp))
|
| 159 |
+
if ((pred[0,0,0]>0.5) and (pred[0,1,0]>0.5)): # SDG
|
| 160 |
+
SDG.append((auth, temp))
|
| 161 |
+
if ((pred[0,0,1]>0.5) and (pred[0,1,0]>0.5)): # NotAlignedSPG
|
| 162 |
+
NotAlignedSPG.append((auth, temp))
|
| 163 |
+
|
| 164 |
+
with open('convxl_cls.pk','wb') as f:
|
| 165 |
+
pickle.dump(all_dict, f)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
|
modelsforCIML/convbuper.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
|
| 5 |
+
# This source code is licensed under the license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from timm.models.layers import trunc_normal_, DropPath
|
| 14 |
+
from mmseg.models.decode_heads import UPerHead,FCNHead
|
| 15 |
+
from functools import partial
|
| 16 |
+
from itertools import chain
|
| 17 |
+
from typing import Sequence
|
| 18 |
+
|
| 19 |
+
class Block(nn.Module):
|
| 20 |
+
r""" ConvNeXt Block. There are two equivalent implementations:
|
| 21 |
+
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
| 22 |
+
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
| 23 |
+
We use (2) as we find it slightly faster in PyTorch
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
dim (int): Number of input channels.
|
| 27 |
+
drop_path (float): Stochastic depth rate. Default: 0.0
|
| 28 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
| 29 |
+
"""
|
| 30 |
+
def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
|
| 33 |
+
self.norm = LayerNorm(dim, eps=1e-6)
|
| 34 |
+
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
|
| 35 |
+
self.act = nn.GELU()
|
| 36 |
+
self.pwconv2 = nn.Linear(4 * dim, dim)
|
| 37 |
+
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
|
| 38 |
+
requires_grad=True) if layer_scale_init_value > 0 else None
|
| 39 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 40 |
+
|
| 41 |
+
def forward(self, x):
|
| 42 |
+
input = x
|
| 43 |
+
x = self.dwconv(x)
|
| 44 |
+
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
| 45 |
+
x = self.norm(x)
|
| 46 |
+
x = self.pwconv1(x)
|
| 47 |
+
x = self.act(x)
|
| 48 |
+
x = self.pwconv2(x)
|
| 49 |
+
if self.gamma is not None:
|
| 50 |
+
x = self.gamma * x
|
| 51 |
+
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
| 52 |
+
|
| 53 |
+
x = input + self.drop_path(x)
|
| 54 |
+
return x
|
| 55 |
+
|
| 56 |
+
class ConvNeXt(nn.Module):
|
| 57 |
+
r""" ConvNeXt
|
| 58 |
+
A PyTorch impl of : `A ConvNet for the 2020s` -
|
| 59 |
+
https://arxiv.org/pdf/2201.03545.pdf
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
in_chans (int): Number of input image channels. Default: 3
|
| 63 |
+
num_classes (int): Number of classes for classification head. Default: 1000
|
| 64 |
+
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
|
| 65 |
+
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
|
| 66 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.
|
| 67 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
| 68 |
+
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
|
| 69 |
+
"""
|
| 70 |
+
def __init__(self, in_chans=3, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768],
|
| 71 |
+
drop_path_rate=0., layer_scale_init_value=1e-6, out_indices=[0, 1, 2, 3],
|
| 72 |
+
):
|
| 73 |
+
super().__init__()
|
| 74 |
+
|
| 75 |
+
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
|
| 76 |
+
stem = nn.Sequential(
|
| 77 |
+
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
|
| 78 |
+
LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
|
| 79 |
+
)
|
| 80 |
+
self.downsample_layers.append(stem)
|
| 81 |
+
for i in range(3):
|
| 82 |
+
downsample_layer = nn.Sequential(
|
| 83 |
+
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
|
| 84 |
+
nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
|
| 85 |
+
)
|
| 86 |
+
self.downsample_layers.append(downsample_layer)
|
| 87 |
+
|
| 88 |
+
self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
|
| 89 |
+
dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
| 90 |
+
cur = 0
|
| 91 |
+
for i in range(4):
|
| 92 |
+
stage = nn.Sequential(
|
| 93 |
+
*[Block(dim=dims[i], drop_path=dp_rates[cur + j],
|
| 94 |
+
layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
|
| 95 |
+
)
|
| 96 |
+
self.stages.append(stage)
|
| 97 |
+
cur += depths[i]
|
| 98 |
+
|
| 99 |
+
self.out_indices = out_indices
|
| 100 |
+
|
| 101 |
+
norm_layer = partial(LayerNorm, eps=1e-6, data_format="channels_first")
|
| 102 |
+
for i_layer in range(4):
|
| 103 |
+
layer = norm_layer(dims[i_layer])
|
| 104 |
+
layer_name = f'norm{i_layer}'
|
| 105 |
+
self.add_module(layer_name, layer)
|
| 106 |
+
|
| 107 |
+
self.apply(self._init_weights)
|
| 108 |
+
|
| 109 |
+
def _init_weights(self, m):
|
| 110 |
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
| 111 |
+
trunc_normal_(m.weight, std=.02)
|
| 112 |
+
nn.init.constant_(m.bias, 0)
|
| 113 |
+
|
| 114 |
+
def init_weights(self, pretrained=None):
|
| 115 |
+
"""Initialize the weights in backbone.
|
| 116 |
+
Args:
|
| 117 |
+
pretrained (str, optional): Path to pre-trained weights.
|
| 118 |
+
Defaults to None.
|
| 119 |
+
"""
|
| 120 |
+
|
| 121 |
+
def _init_weights(m):
|
| 122 |
+
if isinstance(m, nn.Linear):
|
| 123 |
+
trunc_normal_(m.weight, std=.02)
|
| 124 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 125 |
+
nn.init.constant_(m.bias, 0)
|
| 126 |
+
elif isinstance(m, nn.LayerNorm):
|
| 127 |
+
nn.init.constant_(m.bias, 0)
|
| 128 |
+
nn.init.constant_(m.weight, 1.0)
|
| 129 |
+
|
| 130 |
+
if pretrained is None:
|
| 131 |
+
self.apply(_init_weights)
|
| 132 |
+
else:
|
| 133 |
+
raise TypeError('pretrained must be a str or None')
|
| 134 |
+
|
| 135 |
+
def forward_features(self, x):
|
| 136 |
+
outs = []
|
| 137 |
+
for i in range(4):
|
| 138 |
+
x = self.downsample_layers[i](x)
|
| 139 |
+
x = self.stages[i](x)
|
| 140 |
+
if i in self.out_indices:
|
| 141 |
+
norm_layer = getattr(self, f'norm{i}')
|
| 142 |
+
x_out = norm_layer(x)
|
| 143 |
+
outs.append(x_out)
|
| 144 |
+
|
| 145 |
+
return tuple(outs)
|
| 146 |
+
|
| 147 |
+
def forward(self, x):
|
| 148 |
+
x = self.forward_features(x)
|
| 149 |
+
return x
|
| 150 |
+
|
| 151 |
+
class LayerNorm(nn.Module):
|
| 152 |
+
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
| 153 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
| 154 |
+
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
| 155 |
+
with shape (batch_size, channels, height, width).
|
| 156 |
+
"""
|
| 157 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
| 158 |
+
super().__init__()
|
| 159 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 160 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
| 161 |
+
self.eps = eps
|
| 162 |
+
self.data_format = data_format
|
| 163 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
| 164 |
+
raise NotImplementedError
|
| 165 |
+
self.normalized_shape = (normalized_shape, )
|
| 166 |
+
|
| 167 |
+
def forward(self, x):
|
| 168 |
+
if self.data_format == "channels_last":
|
| 169 |
+
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
| 170 |
+
elif self.data_format == "channels_first":
|
| 171 |
+
u = x.mean(1, keepdim=True)
|
| 172 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
| 173 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
| 174 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
| 175 |
+
return x
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class ConvBUPer(nn.Module):
|
| 179 |
+
def __init__(self,):
|
| 180 |
+
super(ConvBUPer, self).__init__()
|
| 181 |
+
self.backbone = ConvNeXt(in_chans=3, depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], drop_path_rate=0.4)
|
| 182 |
+
self.decode_head = UPerHead(
|
| 183 |
+
in_channels=[128, 256, 512, 1024],
|
| 184 |
+
in_index=[0,1,2,3],
|
| 185 |
+
pool_scales=(1,2,3,6),
|
| 186 |
+
channels=512,
|
| 187 |
+
dropout_ratio=0.1,
|
| 188 |
+
num_classes=2,
|
| 189 |
+
norm_cfg=dict(type='SyncBN'),
|
| 190 |
+
#norm_cfg=dict(type='SyncBN'),
|
| 191 |
+
)
|
| 192 |
+
self.auxiliary_head = FCNHead(
|
| 193 |
+
in_channels=512,
|
| 194 |
+
in_index=2,
|
| 195 |
+
channels=256,
|
| 196 |
+
num_convs=1,
|
| 197 |
+
concat_input=False,
|
| 198 |
+
dropout_ratio=0.1,
|
| 199 |
+
num_classes=2,
|
| 200 |
+
align_corners=False,
|
| 201 |
+
norm_cfg=dict(type='SyncBN'),
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
def forward(self,x):
|
| 205 |
+
outs = self.backbone(x)
|
| 206 |
+
outs = self.decode_head(outs)
|
| 207 |
+
return outs
|
| 208 |
+
|
| 209 |
+
|
modelsforCIML/convnext.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
|
| 5 |
+
# This source code is licensed under the license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
from functools import partial
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from timm.models.layers import trunc_normal_, DropPath
|
| 14 |
+
|
| 15 |
+
class Block(nn.Module):
|
| 16 |
+
r""" ConvNeXt Block. There are two equivalent implementations:
|
| 17 |
+
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
| 18 |
+
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
| 19 |
+
We use (2) as we find it slightly faster in PyTorch
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
dim (int): Number of input channels.
|
| 23 |
+
drop_path (float): Stochastic depth rate. Default: 0.0
|
| 24 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
| 25 |
+
"""
|
| 26 |
+
def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
|
| 29 |
+
self.norm = LayerNorm(dim, eps=1e-6)
|
| 30 |
+
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
|
| 31 |
+
self.act = nn.GELU()
|
| 32 |
+
self.pwconv2 = nn.Linear(4 * dim, dim)
|
| 33 |
+
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
|
| 34 |
+
requires_grad=True) if layer_scale_init_value > 0 else None
|
| 35 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 36 |
+
|
| 37 |
+
def forward(self, x):
|
| 38 |
+
input = x
|
| 39 |
+
x = self.dwconv(x)
|
| 40 |
+
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
| 41 |
+
x = self.norm(x)
|
| 42 |
+
x = self.pwconv1(x)
|
| 43 |
+
x = self.act(x)
|
| 44 |
+
x = self.pwconv2(x)
|
| 45 |
+
if self.gamma is not None:
|
| 46 |
+
x = self.gamma * x
|
| 47 |
+
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
| 48 |
+
|
| 49 |
+
x = input + self.drop_path(x)
|
| 50 |
+
return x
|
| 51 |
+
|
| 52 |
+
class ConvNeXt(nn.Module):
|
| 53 |
+
r""" ConvNeXt
|
| 54 |
+
A PyTorch impl of : `A ConvNet for the 2020s` -
|
| 55 |
+
https://arxiv.org/pdf/2201.03545.pdf
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
in_chans (int): Number of input image channels. Default: 3
|
| 59 |
+
num_classes (int): Number of classes for classification head. Default: 1000
|
| 60 |
+
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
|
| 61 |
+
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
|
| 62 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.
|
| 63 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
| 64 |
+
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
|
| 65 |
+
"""
|
| 66 |
+
def __init__(self, in_chans=3, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768],
|
| 67 |
+
drop_path_rate=0., layer_scale_init_value=1e-6, num_classes=8,
|
| 68 |
+
):
|
| 69 |
+
super().__init__()
|
| 70 |
+
|
| 71 |
+
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
|
| 72 |
+
stem = nn.Sequential(
|
| 73 |
+
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
|
| 74 |
+
LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
|
| 75 |
+
)
|
| 76 |
+
self.downsample_layers.append(stem)
|
| 77 |
+
for i in range(3):
|
| 78 |
+
downsample_layer = nn.Sequential(
|
| 79 |
+
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
|
| 80 |
+
nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
|
| 81 |
+
)
|
| 82 |
+
self.downsample_layers.append(downsample_layer)
|
| 83 |
+
|
| 84 |
+
self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
|
| 85 |
+
dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
| 86 |
+
cur = 0
|
| 87 |
+
for i in range(4):
|
| 88 |
+
stage = nn.Sequential(
|
| 89 |
+
*[Block(dim=dims[i], drop_path=dp_rates[cur + j],
|
| 90 |
+
layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
|
| 91 |
+
)
|
| 92 |
+
self.stages.append(stage)
|
| 93 |
+
cur += depths[i]
|
| 94 |
+
|
| 95 |
+
self.fc = nn.Sequential(nn.Dropout(p=0.3), nn.AdaptiveAvgPool2d(1), nn.Flatten(1), nn.Linear(dims[-1], num_classes))
|
| 96 |
+
norm_layer = partial(LayerNorm, eps=1e-6, data_format="channels_first")
|
| 97 |
+
for i_layer in range(3,4):
|
| 98 |
+
layer = norm_layer(dims[i_layer])
|
| 99 |
+
layer_name = f'norm'
|
| 100 |
+
self.add_module(layer_name, layer)
|
| 101 |
+
|
| 102 |
+
self.apply(self._init_weights)
|
| 103 |
+
|
| 104 |
+
def _init_weights(self, m):
|
| 105 |
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
| 106 |
+
trunc_normal_(m.weight, std=.02)
|
| 107 |
+
nn.init.constant_(m.bias, 0)
|
| 108 |
+
|
| 109 |
+
def init_weights(self, pretrained=None):
|
| 110 |
+
"""Initialize the weights in backbone.
|
| 111 |
+
Args:
|
| 112 |
+
pretrained (str, optional): Path to pre-trained weights.
|
| 113 |
+
Defaults to None.
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
def _init_weights(m):
|
| 117 |
+
if isinstance(m, nn.Linear):
|
| 118 |
+
trunc_normal_(m.weight, std=.02)
|
| 119 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 120 |
+
nn.init.constant_(m.bias, 0)
|
| 121 |
+
elif isinstance(m, nn.LayerNorm):
|
| 122 |
+
nn.init.constant_(m.bias, 0)
|
| 123 |
+
nn.init.constant_(m.weight, 1.0)
|
| 124 |
+
|
| 125 |
+
self.apply(_init_weights)
|
| 126 |
+
|
| 127 |
+
def forward_features(self, x):
|
| 128 |
+
for i in range(4):
|
| 129 |
+
x = self.downsample_layers[i](x)
|
| 130 |
+
x = self.stages[i](x)
|
| 131 |
+
if i==3:
|
| 132 |
+
norm_layer = getattr(self, f'norm')
|
| 133 |
+
x_out = norm_layer(x)
|
| 134 |
+
return self.fc(x_out)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def forward(self, x):
|
| 138 |
+
x = self.forward_features(x)
|
| 139 |
+
return x
|
| 140 |
+
|
| 141 |
+
class LayerNorm(nn.Module):
|
| 142 |
+
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
| 143 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
| 144 |
+
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
| 145 |
+
with shape (batch_size, channels, height, width).
|
| 146 |
+
"""
|
| 147 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
| 148 |
+
super().__init__()
|
| 149 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 150 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
| 151 |
+
self.eps = eps
|
| 152 |
+
self.data_format = data_format
|
| 153 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
| 154 |
+
raise NotImplementedError
|
| 155 |
+
self.normalized_shape = (normalized_shape, )
|
| 156 |
+
|
| 157 |
+
def forward(self, x):
|
| 158 |
+
if self.data_format == "channels_last":
|
| 159 |
+
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
| 160 |
+
elif self.data_format == "channels_first":
|
| 161 |
+
u = x.mean(1, keepdim=True)
|
| 162 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
| 163 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
| 164 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
| 165 |
+
return x
|
modelsforCIML/dass.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
|
| 5 |
+
# This source code is licensed under the license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from mmseg.models.decode_heads import UPerLab,FCNHead
|
| 14 |
+
|
| 15 |
+
# --------------------------------------------------------
|
| 16 |
+
# InternImage
|
| 17 |
+
# Copyright (c) 2022 OpenGVLab
|
| 18 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 19 |
+
# --------------------------------------------------------
|
| 20 |
+
|
| 21 |
+
from collections import OrderedDict
|
| 22 |
+
import torch.utils.checkpoint as checkpoint
|
| 23 |
+
from timm.models.layers import trunc_normal_, DropPath
|
| 24 |
+
from mmcv.cnn import constant_init, trunc_normal_init
|
| 25 |
+
import torch.nn.functional as F
|
| 26 |
+
from torch.nn.modules.utils import _pair as to_2tuple
|
| 27 |
+
from mmcv.cnn import build_norm_layer
|
| 28 |
+
from mmcv.runner import BaseModule
|
| 29 |
+
import math
|
| 30 |
+
import warnings
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class Mlp(nn.Module):
|
| 34 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., linear=False):
|
| 35 |
+
super().__init__()
|
| 36 |
+
out_features = out_features or in_features
|
| 37 |
+
hidden_features = hidden_features or in_features
|
| 38 |
+
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
|
| 39 |
+
self.dwconv = DWConv(hidden_features)
|
| 40 |
+
self.act = act_layer()
|
| 41 |
+
self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
|
| 42 |
+
self.drop = nn.Dropout(drop)
|
| 43 |
+
self.linear = linear
|
| 44 |
+
if self.linear:
|
| 45 |
+
self.relu = nn.ReLU(inplace=True)
|
| 46 |
+
|
| 47 |
+
def forward(self, x):
|
| 48 |
+
x = self.fc1(x)
|
| 49 |
+
if self.linear:
|
| 50 |
+
x = self.relu(x)
|
| 51 |
+
x = self.dwconv(x)
|
| 52 |
+
x = self.act(x)
|
| 53 |
+
x = self.drop(x)
|
| 54 |
+
x = self.fc2(x)
|
| 55 |
+
x = self.drop(x)
|
| 56 |
+
return x
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class AttentionModule(nn.Module):
|
| 60 |
+
def __init__(self, dim):
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
|
| 63 |
+
self.conv_spatial = nn.Conv2d(
|
| 64 |
+
dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3)
|
| 65 |
+
self.conv1 = nn.Conv2d(dim, dim, 1)
|
| 66 |
+
|
| 67 |
+
def forward(self, x):
|
| 68 |
+
u = x.clone()
|
| 69 |
+
attn = self.conv0(x)
|
| 70 |
+
attn = self.conv_spatial(attn)
|
| 71 |
+
attn = self.conv1(attn)
|
| 72 |
+
return u * attn
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class SpatialAttention(nn.Module):
|
| 76 |
+
def __init__(self, d_model):
|
| 77 |
+
super().__init__()
|
| 78 |
+
self.d_model = d_model
|
| 79 |
+
self.proj_1 = nn.Conv2d(d_model, d_model, 1)
|
| 80 |
+
self.activation = nn.GELU()
|
| 81 |
+
self.spatial_gating_unit = AttentionModule(d_model)
|
| 82 |
+
self.proj_2 = nn.Conv2d(d_model, d_model, 1)
|
| 83 |
+
|
| 84 |
+
def forward(self, x):
|
| 85 |
+
shorcut = x.clone()
|
| 86 |
+
x = self.proj_1(x)
|
| 87 |
+
x = self.activation(x)
|
| 88 |
+
x = self.spatial_gating_unit(x)
|
| 89 |
+
x = self.proj_2(x)
|
| 90 |
+
x = x + shorcut
|
| 91 |
+
return x
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class Block(nn.Module):
|
| 95 |
+
|
| 96 |
+
def __init__(self,
|
| 97 |
+
dim,
|
| 98 |
+
mlp_ratio=4.,
|
| 99 |
+
drop=0.,
|
| 100 |
+
drop_path=0.,
|
| 101 |
+
act_layer=nn.GELU,
|
| 102 |
+
linear=False,
|
| 103 |
+
norm_cfg=dict(type='SyncBN', requires_grad=True)):
|
| 104 |
+
super().__init__()
|
| 105 |
+
self.norm1 = build_norm_layer(norm_cfg, dim)[1]
|
| 106 |
+
self.attn = SpatialAttention(dim)
|
| 107 |
+
self.drop_path = DropPath(
|
| 108 |
+
drop_path) if drop_path > 0. else nn.Identity()
|
| 109 |
+
|
| 110 |
+
self.norm2 = build_norm_layer(norm_cfg, dim)[1]
|
| 111 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 112 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
| 113 |
+
act_layer=act_layer, drop=drop, linear=linear)
|
| 114 |
+
layer_scale_init_value = 1e-2
|
| 115 |
+
self.layer_scale_1 = nn.Parameter(
|
| 116 |
+
layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
| 117 |
+
self.layer_scale_2 = nn.Parameter(
|
| 118 |
+
layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
| 119 |
+
|
| 120 |
+
def forward(self, x, H, W):
|
| 121 |
+
B, N, C = x.shape
|
| 122 |
+
x = x.permute(0, 2, 1).view(B, C, H, W)
|
| 123 |
+
x = x + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
|
| 124 |
+
* self.attn(self.norm1(x)))
|
| 125 |
+
x = x + self.drop_path(self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
|
| 126 |
+
* self.mlp(self.norm2(x)))
|
| 127 |
+
x = x.view(B, C, N).permute(0, 2, 1)
|
| 128 |
+
return x
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class OverlapPatchEmbed(nn.Module):
|
| 132 |
+
""" Image to Patch Embedding
|
| 133 |
+
"""
|
| 134 |
+
|
| 135 |
+
def __init__(self,
|
| 136 |
+
patch_size=7,
|
| 137 |
+
stride=4,
|
| 138 |
+
in_chans=3,
|
| 139 |
+
embed_dim=768,
|
| 140 |
+
norm_cfg=dict(type='SyncBN', requires_grad=True)):
|
| 141 |
+
super().__init__()
|
| 142 |
+
patch_size = to_2tuple(patch_size)
|
| 143 |
+
|
| 144 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
|
| 145 |
+
padding=(patch_size[0] // 2, patch_size[1] // 2))
|
| 146 |
+
self.norm = build_norm_layer(norm_cfg, embed_dim)[1]
|
| 147 |
+
|
| 148 |
+
def forward(self, x):
|
| 149 |
+
x = self.proj(x)
|
| 150 |
+
_, _, H, W = x.shape
|
| 151 |
+
x = self.norm(x)
|
| 152 |
+
|
| 153 |
+
x = x.flatten(2).transpose(1, 2)
|
| 154 |
+
|
| 155 |
+
return x, H, W
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class VAN(BaseModule):
|
| 159 |
+
def __init__(self,
|
| 160 |
+
in_chans=9,
|
| 161 |
+
embed_dims=[64, 128, 256, 512],
|
| 162 |
+
mlp_ratios=[8, 8, 4, 4],
|
| 163 |
+
drop_rate=0.,
|
| 164 |
+
drop_path_rate=0.,
|
| 165 |
+
depths=[3, 4, 6, 3],
|
| 166 |
+
num_stages=4,
|
| 167 |
+
linear=False,
|
| 168 |
+
pretrained=None,
|
| 169 |
+
init_cfg=None,
|
| 170 |
+
norm_cfg=dict(type='SyncBN', requires_grad=True)):
|
| 171 |
+
super(VAN, self).__init__(init_cfg=init_cfg)
|
| 172 |
+
|
| 173 |
+
assert not (init_cfg and pretrained), \
|
| 174 |
+
'init_cfg and pretrained cannot be set at the same time'
|
| 175 |
+
if isinstance(pretrained, str):
|
| 176 |
+
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
| 177 |
+
'please use "init_cfg" instead')
|
| 178 |
+
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
| 179 |
+
elif pretrained is not None:
|
| 180 |
+
raise TypeError('pretrained must be a str or None')
|
| 181 |
+
|
| 182 |
+
self.depths = depths
|
| 183 |
+
self.num_stages = num_stages
|
| 184 |
+
self.linear = linear
|
| 185 |
+
|
| 186 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate,
|
| 187 |
+
sum(depths))] # stochastic depth decay rule
|
| 188 |
+
cur = 0
|
| 189 |
+
|
| 190 |
+
for i in range(num_stages):
|
| 191 |
+
patch_embed = OverlapPatchEmbed(patch_size=7 if i == 0 else 3,
|
| 192 |
+
stride=4 if i == 0 else 2,
|
| 193 |
+
in_chans=in_chans if i == 0 else embed_dims[i - 1],
|
| 194 |
+
embed_dim=embed_dims[i])
|
| 195 |
+
|
| 196 |
+
block = nn.ModuleList([Block(dim=embed_dims[i],
|
| 197 |
+
mlp_ratio=mlp_ratios[i],
|
| 198 |
+
drop=drop_rate,
|
| 199 |
+
drop_path=dpr[cur + j],
|
| 200 |
+
linear=linear,
|
| 201 |
+
norm_cfg=norm_cfg)
|
| 202 |
+
for j in range(depths[i])])
|
| 203 |
+
norm = nn.LayerNorm(embed_dims[i])
|
| 204 |
+
cur += depths[i]
|
| 205 |
+
|
| 206 |
+
setattr(self, f"patch_embed{i + 1}", patch_embed)
|
| 207 |
+
setattr(self, f"block{i + 1}", block)
|
| 208 |
+
setattr(self, f"norm{i + 1}", norm)
|
| 209 |
+
|
| 210 |
+
def init_weights(self):
|
| 211 |
+
print('init cfg', self.init_cfg)
|
| 212 |
+
if self.init_cfg is None:
|
| 213 |
+
for m in self.modules():
|
| 214 |
+
if isinstance(m, nn.Linear):
|
| 215 |
+
trunc_normal_init(m, std=.02, bias=0.)
|
| 216 |
+
elif isinstance(m, nn.LayerNorm):
|
| 217 |
+
constant_init(m, val=1.0, bias=0.)
|
| 218 |
+
elif isinstance(m, nn.Conv2d):
|
| 219 |
+
fan_out = m.kernel_size[0] * m.kernel_size[
|
| 220 |
+
1] * m.out_channels
|
| 221 |
+
fan_out //= m.groups
|
| 222 |
+
normal_init(
|
| 223 |
+
m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
|
| 224 |
+
else:
|
| 225 |
+
super(VAN, self).init_weights()
|
| 226 |
+
|
| 227 |
+
def forward(self, x):
|
| 228 |
+
B = x.shape[0]
|
| 229 |
+
outs = []
|
| 230 |
+
|
| 231 |
+
for i in range(self.num_stages):
|
| 232 |
+
patch_embed = getattr(self, f"patch_embed{i + 1}")
|
| 233 |
+
block = getattr(self, f"block{i + 1}")
|
| 234 |
+
norm = getattr(self, f"norm{i + 1}")
|
| 235 |
+
x, H, W = patch_embed(x)
|
| 236 |
+
for blk in block:
|
| 237 |
+
x = blk(x, H, W)
|
| 238 |
+
x = norm(x)
|
| 239 |
+
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
| 240 |
+
outs.append(x)
|
| 241 |
+
|
| 242 |
+
return outs
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
class DWConv(nn.Module):
|
| 246 |
+
def __init__(self, dim=768):
|
| 247 |
+
super(DWConv, self).__init__()
|
| 248 |
+
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
|
| 249 |
+
|
| 250 |
+
def forward(self, x):
|
| 251 |
+
x = self.dwconv(x)
|
| 252 |
+
return x
|
| 253 |
+
|
| 254 |
+
class DASS(nn.Module):
|
| 255 |
+
def __init__(self,in_chans=6):
|
| 256 |
+
super(DASS, self).__init__()
|
| 257 |
+
self.backbone = VAN(in_chans=in_chans, embed_dims=[96, 192, 480, 768], drop_rate=0.0, drop_path_rate=0.4, depths=[3, 3, 24, 3], norm_cfg=dict(type='SyncBN', requires_grad=True))
|
| 258 |
+
self.decode_head = UPerLab(
|
| 259 |
+
in_channels=[96, 192, 480, 768],
|
| 260 |
+
in_index=[0,1,2,3],
|
| 261 |
+
pool_scales=(1,2,3,6),
|
| 262 |
+
channels=512,
|
| 263 |
+
dropout_ratio=0.1,
|
| 264 |
+
num_classes=2,
|
| 265 |
+
norm_cfg=dict(type='SyncBN'),
|
| 266 |
+
#norm_cfg=dict(type='SyncBN'),
|
| 267 |
+
)
|
| 268 |
+
self.auxiliary_head = FCNHead(
|
| 269 |
+
in_channels=480,
|
| 270 |
+
in_index=2,
|
| 271 |
+
channels=256,
|
| 272 |
+
num_convs=1,
|
| 273 |
+
concat_input=False,
|
| 274 |
+
dropout_ratio=0.1,
|
| 275 |
+
num_classes=2,
|
| 276 |
+
align_corners=False,
|
| 277 |
+
norm_cfg=dict(type='SyncBN'),
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
def forward(self,x):
|
| 281 |
+
outs = self.backbone(x)
|
| 282 |
+
if self.training:
|
| 283 |
+
out1, out3 = self.decode_head(outs)
|
| 284 |
+
out2 = self.auxiliary_head(outs)
|
| 285 |
+
return F.upsample_bilinear(out1,scale_factor=4.0),F.upsample_bilinear(out2,scale_factor=16.0),F.upsample_bilinear(out3,scale_factor=4.0)
|
| 286 |
+
else:
|
| 287 |
+
out1 = self.decode_head(outs)
|
| 288 |
+
return F.upsample_bilinear(out1,scale_factor=4.0)
|
| 289 |
+
|
modelsforCIML/mmseg/__init__.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import warnings
|
| 3 |
+
|
| 4 |
+
import mmcv
|
| 5 |
+
from packaging.version import parse
|
| 6 |
+
|
| 7 |
+
from .version import __version__, version_info
|
| 8 |
+
|
| 9 |
+
MMCV_MIN = '1.3.13'
|
| 10 |
+
MMCV_MAX = '1.8.0'
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def digit_version(version_str: str, length: int = 4):
|
| 14 |
+
"""Convert a version string into a tuple of integers.
|
| 15 |
+
|
| 16 |
+
This method is usually used for comparing two versions. For pre-release
|
| 17 |
+
versions: alpha < beta < rc.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
version_str (str): The version string.
|
| 21 |
+
length (int): The maximum number of version levels. Default: 4.
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
tuple[int]: The version info in digits (integers).
|
| 25 |
+
"""
|
| 26 |
+
version = parse(version_str)
|
| 27 |
+
assert version.release, f'failed to parse version {version_str}'
|
| 28 |
+
release = list(version.release)
|
| 29 |
+
release = release[:length]
|
| 30 |
+
if len(release) < length:
|
| 31 |
+
release = release + [0] * (length - len(release))
|
| 32 |
+
if version.is_prerelease:
|
| 33 |
+
mapping = {'a': -3, 'b': -2, 'rc': -1}
|
| 34 |
+
val = -4
|
| 35 |
+
# version.pre can be None
|
| 36 |
+
if version.pre:
|
| 37 |
+
if version.pre[0] not in mapping:
|
| 38 |
+
warnings.warn(f'unknown prerelease version {version.pre[0]}, '
|
| 39 |
+
'version checking may go wrong')
|
| 40 |
+
else:
|
| 41 |
+
val = mapping[version.pre[0]]
|
| 42 |
+
release.extend([val, version.pre[-1]])
|
| 43 |
+
else:
|
| 44 |
+
release.extend([val, 0])
|
| 45 |
+
|
| 46 |
+
elif version.is_postrelease:
|
| 47 |
+
release.extend([1, version.post])
|
| 48 |
+
else:
|
| 49 |
+
release.extend([0, 0])
|
| 50 |
+
return tuple(release)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
mmcv_min_version = digit_version(MMCV_MIN)
|
| 54 |
+
mmcv_max_version = digit_version(MMCV_MAX)
|
| 55 |
+
mmcv_version = digit_version(mmcv.__version__)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
assert (mmcv_min_version <= mmcv_version < mmcv_max_version), \
|
| 59 |
+
f'MMCV=={mmcv.__version__} is used but incompatible. ' \
|
| 60 |
+
f'Please install mmcv>={mmcv_min_version}, <{mmcv_max_version}.'
|
| 61 |
+
|
| 62 |
+
__all__ = ['__version__', 'version_info', 'digit_version']
|
modelsforCIML/mmseg/core/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from .builder import (OPTIMIZER_BUILDERS, build_optimizer,
|
| 3 |
+
build_optimizer_constructor)
|
| 4 |
+
from .evaluation import * # noqa: F401, F403
|
| 5 |
+
from .hook import * # noqa: F401, F403
|
| 6 |
+
from .optimizers import * # noqa: F401, F403
|
| 7 |
+
from .seg import * # noqa: F401, F403
|
| 8 |
+
from .utils import * # noqa: F401, F403
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
'OPTIMIZER_BUILDERS', 'build_optimizer', 'build_optimizer_constructor'
|
| 12 |
+
]
|
modelsforCIML/mmseg/core/builder.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import copy
|
| 3 |
+
|
| 4 |
+
from mmcv.runner.optimizer import OPTIMIZER_BUILDERS as MMCV_OPTIMIZER_BUILDERS
|
| 5 |
+
from mmcv.utils import Registry, build_from_cfg
|
| 6 |
+
|
| 7 |
+
OPTIMIZER_BUILDERS = Registry(
|
| 8 |
+
'optimizer builder', parent=MMCV_OPTIMIZER_BUILDERS)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def build_optimizer_constructor(cfg):
|
| 12 |
+
constructor_type = cfg.get('type')
|
| 13 |
+
if constructor_type in OPTIMIZER_BUILDERS:
|
| 14 |
+
return build_from_cfg(cfg, OPTIMIZER_BUILDERS)
|
| 15 |
+
elif constructor_type in MMCV_OPTIMIZER_BUILDERS:
|
| 16 |
+
return build_from_cfg(cfg, MMCV_OPTIMIZER_BUILDERS)
|
| 17 |
+
else:
|
| 18 |
+
raise KeyError(f'{constructor_type} is not registered '
|
| 19 |
+
'in the optimizer builder registry.')
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def build_optimizer(model, cfg):
|
| 23 |
+
optimizer_cfg = copy.deepcopy(cfg)
|
| 24 |
+
constructor_type = optimizer_cfg.pop('constructor',
|
| 25 |
+
'DefaultOptimizerConstructor')
|
| 26 |
+
paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
|
| 27 |
+
optim_constructor = build_optimizer_constructor(
|
| 28 |
+
dict(
|
| 29 |
+
type=constructor_type,
|
| 30 |
+
optimizer_cfg=optimizer_cfg,
|
| 31 |
+
paramwise_cfg=paramwise_cfg))
|
| 32 |
+
optimizer = optim_constructor(model)
|
| 33 |
+
return optimizer
|
modelsforCIML/mmseg/core/evaluation/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from .class_names import get_classes, get_palette
|
| 3 |
+
from .eval_hooks import DistEvalHook, EvalHook
|
| 4 |
+
from .metrics import (eval_metrics, intersect_and_union, mean_dice,
|
| 5 |
+
mean_fscore, mean_iou, pre_eval_to_metrics)
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'mean_fscore',
|
| 9 |
+
'eval_metrics', 'get_classes', 'get_palette', 'pre_eval_to_metrics',
|
| 10 |
+
'intersect_and_union'
|
| 11 |
+
]
|
modelsforCIML/mmseg/core/evaluation/class_names.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import mmcv
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def cityscapes_classes():
|
| 6 |
+
"""Cityscapes class names for external use."""
|
| 7 |
+
return [
|
| 8 |
+
'road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
|
| 9 |
+
'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
|
| 10 |
+
'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
|
| 11 |
+
'bicycle'
|
| 12 |
+
]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def ade_classes():
|
| 16 |
+
"""ADE20K class names for external use."""
|
| 17 |
+
return [
|
| 18 |
+
'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ',
|
| 19 |
+
'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth',
|
| 20 |
+
'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car',
|
| 21 |
+
'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug',
|
| 22 |
+
'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe',
|
| 23 |
+
'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
|
| 24 |
+
'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
|
| 25 |
+
'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path',
|
| 26 |
+
'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door',
|
| 27 |
+
'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table',
|
| 28 |
+
'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove',
|
| 29 |
+
'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',
|
| 30 |
+
'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
|
| 31 |
+
'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',
|
| 32 |
+
'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister',
|
| 33 |
+
'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',
|
| 34 |
+
'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything',
|
| 35 |
+
'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent',
|
| 36 |
+
'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',
|
| 37 |
+
'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',
|
| 38 |
+
'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',
|
| 39 |
+
'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen',
|
| 40 |
+
'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
|
| 41 |
+
'clock', 'flag'
|
| 42 |
+
]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def voc_classes():
|
| 46 |
+
"""Pascal VOC class names for external use."""
|
| 47 |
+
return [
|
| 48 |
+
'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
|
| 49 |
+
'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
|
| 50 |
+
'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
|
| 51 |
+
'tvmonitor'
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def cocostuff_classes():
|
| 56 |
+
"""CocoStuff class names for external use."""
|
| 57 |
+
return [
|
| 58 |
+
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
|
| 59 |
+
'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
|
| 60 |
+
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
|
| 61 |
+
'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
|
| 62 |
+
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
|
| 63 |
+
'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
|
| 64 |
+
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
|
| 65 |
+
'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
|
| 66 |
+
'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
|
| 67 |
+
'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
|
| 68 |
+
'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
|
| 69 |
+
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
|
| 70 |
+
'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner',
|
| 71 |
+
'blanket', 'branch', 'bridge', 'building-other', 'bush', 'cabinet',
|
| 72 |
+
'cage', 'cardboard', 'carpet', 'ceiling-other', 'ceiling-tile',
|
| 73 |
+
'cloth', 'clothes', 'clouds', 'counter', 'cupboard', 'curtain',
|
| 74 |
+
'desk-stuff', 'dirt', 'door-stuff', 'fence', 'floor-marble',
|
| 75 |
+
'floor-other', 'floor-stone', 'floor-tile', 'floor-wood', 'flower',
|
| 76 |
+
'fog', 'food-other', 'fruit', 'furniture-other', 'grass', 'gravel',
|
| 77 |
+
'ground-other', 'hill', 'house', 'leaves', 'light', 'mat', 'metal',
|
| 78 |
+
'mirror-stuff', 'moss', 'mountain', 'mud', 'napkin', 'net', 'paper',
|
| 79 |
+
'pavement', 'pillow', 'plant-other', 'plastic', 'platform',
|
| 80 |
+
'playingfield', 'railing', 'railroad', 'river', 'road', 'rock', 'roof',
|
| 81 |
+
'rug', 'salad', 'sand', 'sea', 'shelf', 'sky-other', 'skyscraper',
|
| 82 |
+
'snow', 'solid-other', 'stairs', 'stone', 'straw', 'structural-other',
|
| 83 |
+
'table', 'tent', 'textile-other', 'towel', 'tree', 'vegetable',
|
| 84 |
+
'wall-brick', 'wall-concrete', 'wall-other', 'wall-panel',
|
| 85 |
+
'wall-stone', 'wall-tile', 'wall-wood', 'water-other', 'waterdrops',
|
| 86 |
+
'window-blind', 'window-other', 'wood'
|
| 87 |
+
]
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def loveda_classes():
|
| 91 |
+
"""LoveDA class names for external use."""
|
| 92 |
+
return [
|
| 93 |
+
'background', 'building', 'road', 'water', 'barren', 'forest',
|
| 94 |
+
'agricultural'
|
| 95 |
+
]
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def potsdam_classes():
|
| 99 |
+
"""Potsdam class names for external use."""
|
| 100 |
+
return [
|
| 101 |
+
'impervious_surface', 'building', 'low_vegetation', 'tree', 'car',
|
| 102 |
+
'clutter'
|
| 103 |
+
]
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def vaihingen_classes():
|
| 107 |
+
"""Vaihingen class names for external use."""
|
| 108 |
+
return [
|
| 109 |
+
'impervious_surface', 'building', 'low_vegetation', 'tree', 'car',
|
| 110 |
+
'clutter'
|
| 111 |
+
]
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def isaid_classes():
|
| 115 |
+
"""iSAID class names for external use."""
|
| 116 |
+
return [
|
| 117 |
+
'background', 'ship', 'store_tank', 'baseball_diamond', 'tennis_court',
|
| 118 |
+
'basketball_court', 'Ground_Track_Field', 'Bridge', 'Large_Vehicle',
|
| 119 |
+
'Small_Vehicle', 'Helicopter', 'Swimming_pool', 'Roundabout',
|
| 120 |
+
'Soccer_ball_field', 'plane', 'Harbor'
|
| 121 |
+
]
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def stare_classes():
|
| 125 |
+
"""stare class names for external use."""
|
| 126 |
+
return ['background', 'vessel']
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def occludedface_classes():
|
| 130 |
+
"""occludedface class names for external use."""
|
| 131 |
+
return ['background', 'face']
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def cityscapes_palette():
|
| 135 |
+
"""Cityscapes palette for external use."""
|
| 136 |
+
return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
|
| 137 |
+
[190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
|
| 138 |
+
[107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
|
| 139 |
+
[255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
|
| 140 |
+
[0, 0, 230], [119, 11, 32]]
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def ade_palette():
|
| 144 |
+
"""ADE20K palette for external use."""
|
| 145 |
+
return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
|
| 146 |
+
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
|
| 147 |
+
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
|
| 148 |
+
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
|
| 149 |
+
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
|
| 150 |
+
[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
|
| 151 |
+
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
|
| 152 |
+
[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
|
| 153 |
+
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
|
| 154 |
+
[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
|
| 155 |
+
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
|
| 156 |
+
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
|
| 157 |
+
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
|
| 158 |
+
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
|
| 159 |
+
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
|
| 160 |
+
[11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
|
| 161 |
+
[0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
|
| 162 |
+
[255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
|
| 163 |
+
[0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
|
| 164 |
+
[173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
|
| 165 |
+
[255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
|
| 166 |
+
[255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
|
| 167 |
+
[255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
|
| 168 |
+
[0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
|
| 169 |
+
[0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
|
| 170 |
+
[143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
|
| 171 |
+
[8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
|
| 172 |
+
[255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
|
| 173 |
+
[92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
|
| 174 |
+
[163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
|
| 175 |
+
[255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
|
| 176 |
+
[255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
|
| 177 |
+
[10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
|
| 178 |
+
[255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
|
| 179 |
+
[41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
|
| 180 |
+
[71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
|
| 181 |
+
[184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
|
| 182 |
+
[102, 255, 0], [92, 0, 255]]
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def voc_palette():
|
| 186 |
+
"""Pascal VOC palette for external use."""
|
| 187 |
+
return [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
|
| 188 |
+
[128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
|
| 189 |
+
[192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
|
| 190 |
+
[192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
|
| 191 |
+
[128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def cocostuff_palette():
|
| 195 |
+
"""CocoStuff palette for external use."""
|
| 196 |
+
return [[0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192],
|
| 197 |
+
[0, 64, 64], [0, 192, 224], [0, 192, 192], [128, 192, 64],
|
| 198 |
+
[0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224],
|
| 199 |
+
[0, 0, 64], [0, 160, 192], [128, 0, 96], [128, 0, 192],
|
| 200 |
+
[0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192],
|
| 201 |
+
[128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128],
|
| 202 |
+
[64, 128, 32], [0, 160, 0], [0, 0, 0], [192, 128, 160], [0, 32, 0],
|
| 203 |
+
[0, 128, 128], [64, 128, 160], [128, 160, 0], [0, 128, 0],
|
| 204 |
+
[192, 128, 32], [128, 96, 128], [0, 0, 128], [64, 0, 32],
|
| 205 |
+
[0, 224, 128], [128, 0, 0], [192, 0, 160], [0, 96, 128],
|
| 206 |
+
[128, 128, 128], [64, 0, 160], [128, 224, 128], [128, 128, 64],
|
| 207 |
+
[192, 0, 32], [128, 96, 0], [128, 0, 192], [0, 128, 32],
|
| 208 |
+
[64, 224, 0], [0, 0, 64], [128, 128, 160], [64, 96, 0],
|
| 209 |
+
[0, 128, 192], [0, 128, 160], [192, 224, 0], [0, 128, 64],
|
| 210 |
+
[128, 128, 32], [192, 32, 128], [0, 64, 192], [0, 0, 32],
|
| 211 |
+
[64, 160, 128], [128, 64, 64], [128, 0, 160], [64, 32, 128],
|
| 212 |
+
[128, 192, 192], [0, 0, 160], [192, 160, 128], [128, 192, 0],
|
| 213 |
+
[128, 0, 96], [192, 32, 0], [128, 64, 128], [64, 128, 96],
|
| 214 |
+
[64, 160, 0], [0, 64, 0], [192, 128, 224], [64, 32, 0],
|
| 215 |
+
[0, 192, 128], [64, 128, 224], [192, 160, 0], [0, 192, 0],
|
| 216 |
+
[192, 128, 96], [192, 96, 128], [0, 64, 128], [64, 0, 96],
|
| 217 |
+
[64, 224, 128], [128, 64, 0], [192, 0, 224], [64, 96, 128],
|
| 218 |
+
[128, 192, 128], [64, 0, 224], [192, 224, 128], [128, 192, 64],
|
| 219 |
+
[192, 0, 96], [192, 96, 0], [128, 64, 192], [0, 128, 96],
|
| 220 |
+
[0, 224, 0], [64, 64, 64], [128, 128, 224], [0, 96, 0],
|
| 221 |
+
[64, 192, 192], [0, 128, 224], [128, 224, 0], [64, 192, 64],
|
| 222 |
+
[128, 128, 96], [128, 32, 128], [64, 0, 192], [0, 64, 96],
|
| 223 |
+
[0, 160, 128], [192, 0, 64], [128, 64, 224], [0, 32, 128],
|
| 224 |
+
[192, 128, 192], [0, 64, 224], [128, 160, 128], [192, 128, 0],
|
| 225 |
+
[128, 64, 32], [128, 32, 64], [192, 0, 128], [64, 192, 32],
|
| 226 |
+
[0, 160, 64], [64, 0, 0], [192, 192, 160], [0, 32, 64],
|
| 227 |
+
[64, 128, 128], [64, 192, 160], [128, 160, 64], [64, 128, 0],
|
| 228 |
+
[192, 192, 32], [128, 96, 192], [64, 0, 128], [64, 64, 32],
|
| 229 |
+
[0, 224, 192], [192, 0, 0], [192, 64, 160], [0, 96, 192],
|
| 230 |
+
[192, 128, 128], [64, 64, 160], [128, 224, 192], [192, 128, 64],
|
| 231 |
+
[192, 64, 32], [128, 96, 64], [192, 0, 192], [0, 192, 32],
|
| 232 |
+
[64, 224, 64], [64, 0, 64], [128, 192, 160], [64, 96, 64],
|
| 233 |
+
[64, 128, 192], [0, 192, 160], [192, 224, 64], [64, 128, 64],
|
| 234 |
+
[128, 192, 32], [192, 32, 192], [64, 64, 192], [0, 64, 32],
|
| 235 |
+
[64, 160, 192], [192, 64, 64], [128, 64, 160], [64, 32, 192],
|
| 236 |
+
[192, 192, 192], [0, 64, 160], [192, 160, 192], [192, 192, 0],
|
| 237 |
+
[128, 64, 96], [192, 32, 64], [192, 64, 128], [64, 192, 96],
|
| 238 |
+
[64, 160, 64], [64, 64, 0]]
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def loveda_palette():
|
| 242 |
+
"""LoveDA palette for external use."""
|
| 243 |
+
return [[255, 255, 255], [255, 0, 0], [255, 255, 0], [0, 0, 255],
|
| 244 |
+
[159, 129, 183], [0, 255, 0], [255, 195, 128]]
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def potsdam_palette():
|
| 248 |
+
"""Potsdam palette for external use."""
|
| 249 |
+
return [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0],
|
| 250 |
+
[255, 255, 0], [255, 0, 0]]
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def vaihingen_palette():
|
| 254 |
+
"""Vaihingen palette for external use."""
|
| 255 |
+
return [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0],
|
| 256 |
+
[255, 255, 0], [255, 0, 0]]
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def isaid_palette():
|
| 260 |
+
"""iSAID palette for external use."""
|
| 261 |
+
return [[0, 0, 0], [0, 0, 63], [0, 63, 63], [0, 63, 0], [0, 63, 127],
|
| 262 |
+
[0, 63, 191], [0, 63, 255], [0, 127, 63], [0, 127,
|
| 263 |
+
127], [0, 0, 127],
|
| 264 |
+
[0, 0, 191], [0, 0, 255], [0, 191, 127], [0, 127, 191],
|
| 265 |
+
[0, 127, 255], [0, 100, 155]]
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def stare_palette():
|
| 269 |
+
"""STARE palette for external use."""
|
| 270 |
+
return [[120, 120, 120], [6, 230, 230]]
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def occludedface_palette():
|
| 274 |
+
"""occludedface palette for external use."""
|
| 275 |
+
return [[0, 0, 0], [128, 0, 0]]
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
dataset_aliases = {
|
| 279 |
+
'cityscapes': ['cityscapes'],
|
| 280 |
+
'ade': ['ade', 'ade20k'],
|
| 281 |
+
'voc': ['voc', 'pascal_voc', 'voc12', 'voc12aug'],
|
| 282 |
+
'loveda': ['loveda'],
|
| 283 |
+
'potsdam': ['potsdam'],
|
| 284 |
+
'vaihingen': ['vaihingen'],
|
| 285 |
+
'cocostuff': [
|
| 286 |
+
'cocostuff', 'cocostuff10k', 'cocostuff164k', 'coco-stuff',
|
| 287 |
+
'coco-stuff10k', 'coco-stuff164k', 'coco_stuff', 'coco_stuff10k',
|
| 288 |
+
'coco_stuff164k'
|
| 289 |
+
],
|
| 290 |
+
'isaid': ['isaid', 'iSAID'],
|
| 291 |
+
'stare': ['stare', 'STARE'],
|
| 292 |
+
'occludedface': ['occludedface']
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def get_classes(dataset):
|
| 297 |
+
"""Get class names of a dataset."""
|
| 298 |
+
alias2name = {}
|
| 299 |
+
for name, aliases in dataset_aliases.items():
|
| 300 |
+
for alias in aliases:
|
| 301 |
+
alias2name[alias] = name
|
| 302 |
+
|
| 303 |
+
if mmcv.is_str(dataset):
|
| 304 |
+
if dataset in alias2name:
|
| 305 |
+
labels = eval(alias2name[dataset] + '_classes()')
|
| 306 |
+
else:
|
| 307 |
+
raise ValueError(f'Unrecognized dataset: {dataset}')
|
| 308 |
+
else:
|
| 309 |
+
raise TypeError(f'dataset must a str, but got {type(dataset)}')
|
| 310 |
+
return labels
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def get_palette(dataset):
|
| 314 |
+
"""Get class palette (RGB) of a dataset."""
|
| 315 |
+
alias2name = {}
|
| 316 |
+
for name, aliases in dataset_aliases.items():
|
| 317 |
+
for alias in aliases:
|
| 318 |
+
alias2name[alias] = name
|
| 319 |
+
|
| 320 |
+
if mmcv.is_str(dataset):
|
| 321 |
+
if dataset in alias2name:
|
| 322 |
+
labels = eval(alias2name[dataset] + '_palette()')
|
| 323 |
+
else:
|
| 324 |
+
raise ValueError(f'Unrecognized dataset: {dataset}')
|
| 325 |
+
else:
|
| 326 |
+
raise TypeError(f'dataset must a str, but got {type(dataset)}')
|
| 327 |
+
return labels
|
modelsforCIML/mmseg/core/evaluation/eval_hooks.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import os.path as osp
|
| 3 |
+
import warnings
|
| 4 |
+
|
| 5 |
+
import torch.distributed as dist
|
| 6 |
+
from mmcv.runner import DistEvalHook as _DistEvalHook
|
| 7 |
+
from mmcv.runner import EvalHook as _EvalHook
|
| 8 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class EvalHook(_EvalHook):
|
| 12 |
+
"""Single GPU EvalHook, with efficient test support.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
by_epoch (bool): Determine perform evaluation by epoch or by iteration.
|
| 16 |
+
If set to True, it will perform by epoch. Otherwise, by iteration.
|
| 17 |
+
Default: False.
|
| 18 |
+
efficient_test (bool): Whether save the results as local numpy files to
|
| 19 |
+
save CPU memory during evaluation. Default: False.
|
| 20 |
+
pre_eval (bool): Whether to use progressive mode to evaluate model.
|
| 21 |
+
Default: False.
|
| 22 |
+
Returns:
|
| 23 |
+
list: The prediction results.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
greater_keys = ['mIoU', 'mAcc', 'aAcc']
|
| 27 |
+
|
| 28 |
+
def __init__(self,
|
| 29 |
+
*args,
|
| 30 |
+
by_epoch=False,
|
| 31 |
+
efficient_test=False,
|
| 32 |
+
pre_eval=False,
|
| 33 |
+
**kwargs):
|
| 34 |
+
super().__init__(*args, by_epoch=by_epoch, **kwargs)
|
| 35 |
+
self.pre_eval = pre_eval
|
| 36 |
+
self.latest_results = None
|
| 37 |
+
|
| 38 |
+
if efficient_test:
|
| 39 |
+
warnings.warn(
|
| 40 |
+
'DeprecationWarning: ``efficient_test`` for evaluation hook '
|
| 41 |
+
'is deprecated, the evaluation hook is CPU memory friendly '
|
| 42 |
+
'with ``pre_eval=True`` as argument for ``single_gpu_test()`` '
|
| 43 |
+
'function')
|
| 44 |
+
|
| 45 |
+
def _do_evaluate(self, runner):
|
| 46 |
+
"""perform evaluation and save ckpt."""
|
| 47 |
+
if not self._should_evaluate(runner):
|
| 48 |
+
return
|
| 49 |
+
|
| 50 |
+
from mmseg.apis import single_gpu_test
|
| 51 |
+
results = single_gpu_test(
|
| 52 |
+
runner.model, self.dataloader, show=False, pre_eval=self.pre_eval)
|
| 53 |
+
self.latest_results = results
|
| 54 |
+
runner.log_buffer.clear()
|
| 55 |
+
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
|
| 56 |
+
key_score = self.evaluate(runner, results)
|
| 57 |
+
if self.save_best:
|
| 58 |
+
self._save_ckpt(runner, key_score)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class DistEvalHook(_DistEvalHook):
|
| 62 |
+
"""Distributed EvalHook, with efficient test support.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
by_epoch (bool): Determine perform evaluation by epoch or by iteration.
|
| 66 |
+
If set to True, it will perform by epoch. Otherwise, by iteration.
|
| 67 |
+
Default: False.
|
| 68 |
+
efficient_test (bool): Whether save the results as local numpy files to
|
| 69 |
+
save CPU memory during evaluation. Default: False.
|
| 70 |
+
pre_eval (bool): Whether to use progressive mode to evaluate model.
|
| 71 |
+
Default: False.
|
| 72 |
+
Returns:
|
| 73 |
+
list: The prediction results.
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
greater_keys = ['mIoU', 'mAcc', 'aAcc']
|
| 77 |
+
|
| 78 |
+
def __init__(self,
|
| 79 |
+
*args,
|
| 80 |
+
by_epoch=False,
|
| 81 |
+
efficient_test=False,
|
| 82 |
+
pre_eval=False,
|
| 83 |
+
**kwargs):
|
| 84 |
+
super().__init__(*args, by_epoch=by_epoch, **kwargs)
|
| 85 |
+
self.pre_eval = pre_eval
|
| 86 |
+
self.latest_results = None
|
| 87 |
+
if efficient_test:
|
| 88 |
+
warnings.warn(
|
| 89 |
+
'DeprecationWarning: ``efficient_test`` for evaluation hook '
|
| 90 |
+
'is deprecated, the evaluation hook is CPU memory friendly '
|
| 91 |
+
'with ``pre_eval=True`` as argument for ``multi_gpu_test()`` '
|
| 92 |
+
'function')
|
| 93 |
+
|
| 94 |
+
def _do_evaluate(self, runner):
|
| 95 |
+
"""perform evaluation and save ckpt."""
|
| 96 |
+
# Synchronization of BatchNorm's buffer (running_mean
|
| 97 |
+
# and running_var) is not supported in the DDP of pytorch,
|
| 98 |
+
# which may cause the inconsistent performance of models in
|
| 99 |
+
# different ranks, so we broadcast BatchNorm's buffers
|
| 100 |
+
# of rank 0 to other ranks to avoid this.
|
| 101 |
+
if self.broadcast_bn_buffer:
|
| 102 |
+
model = runner.model
|
| 103 |
+
for name, module in model.named_modules():
|
| 104 |
+
if isinstance(module,
|
| 105 |
+
_BatchNorm) and module.track_running_stats:
|
| 106 |
+
dist.broadcast(module.running_var, 0)
|
| 107 |
+
dist.broadcast(module.running_mean, 0)
|
| 108 |
+
|
| 109 |
+
if not self._should_evaluate(runner):
|
| 110 |
+
return
|
| 111 |
+
|
| 112 |
+
tmpdir = self.tmpdir
|
| 113 |
+
if tmpdir is None:
|
| 114 |
+
tmpdir = osp.join(runner.work_dir, '.eval_hook')
|
| 115 |
+
|
| 116 |
+
from mmseg.apis import multi_gpu_test
|
| 117 |
+
results = multi_gpu_test(
|
| 118 |
+
runner.model,
|
| 119 |
+
self.dataloader,
|
| 120 |
+
tmpdir=tmpdir,
|
| 121 |
+
gpu_collect=self.gpu_collect,
|
| 122 |
+
pre_eval=self.pre_eval)
|
| 123 |
+
self.latest_results = results
|
| 124 |
+
runner.log_buffer.clear()
|
| 125 |
+
|
| 126 |
+
if runner.rank == 0:
|
| 127 |
+
print('\n')
|
| 128 |
+
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
|
| 129 |
+
key_score = self.evaluate(runner, results)
|
| 130 |
+
|
| 131 |
+
if self.save_best:
|
| 132 |
+
self._save_ckpt(runner, key_score)
|
modelsforCIML/mmseg/core/evaluation/metrics.py
ADDED
|
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from collections import OrderedDict
|
| 3 |
+
|
| 4 |
+
import mmcv
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def f_score(precision, recall, beta=1):
|
| 10 |
+
"""calculate the f-score value.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
precision (float | torch.Tensor): The precision value.
|
| 14 |
+
recall (float | torch.Tensor): The recall value.
|
| 15 |
+
beta (int): Determines the weight of recall in the combined score.
|
| 16 |
+
Default: False.
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
[torch.tensor]: The f-score value.
|
| 20 |
+
"""
|
| 21 |
+
score = (1 + beta**2) * (precision * recall) / (
|
| 22 |
+
(beta**2 * precision) + recall)
|
| 23 |
+
return score
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def intersect_and_union(pred_label,
|
| 27 |
+
label,
|
| 28 |
+
num_classes,
|
| 29 |
+
ignore_index,
|
| 30 |
+
label_map=dict(),
|
| 31 |
+
reduce_zero_label=False):
|
| 32 |
+
"""Calculate intersection and Union.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
pred_label (ndarray | str): Prediction segmentation map
|
| 36 |
+
or predict result filename.
|
| 37 |
+
label (ndarray | str): Ground truth segmentation map
|
| 38 |
+
or label filename.
|
| 39 |
+
num_classes (int): Number of categories.
|
| 40 |
+
ignore_index (int): Index that will be ignored in evaluation.
|
| 41 |
+
label_map (dict): Mapping old labels to new labels. The parameter will
|
| 42 |
+
work only when label is str. Default: dict().
|
| 43 |
+
reduce_zero_label (bool): Whether ignore zero label. The parameter will
|
| 44 |
+
work only when label is str. Default: False.
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
torch.Tensor: The intersection of prediction and ground truth
|
| 48 |
+
histogram on all classes.
|
| 49 |
+
torch.Tensor: The union of prediction and ground truth histogram on
|
| 50 |
+
all classes.
|
| 51 |
+
torch.Tensor: The prediction histogram on all classes.
|
| 52 |
+
torch.Tensor: The ground truth histogram on all classes.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
if isinstance(pred_label, str):
|
| 56 |
+
pred_label = torch.from_numpy(np.load(pred_label))
|
| 57 |
+
else:
|
| 58 |
+
# pred_label = torch.from_numpy((pred_label))
|
| 59 |
+
pass
|
| 60 |
+
|
| 61 |
+
if isinstance(label, str):
|
| 62 |
+
label = torch.from_numpy(
|
| 63 |
+
mmcv.imread(label, flag='unchanged', backend='pillow'))
|
| 64 |
+
else:
|
| 65 |
+
label = torch.from_numpy(label)
|
| 66 |
+
|
| 67 |
+
if reduce_zero_label:
|
| 68 |
+
label[label == 0] = 255
|
| 69 |
+
label = label - 1
|
| 70 |
+
label[label == 254] = 255
|
| 71 |
+
if label_map is not None:
|
| 72 |
+
label_copy = label.clone()
|
| 73 |
+
for old_id, new_id in label_map.items():
|
| 74 |
+
label[label_copy == old_id] = new_id
|
| 75 |
+
|
| 76 |
+
mask = (label != ignore_index)
|
| 77 |
+
# print(mask.shape, pred_label.shape)
|
| 78 |
+
pred_label = pred_label[mask]
|
| 79 |
+
label = label[mask]
|
| 80 |
+
|
| 81 |
+
intersect = pred_label[pred_label == label]
|
| 82 |
+
area_intersect = torch.histc(
|
| 83 |
+
intersect.float(), bins=(num_classes), min=0, max=num_classes - 1)
|
| 84 |
+
area_pred_label = torch.histc(
|
| 85 |
+
pred_label.float(), bins=(num_classes), min=0, max=num_classes - 1)
|
| 86 |
+
area_label = torch.histc(
|
| 87 |
+
label.float(), bins=(num_classes), min=0, max=num_classes - 1)
|
| 88 |
+
area_union = area_pred_label + area_label - area_intersect
|
| 89 |
+
return area_intersect, area_union, area_pred_label, area_label
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def total_intersect_and_union(results,
|
| 93 |
+
gt_seg_maps,
|
| 94 |
+
num_classes,
|
| 95 |
+
ignore_index,
|
| 96 |
+
label_map=dict(),
|
| 97 |
+
reduce_zero_label=False):
|
| 98 |
+
"""Calculate Total Intersection and Union.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
results (list[ndarray] | list[str]): List of prediction segmentation
|
| 102 |
+
maps or list of prediction result filenames.
|
| 103 |
+
gt_seg_maps (list[ndarray] | list[str] | Iterables): list of ground
|
| 104 |
+
truth segmentation maps or list of label filenames.
|
| 105 |
+
num_classes (int): Number of categories.
|
| 106 |
+
ignore_index (int): Index that will be ignored in evaluation.
|
| 107 |
+
label_map (dict): Mapping old labels to new labels. Default: dict().
|
| 108 |
+
reduce_zero_label (bool): Whether ignore zero label. Default: False.
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
ndarray: The intersection of prediction and ground truth histogram
|
| 112 |
+
on all classes.
|
| 113 |
+
ndarray: The union of prediction and ground truth histogram on all
|
| 114 |
+
classes.
|
| 115 |
+
ndarray: The prediction histogram on all classes.
|
| 116 |
+
ndarray: The ground truth histogram on all classes.
|
| 117 |
+
"""
|
| 118 |
+
# print('ss1',len(results),len(gt_seg_maps))
|
| 119 |
+
total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64)
|
| 120 |
+
total_area_union = torch.zeros((num_classes, ), dtype=torch.float64)
|
| 121 |
+
total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64)
|
| 122 |
+
total_area_label = torch.zeros((num_classes, ), dtype=torch.float64)
|
| 123 |
+
for result, gt_seg_map in zip(results, gt_seg_maps):
|
| 124 |
+
area_intersect, area_union, area_pred_label, area_label = \
|
| 125 |
+
intersect_and_union(
|
| 126 |
+
result, gt_seg_map, num_classes, ignore_index,
|
| 127 |
+
label_map, reduce_zero_label)
|
| 128 |
+
total_area_intersect += area_intersect
|
| 129 |
+
total_area_union += area_union
|
| 130 |
+
total_area_pred_label += area_pred_label
|
| 131 |
+
total_area_label += area_label
|
| 132 |
+
return total_area_intersect, total_area_union, total_area_pred_label, \
|
| 133 |
+
total_area_label
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def mean_iou(results,
|
| 137 |
+
gt_seg_maps,
|
| 138 |
+
num_classes,
|
| 139 |
+
ignore_index,
|
| 140 |
+
nan_to_num=None,
|
| 141 |
+
label_map=dict(),
|
| 142 |
+
reduce_zero_label=False):
|
| 143 |
+
"""Calculate Mean Intersection and Union (mIoU)
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
results (list[ndarray] | list[str]): List of prediction segmentation
|
| 147 |
+
maps or list of prediction result filenames.
|
| 148 |
+
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
|
| 149 |
+
segmentation maps or list of label filenames.
|
| 150 |
+
num_classes (int): Number of categories.
|
| 151 |
+
ignore_index (int): Index that will be ignored in evaluation.
|
| 152 |
+
nan_to_num (int, optional): If specified, NaN values will be replaced
|
| 153 |
+
by the numbers defined by the user. Default: None.
|
| 154 |
+
label_map (dict): Mapping old labels to new labels. Default: dict().
|
| 155 |
+
reduce_zero_label (bool): Whether ignore zero label. Default: False.
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
dict[str, float | ndarray]:
|
| 159 |
+
<aAcc> float: Overall accuracy on all images.
|
| 160 |
+
<Acc> ndarray: Per category accuracy, shape (num_classes, ).
|
| 161 |
+
<IoU> ndarray: Per category IoU, shape (num_classes, ).
|
| 162 |
+
"""
|
| 163 |
+
iou_result = eval_metrics(
|
| 164 |
+
results=results,
|
| 165 |
+
gt_seg_maps=gt_seg_maps,
|
| 166 |
+
num_classes=num_classes,
|
| 167 |
+
ignore_index=ignore_index,
|
| 168 |
+
metrics=['mIoU'],
|
| 169 |
+
nan_to_num=nan_to_num,
|
| 170 |
+
label_map=label_map,
|
| 171 |
+
reduce_zero_label=reduce_zero_label)
|
| 172 |
+
return iou_result
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def mean_dice(results,
|
| 176 |
+
gt_seg_maps,
|
| 177 |
+
num_classes,
|
| 178 |
+
ignore_index,
|
| 179 |
+
nan_to_num=None,
|
| 180 |
+
label_map=dict(),
|
| 181 |
+
reduce_zero_label=False):
|
| 182 |
+
"""Calculate Mean Dice (mDice)
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
results (list[ndarray] | list[str]): List of prediction segmentation
|
| 186 |
+
maps or list of prediction result filenames.
|
| 187 |
+
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
|
| 188 |
+
segmentation maps or list of label filenames.
|
| 189 |
+
num_classes (int): Number of categories.
|
| 190 |
+
ignore_index (int): Index that will be ignored in evaluation.
|
| 191 |
+
nan_to_num (int, optional): If specified, NaN values will be replaced
|
| 192 |
+
by the numbers defined by the user. Default: None.
|
| 193 |
+
label_map (dict): Mapping old labels to new labels. Default: dict().
|
| 194 |
+
reduce_zero_label (bool): Whether ignore zero label. Default: False.
|
| 195 |
+
|
| 196 |
+
Returns:
|
| 197 |
+
dict[str, float | ndarray]: Default metrics.
|
| 198 |
+
<aAcc> float: Overall accuracy on all images.
|
| 199 |
+
<Acc> ndarray: Per category accuracy, shape (num_classes, ).
|
| 200 |
+
<Dice> ndarray: Per category dice, shape (num_classes, ).
|
| 201 |
+
"""
|
| 202 |
+
|
| 203 |
+
dice_result = eval_metrics(
|
| 204 |
+
results=results,
|
| 205 |
+
gt_seg_maps=gt_seg_maps,
|
| 206 |
+
num_classes=num_classes,
|
| 207 |
+
ignore_index=ignore_index,
|
| 208 |
+
metrics=['mDice'],
|
| 209 |
+
nan_to_num=nan_to_num,
|
| 210 |
+
label_map=label_map,
|
| 211 |
+
reduce_zero_label=reduce_zero_label)
|
| 212 |
+
return dice_result
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def mean_fscore(results,
|
| 216 |
+
gt_seg_maps,
|
| 217 |
+
num_classes,
|
| 218 |
+
ignore_index,
|
| 219 |
+
nan_to_num=None,
|
| 220 |
+
label_map=dict(),
|
| 221 |
+
reduce_zero_label=False,
|
| 222 |
+
beta=1):
|
| 223 |
+
"""Calculate Mean F-Score (mFscore)
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
results (list[ndarray] | list[str]): List of prediction segmentation
|
| 227 |
+
maps or list of prediction result filenames.
|
| 228 |
+
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
|
| 229 |
+
segmentation maps or list of label filenames.
|
| 230 |
+
num_classes (int): Number of categories.
|
| 231 |
+
ignore_index (int): Index that will be ignored in evaluation.
|
| 232 |
+
nan_to_num (int, optional): If specified, NaN values will be replaced
|
| 233 |
+
by the numbers defined by the user. Default: None.
|
| 234 |
+
label_map (dict): Mapping old labels to new labels. Default: dict().
|
| 235 |
+
reduce_zero_label (bool): Whether ignore zero label. Default: False.
|
| 236 |
+
beta (int): Determines the weight of recall in the combined score.
|
| 237 |
+
Default: False.
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
Returns:
|
| 241 |
+
dict[str, float | ndarray]: Default metrics.
|
| 242 |
+
<aAcc> float: Overall accuracy on all images.
|
| 243 |
+
<Fscore> ndarray: Per category recall, shape (num_classes, ).
|
| 244 |
+
<Precision> ndarray: Per category precision, shape (num_classes, ).
|
| 245 |
+
<Recall> ndarray: Per category f-score, shape (num_classes, ).
|
| 246 |
+
"""
|
| 247 |
+
fscore_result = eval_metrics(
|
| 248 |
+
results=results,
|
| 249 |
+
gt_seg_maps=gt_seg_maps,
|
| 250 |
+
num_classes=num_classes,
|
| 251 |
+
ignore_index=ignore_index,
|
| 252 |
+
metrics=['mFscore'],
|
| 253 |
+
nan_to_num=nan_to_num,
|
| 254 |
+
label_map=label_map,
|
| 255 |
+
reduce_zero_label=reduce_zero_label,
|
| 256 |
+
beta=beta)
|
| 257 |
+
return fscore_result
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def eval_metrics(results,
|
| 261 |
+
gt_seg_maps,
|
| 262 |
+
num_classes,
|
| 263 |
+
ignore_index,
|
| 264 |
+
metrics=['mIoU'],
|
| 265 |
+
nan_to_num=None,
|
| 266 |
+
label_map=dict(),
|
| 267 |
+
reduce_zero_label=False,
|
| 268 |
+
beta=1):
|
| 269 |
+
"""Calculate evaluation metrics
|
| 270 |
+
Args:
|
| 271 |
+
results (list[ndarray] | list[str]): List of prediction segmentation
|
| 272 |
+
maps or list of prediction result filenames.
|
| 273 |
+
gt_seg_maps (list[ndarray] | list[str] | Iterables): list of ground
|
| 274 |
+
truth segmentation maps or list of label filenames.
|
| 275 |
+
num_classes (int): Number of categories.
|
| 276 |
+
ignore_index (int): Index that will be ignored in evaluation.
|
| 277 |
+
metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
|
| 278 |
+
nan_to_num (int, optional): If specified, NaN values will be replaced
|
| 279 |
+
by the numbers defined by the user. Default: None.
|
| 280 |
+
label_map (dict): Mapping old labels to new labels. Default: dict().
|
| 281 |
+
reduce_zero_label (bool): Whether ignore zero label. Default: False.
|
| 282 |
+
Returns:
|
| 283 |
+
float: Overall accuracy on all images.
|
| 284 |
+
ndarray: Per category accuracy, shape (num_classes, ).
|
| 285 |
+
ndarray: Per category evaluation metrics, shape (num_classes, ).
|
| 286 |
+
"""
|
| 287 |
+
|
| 288 |
+
total_area_intersect, total_area_union, total_area_pred_label, total_area_label = total_intersect_and_union(results, gt_seg_maps, num_classes, ignore_index, label_map, reduce_zero_label)
|
| 289 |
+
ret_metrics = total_area_to_metrics(total_area_intersect, total_area_union,
|
| 290 |
+
total_area_pred_label,
|
| 291 |
+
total_area_label, metrics, nan_to_num,
|
| 292 |
+
beta)
|
| 293 |
+
|
| 294 |
+
return ret_metrics
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def pre_eval_to_metrics(pre_eval_results,
|
| 298 |
+
metrics=['mIoU'],
|
| 299 |
+
nan_to_num=None,
|
| 300 |
+
beta=1):
|
| 301 |
+
"""Convert pre-eval results to metrics.
|
| 302 |
+
|
| 303 |
+
Args:
|
| 304 |
+
pre_eval_results (list[tuple[torch.Tensor]]): per image eval results
|
| 305 |
+
for computing evaluation metric
|
| 306 |
+
metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
|
| 307 |
+
nan_to_num (int, optional): If specified, NaN values will be replaced
|
| 308 |
+
by the numbers defined by the user. Default: None.
|
| 309 |
+
Returns:
|
| 310 |
+
float: Overall accuracy on all images.
|
| 311 |
+
ndarray: Per category accuracy, shape (num_classes, ).
|
| 312 |
+
ndarray: Per category evaluation metrics, shape (num_classes, ).
|
| 313 |
+
"""
|
| 314 |
+
|
| 315 |
+
# convert list of tuples to tuple of lists, e.g.
|
| 316 |
+
# [(A_1, B_1, C_1, D_1), ..., (A_n, B_n, C_n, D_n)] to
|
| 317 |
+
# ([A_1, ..., A_n], ..., [D_1, ..., D_n])
|
| 318 |
+
pre_eval_results = tuple(zip(*pre_eval_results))
|
| 319 |
+
assert len(pre_eval_results) == 4
|
| 320 |
+
|
| 321 |
+
total_area_intersect = sum(pre_eval_results[0])
|
| 322 |
+
total_area_union = sum(pre_eval_results[1])
|
| 323 |
+
total_area_pred_label = sum(pre_eval_results[2])
|
| 324 |
+
total_area_label = sum(pre_eval_results[3])
|
| 325 |
+
|
| 326 |
+
ret_metrics = total_area_to_metrics(total_area_intersect, total_area_union,
|
| 327 |
+
total_area_pred_label,
|
| 328 |
+
total_area_label, metrics, nan_to_num,
|
| 329 |
+
beta)
|
| 330 |
+
|
| 331 |
+
return ret_metrics
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def total_area_to_metrics(total_area_intersect,
|
| 335 |
+
total_area_union,
|
| 336 |
+
total_area_pred_label,
|
| 337 |
+
total_area_label,
|
| 338 |
+
metrics=['mIoU'],
|
| 339 |
+
nan_to_num=None,
|
| 340 |
+
beta=1):
|
| 341 |
+
"""Calculate evaluation metrics
|
| 342 |
+
Args:
|
| 343 |
+
total_area_intersect (ndarray): The intersection of prediction and
|
| 344 |
+
ground truth histogram on all classes.
|
| 345 |
+
total_area_union (ndarray): The union of prediction and ground truth
|
| 346 |
+
histogram on all classes.
|
| 347 |
+
total_area_pred_label (ndarray): The prediction histogram on all
|
| 348 |
+
classes.
|
| 349 |
+
total_area_label (ndarray): The ground truth histogram on all classes.
|
| 350 |
+
metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
|
| 351 |
+
nan_to_num (int, optional): If specified, NaN values will be replaced
|
| 352 |
+
by the numbers defined by the user. Default: None.
|
| 353 |
+
Returns:
|
| 354 |
+
float: Overall accuracy on all images.
|
| 355 |
+
ndarray: Per category accuracy, shape (num_classes, ).
|
| 356 |
+
ndarray: Per category evaluation metrics, shape (num_classes, ).
|
| 357 |
+
"""
|
| 358 |
+
if isinstance(metrics, str):
|
| 359 |
+
metrics = [metrics]
|
| 360 |
+
allowed_metrics = ['mIoU', 'mDice', 'mFscore']
|
| 361 |
+
if not set(metrics).issubset(set(allowed_metrics)):
|
| 362 |
+
raise KeyError('metrics {} is not supported'.format(metrics))
|
| 363 |
+
|
| 364 |
+
all_acc = total_area_intersect.sum() / total_area_label.sum()
|
| 365 |
+
ret_metrics = OrderedDict({'aAcc': all_acc})
|
| 366 |
+
for metric in metrics:
|
| 367 |
+
if metric == 'mIoU':
|
| 368 |
+
iou = total_area_intersect / total_area_union
|
| 369 |
+
acc = total_area_intersect / total_area_label
|
| 370 |
+
ret_metrics['IoU'] = iou
|
| 371 |
+
ret_metrics['Acc'] = acc
|
| 372 |
+
elif metric == 'mDice':
|
| 373 |
+
dice = 2 * total_area_intersect / (
|
| 374 |
+
total_area_pred_label + total_area_label)
|
| 375 |
+
acc = total_area_intersect / total_area_label
|
| 376 |
+
ret_metrics['Dice'] = dice
|
| 377 |
+
ret_metrics['Acc'] = acc
|
| 378 |
+
elif metric == 'mFscore':
|
| 379 |
+
precision = total_area_intersect / total_area_pred_label
|
| 380 |
+
recall = total_area_intersect / total_area_label
|
| 381 |
+
f_value = torch.tensor(
|
| 382 |
+
[f_score(x[0], x[1], beta) for x in zip(precision, recall)])
|
| 383 |
+
ret_metrics['Fscore'] = f_value
|
| 384 |
+
ret_metrics['Precision'] = precision
|
| 385 |
+
ret_metrics['Recall'] = recall
|
| 386 |
+
|
| 387 |
+
ret_metrics = {
|
| 388 |
+
metric: value.numpy()
|
| 389 |
+
for metric, value in ret_metrics.items()
|
| 390 |
+
}
|
| 391 |
+
if nan_to_num is not None:
|
| 392 |
+
ret_metrics = OrderedDict({
|
| 393 |
+
metric: np.nan_to_num(metric_value, nan=nan_to_num)
|
| 394 |
+
for metric, metric_value in ret_metrics.items()
|
| 395 |
+
})
|
| 396 |
+
return ret_metrics
|
modelsforCIML/mmseg/core/hook/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from .wandblogger_hook import MMSegWandbHook
|
| 3 |
+
|
| 4 |
+
__all__ = ['MMSegWandbHook']
|
modelsforCIML/mmseg/core/hook/wandblogger_hook.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import os.path as osp
|
| 3 |
+
|
| 4 |
+
import mmcv
|
| 5 |
+
import numpy as np
|
| 6 |
+
from mmcv.runner import HOOKS
|
| 7 |
+
from mmcv.runner.dist_utils import master_only
|
| 8 |
+
from mmcv.runner.hooks.checkpoint import CheckpointHook
|
| 9 |
+
from mmcv.runner.hooks.logger.wandb import WandbLoggerHook
|
| 10 |
+
|
| 11 |
+
from mmseg.core import DistEvalHook, EvalHook
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@HOOKS.register_module()
|
| 15 |
+
class MMSegWandbHook(WandbLoggerHook):
|
| 16 |
+
"""Enhanced Wandb logger hook for MMSegmentation.
|
| 17 |
+
|
| 18 |
+
Comparing with the :cls:`mmcv.runner.WandbLoggerHook`, this hook can not
|
| 19 |
+
only automatically log all the metrics but also log the following extra
|
| 20 |
+
information - saves model checkpoints as W&B Artifact, and
|
| 21 |
+
logs model prediction as interactive W&B Tables.
|
| 22 |
+
|
| 23 |
+
- Metrics: The MMSegWandbHook will automatically log training
|
| 24 |
+
and validation metrics along with system metrics (CPU/GPU).
|
| 25 |
+
|
| 26 |
+
- Checkpointing: If `log_checkpoint` is True, the checkpoint saved at
|
| 27 |
+
every checkpoint interval will be saved as W&B Artifacts.
|
| 28 |
+
This depends on the : class:`mmcv.runner.CheckpointHook` whose priority
|
| 29 |
+
is higher than this hook. Please refer to
|
| 30 |
+
https://docs.wandb.ai/guides/artifacts/model-versioning
|
| 31 |
+
to learn more about model versioning with W&B Artifacts.
|
| 32 |
+
|
| 33 |
+
- Checkpoint Metadata: If evaluation results are available for a given
|
| 34 |
+
checkpoint artifact, it will have a metadata associated with it.
|
| 35 |
+
The metadata contains the evaluation metrics computed on validation
|
| 36 |
+
data with that checkpoint along with the current epoch. It depends
|
| 37 |
+
on `EvalHook` whose priority is more than MMSegWandbHook.
|
| 38 |
+
|
| 39 |
+
- Evaluation: At every evaluation interval, the `MMSegWandbHook` logs the
|
| 40 |
+
model prediction as interactive W&B Tables. The number of samples
|
| 41 |
+
logged is given by `num_eval_images`. Currently, the `MMSegWandbHook`
|
| 42 |
+
logs the predicted segmentation masks along with the ground truth at
|
| 43 |
+
every evaluation interval. This depends on the `EvalHook` whose
|
| 44 |
+
priority is more than `MMSegWandbHook`. Also note that the data is just
|
| 45 |
+
logged once and subsequent evaluation tables uses reference to the
|
| 46 |
+
logged data to save memory usage. Please refer to
|
| 47 |
+
https://docs.wandb.ai/guides/data-vis to learn more about W&B Tables.
|
| 48 |
+
|
| 49 |
+
```
|
| 50 |
+
Example:
|
| 51 |
+
log_config = dict(
|
| 52 |
+
...
|
| 53 |
+
hooks=[
|
| 54 |
+
...,
|
| 55 |
+
dict(type='MMSegWandbHook',
|
| 56 |
+
init_kwargs={
|
| 57 |
+
'entity': "YOUR_ENTITY",
|
| 58 |
+
'project': "YOUR_PROJECT_NAME"
|
| 59 |
+
},
|
| 60 |
+
interval=50,
|
| 61 |
+
log_checkpoint=True,
|
| 62 |
+
log_checkpoint_metadata=True,
|
| 63 |
+
num_eval_images=100,
|
| 64 |
+
bbox_score_thr=0.3)
|
| 65 |
+
])
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
init_kwargs (dict): A dict passed to wandb.init to initialize
|
| 70 |
+
a W&B run. Please refer to https://docs.wandb.ai/ref/python/init
|
| 71 |
+
for possible key-value pairs.
|
| 72 |
+
interval (int): Logging interval (every k iterations).
|
| 73 |
+
Default 10.
|
| 74 |
+
log_checkpoint (bool): Save the checkpoint at every checkpoint interval
|
| 75 |
+
as W&B Artifacts. Use this for model versioning where each version
|
| 76 |
+
is a checkpoint.
|
| 77 |
+
Default: False
|
| 78 |
+
log_checkpoint_metadata (bool): Log the evaluation metrics computed
|
| 79 |
+
on the validation data with the checkpoint, along with current
|
| 80 |
+
epoch as a metadata to that checkpoint.
|
| 81 |
+
Default: True
|
| 82 |
+
num_eval_images (int): Number of validation images to be logged.
|
| 83 |
+
Default: 100
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
def __init__(self,
|
| 87 |
+
init_kwargs=None,
|
| 88 |
+
interval=50,
|
| 89 |
+
log_checkpoint=False,
|
| 90 |
+
log_checkpoint_metadata=False,
|
| 91 |
+
num_eval_images=100,
|
| 92 |
+
**kwargs):
|
| 93 |
+
super(MMSegWandbHook, self).__init__(init_kwargs, interval, **kwargs)
|
| 94 |
+
|
| 95 |
+
self.log_checkpoint = log_checkpoint
|
| 96 |
+
self.log_checkpoint_metadata = (
|
| 97 |
+
log_checkpoint and log_checkpoint_metadata)
|
| 98 |
+
self.num_eval_images = num_eval_images
|
| 99 |
+
self.log_evaluation = (num_eval_images > 0)
|
| 100 |
+
self.ckpt_hook: CheckpointHook = None
|
| 101 |
+
self.eval_hook: EvalHook = None
|
| 102 |
+
self.test_fn = None
|
| 103 |
+
|
| 104 |
+
@master_only
|
| 105 |
+
def before_run(self, runner):
|
| 106 |
+
super(MMSegWandbHook, self).before_run(runner)
|
| 107 |
+
|
| 108 |
+
# Check if EvalHook and CheckpointHook are available.
|
| 109 |
+
for hook in runner.hooks:
|
| 110 |
+
if isinstance(hook, CheckpointHook):
|
| 111 |
+
self.ckpt_hook = hook
|
| 112 |
+
if isinstance(hook, EvalHook):
|
| 113 |
+
from mmseg.apis import single_gpu_test
|
| 114 |
+
self.eval_hook = hook
|
| 115 |
+
self.test_fn = single_gpu_test
|
| 116 |
+
if isinstance(hook, DistEvalHook):
|
| 117 |
+
from mmseg.apis import multi_gpu_test
|
| 118 |
+
self.eval_hook = hook
|
| 119 |
+
self.test_fn = multi_gpu_test
|
| 120 |
+
|
| 121 |
+
# Check conditions to log checkpoint
|
| 122 |
+
if self.log_checkpoint:
|
| 123 |
+
if self.ckpt_hook is None:
|
| 124 |
+
self.log_checkpoint = False
|
| 125 |
+
self.log_checkpoint_metadata = False
|
| 126 |
+
runner.logger.warning(
|
| 127 |
+
'To log checkpoint in MMSegWandbHook, `CheckpointHook` is'
|
| 128 |
+
'required, please check hooks in the runner.')
|
| 129 |
+
else:
|
| 130 |
+
self.ckpt_interval = self.ckpt_hook.interval
|
| 131 |
+
|
| 132 |
+
# Check conditions to log evaluation
|
| 133 |
+
if self.log_evaluation or self.log_checkpoint_metadata:
|
| 134 |
+
if self.eval_hook is None:
|
| 135 |
+
self.log_evaluation = False
|
| 136 |
+
self.log_checkpoint_metadata = False
|
| 137 |
+
runner.logger.warning(
|
| 138 |
+
'To log evaluation or checkpoint metadata in '
|
| 139 |
+
'MMSegWandbHook, `EvalHook` or `DistEvalHook` in mmseg '
|
| 140 |
+
'is required, please check whether the validation '
|
| 141 |
+
'is enabled.')
|
| 142 |
+
else:
|
| 143 |
+
self.eval_interval = self.eval_hook.interval
|
| 144 |
+
self.val_dataset = self.eval_hook.dataloader.dataset
|
| 145 |
+
# Determine the number of samples to be logged.
|
| 146 |
+
if self.num_eval_images > len(self.val_dataset):
|
| 147 |
+
self.num_eval_images = len(self.val_dataset)
|
| 148 |
+
runner.logger.warning(
|
| 149 |
+
f'The num_eval_images ({self.num_eval_images}) is '
|
| 150 |
+
'greater than the total number of validation samples '
|
| 151 |
+
f'({len(self.val_dataset)}). The complete validation '
|
| 152 |
+
'dataset will be logged.')
|
| 153 |
+
|
| 154 |
+
# Check conditions to log checkpoint metadata
|
| 155 |
+
if self.log_checkpoint_metadata:
|
| 156 |
+
assert self.ckpt_interval % self.eval_interval == 0, \
|
| 157 |
+
'To log checkpoint metadata in MMSegWandbHook, the interval ' \
|
| 158 |
+
f'of checkpoint saving ({self.ckpt_interval}) should be ' \
|
| 159 |
+
'divisible by the interval of evaluation ' \
|
| 160 |
+
f'({self.eval_interval}).'
|
| 161 |
+
|
| 162 |
+
# Initialize evaluation table
|
| 163 |
+
if self.log_evaluation:
|
| 164 |
+
# Initialize data table
|
| 165 |
+
self._init_data_table()
|
| 166 |
+
# Add data to the data table
|
| 167 |
+
self._add_ground_truth(runner)
|
| 168 |
+
# Log ground truth data
|
| 169 |
+
self._log_data_table()
|
| 170 |
+
|
| 171 |
+
# for the reason of this double-layered structure, refer to
|
| 172 |
+
# https://github.com/open-mmlab/mmdetection/issues/8145#issuecomment-1345343076
|
| 173 |
+
def after_train_iter(self, runner):
|
| 174 |
+
if self.get_mode(runner) == 'train':
|
| 175 |
+
# An ugly patch. The iter-based eval hook will call the
|
| 176 |
+
# `after_train_iter` method of all logger hooks before evaluation.
|
| 177 |
+
# Use this trick to skip that call.
|
| 178 |
+
# Don't call super method at first, it will clear the log_buffer
|
| 179 |
+
return super(MMSegWandbHook, self).after_train_iter(runner)
|
| 180 |
+
else:
|
| 181 |
+
super(MMSegWandbHook, self).after_train_iter(runner)
|
| 182 |
+
self._after_train_iter(runner)
|
| 183 |
+
|
| 184 |
+
@master_only
|
| 185 |
+
def _after_train_iter(self, runner):
|
| 186 |
+
if self.by_epoch:
|
| 187 |
+
return
|
| 188 |
+
|
| 189 |
+
# Save checkpoint and metadata
|
| 190 |
+
if (self.log_checkpoint
|
| 191 |
+
and self.every_n_iters(runner, self.ckpt_interval)
|
| 192 |
+
or (self.ckpt_hook.save_last and self.is_last_iter(runner))):
|
| 193 |
+
if self.log_checkpoint_metadata and self.eval_hook:
|
| 194 |
+
metadata = {
|
| 195 |
+
'iter': runner.iter + 1,
|
| 196 |
+
**self._get_eval_results()
|
| 197 |
+
}
|
| 198 |
+
else:
|
| 199 |
+
metadata = None
|
| 200 |
+
aliases = [f'iter_{runner.iter+1}', 'latest']
|
| 201 |
+
model_path = osp.join(self.ckpt_hook.out_dir,
|
| 202 |
+
f'iter_{runner.iter+1}.pth')
|
| 203 |
+
self._log_ckpt_as_artifact(model_path, aliases, metadata)
|
| 204 |
+
|
| 205 |
+
# Save prediction table
|
| 206 |
+
if self.log_evaluation and self.eval_hook._should_evaluate(runner):
|
| 207 |
+
# Currently the results of eval_hook is not reused by wandb, so
|
| 208 |
+
# wandb will run evaluation again internally. We will consider
|
| 209 |
+
# refactoring this function afterwards
|
| 210 |
+
results = self.test_fn(runner.model, self.eval_hook.dataloader)
|
| 211 |
+
# Initialize evaluation table
|
| 212 |
+
self._init_pred_table()
|
| 213 |
+
# Log predictions
|
| 214 |
+
self._log_predictions(results, runner)
|
| 215 |
+
# Log the table
|
| 216 |
+
self._log_eval_table(runner.iter + 1)
|
| 217 |
+
|
| 218 |
+
@master_only
|
| 219 |
+
def after_run(self, runner):
|
| 220 |
+
self.wandb.finish()
|
| 221 |
+
|
| 222 |
+
def _log_ckpt_as_artifact(self, model_path, aliases, metadata=None):
|
| 223 |
+
"""Log model checkpoint as W&B Artifact.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
model_path (str): Path of the checkpoint to log.
|
| 227 |
+
aliases (list): List of the aliases associated with this artifact.
|
| 228 |
+
metadata (dict, optional): Metadata associated with this artifact.
|
| 229 |
+
"""
|
| 230 |
+
model_artifact = self.wandb.Artifact(
|
| 231 |
+
f'run_{self.wandb.run.id}_model', type='model', metadata=metadata)
|
| 232 |
+
model_artifact.add_file(model_path)
|
| 233 |
+
self.wandb.log_artifact(model_artifact, aliases=aliases)
|
| 234 |
+
|
| 235 |
+
def _get_eval_results(self):
|
| 236 |
+
"""Get model evaluation results."""
|
| 237 |
+
results = self.eval_hook.latest_results
|
| 238 |
+
eval_results = self.val_dataset.evaluate(
|
| 239 |
+
results, logger='silent', **self.eval_hook.eval_kwargs)
|
| 240 |
+
return eval_results
|
| 241 |
+
|
| 242 |
+
def _init_data_table(self):
|
| 243 |
+
"""Initialize the W&B Tables for validation data."""
|
| 244 |
+
columns = ['image_name', 'image']
|
| 245 |
+
self.data_table = self.wandb.Table(columns=columns)
|
| 246 |
+
|
| 247 |
+
def _init_pred_table(self):
|
| 248 |
+
"""Initialize the W&B Tables for model evaluation."""
|
| 249 |
+
columns = ['image_name', 'ground_truth', 'prediction']
|
| 250 |
+
self.eval_table = self.wandb.Table(columns=columns)
|
| 251 |
+
|
| 252 |
+
def _add_ground_truth(self, runner):
|
| 253 |
+
# Get image loading pipeline
|
| 254 |
+
from mmseg.datasets.pipelines import LoadImageFromFile
|
| 255 |
+
img_loader = None
|
| 256 |
+
for t in self.val_dataset.pipeline.transforms:
|
| 257 |
+
if isinstance(t, LoadImageFromFile):
|
| 258 |
+
img_loader = t
|
| 259 |
+
|
| 260 |
+
if img_loader is None:
|
| 261 |
+
self.log_evaluation = False
|
| 262 |
+
runner.logger.warning(
|
| 263 |
+
'LoadImageFromFile is required to add images '
|
| 264 |
+
'to W&B Tables.')
|
| 265 |
+
return
|
| 266 |
+
|
| 267 |
+
# Select the images to be logged.
|
| 268 |
+
self.eval_image_indexs = np.arange(len(self.val_dataset))
|
| 269 |
+
# Set seed so that same validation set is logged each time.
|
| 270 |
+
np.random.seed(42)
|
| 271 |
+
np.random.shuffle(self.eval_image_indexs)
|
| 272 |
+
self.eval_image_indexs = self.eval_image_indexs[:self.num_eval_images]
|
| 273 |
+
|
| 274 |
+
classes = self.val_dataset.CLASSES
|
| 275 |
+
self.class_id_to_label = {id: name for id, name in enumerate(classes)}
|
| 276 |
+
self.class_set = self.wandb.Classes([{
|
| 277 |
+
'id': id,
|
| 278 |
+
'name': name
|
| 279 |
+
} for id, name in self.class_id_to_label.items()])
|
| 280 |
+
|
| 281 |
+
for idx in self.eval_image_indexs:
|
| 282 |
+
img_info = self.val_dataset.img_infos[idx]
|
| 283 |
+
image_name = img_info['filename']
|
| 284 |
+
|
| 285 |
+
# Get image and convert from BGR to RGB
|
| 286 |
+
img_meta = img_loader(
|
| 287 |
+
dict(img_info=img_info, img_prefix=self.val_dataset.img_dir))
|
| 288 |
+
image = mmcv.bgr2rgb(img_meta['img'])
|
| 289 |
+
|
| 290 |
+
# Get segmentation mask
|
| 291 |
+
seg_mask = self.val_dataset.get_gt_seg_map_by_idx(idx)
|
| 292 |
+
# Dict of masks to be logged.
|
| 293 |
+
wandb_masks = None
|
| 294 |
+
if seg_mask.ndim == 2:
|
| 295 |
+
wandb_masks = {
|
| 296 |
+
'ground_truth': {
|
| 297 |
+
'mask_data': seg_mask,
|
| 298 |
+
'class_labels': self.class_id_to_label
|
| 299 |
+
}
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
# Log a row to the data table.
|
| 303 |
+
self.data_table.add_data(
|
| 304 |
+
image_name,
|
| 305 |
+
self.wandb.Image(
|
| 306 |
+
image, masks=wandb_masks, classes=self.class_set))
|
| 307 |
+
else:
|
| 308 |
+
runner.logger.warning(
|
| 309 |
+
f'The segmentation mask is {seg_mask.ndim}D which '
|
| 310 |
+
'is not supported by W&B.')
|
| 311 |
+
self.log_evaluation = False
|
| 312 |
+
return
|
| 313 |
+
|
| 314 |
+
def _log_predictions(self, results, runner):
|
| 315 |
+
table_idxs = self.data_table_ref.get_index()
|
| 316 |
+
assert len(table_idxs) == len(self.eval_image_indexs)
|
| 317 |
+
assert len(results) == len(self.val_dataset)
|
| 318 |
+
|
| 319 |
+
for ndx, eval_image_index in enumerate(self.eval_image_indexs):
|
| 320 |
+
# Get the result
|
| 321 |
+
pred_mask = results[eval_image_index]
|
| 322 |
+
|
| 323 |
+
if pred_mask.ndim == 2:
|
| 324 |
+
wandb_masks = {
|
| 325 |
+
'prediction': {
|
| 326 |
+
'mask_data': pred_mask,
|
| 327 |
+
'class_labels': self.class_id_to_label
|
| 328 |
+
}
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
# Log a row to the data table.
|
| 332 |
+
self.eval_table.add_data(
|
| 333 |
+
self.data_table_ref.data[ndx][0],
|
| 334 |
+
self.data_table_ref.data[ndx][1],
|
| 335 |
+
self.wandb.Image(
|
| 336 |
+
self.data_table_ref.data[ndx][1],
|
| 337 |
+
masks=wandb_masks,
|
| 338 |
+
classes=self.class_set))
|
| 339 |
+
else:
|
| 340 |
+
runner.logger.warning(
|
| 341 |
+
'The predictio segmentation mask is '
|
| 342 |
+
f'{pred_mask.ndim}D which is not supported by W&B.')
|
| 343 |
+
self.log_evaluation = False
|
| 344 |
+
return
|
| 345 |
+
|
| 346 |
+
def _log_data_table(self):
|
| 347 |
+
"""Log the W&B Tables for validation data as artifact and calls
|
| 348 |
+
`use_artifact` on it so that the evaluation table can use the reference
|
| 349 |
+
of already uploaded images.
|
| 350 |
+
|
| 351 |
+
This allows the data to be uploaded just once.
|
| 352 |
+
"""
|
| 353 |
+
data_artifact = self.wandb.Artifact('val', type='dataset')
|
| 354 |
+
data_artifact.add(self.data_table, 'val_data')
|
| 355 |
+
|
| 356 |
+
self.wandb.run.use_artifact(data_artifact)
|
| 357 |
+
data_artifact.wait()
|
| 358 |
+
|
| 359 |
+
self.data_table_ref = data_artifact.get('val_data')
|
| 360 |
+
|
| 361 |
+
def _log_eval_table(self, iter):
|
| 362 |
+
"""Log the W&B Tables for model evaluation.
|
| 363 |
+
|
| 364 |
+
The table will be logged multiple times creating new version. Use this
|
| 365 |
+
to compare models at different intervals interactively.
|
| 366 |
+
"""
|
| 367 |
+
pred_artifact = self.wandb.Artifact(
|
| 368 |
+
f'run_{self.wandb.run.id}_pred', type='evaluation')
|
| 369 |
+
pred_artifact.add(self.eval_table, 'eval_data')
|
| 370 |
+
self.wandb.run.log_artifact(pred_artifact)
|
modelsforCIML/mmseg/core/optimizers/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from .layer_decay_optimizer_constructor import (
|
| 3 |
+
LayerDecayOptimizerConstructor, LearningRateDecayOptimizerConstructor)
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor'
|
| 7 |
+
]
|
modelsforCIML/mmseg/core/optimizers/layer_decay_optimizer_constructor.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import json
|
| 3 |
+
import warnings
|
| 4 |
+
|
| 5 |
+
from mmcv.runner import DefaultOptimizerConstructor, get_dist_info
|
| 6 |
+
|
| 7 |
+
from mmseg.utils import get_root_logger
|
| 8 |
+
from ..builder import OPTIMIZER_BUILDERS
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_layer_id_for_convnext(var_name, max_layer_id):
|
| 12 |
+
"""Get the layer id to set the different learning rates in ``layer_wise``
|
| 13 |
+
decay_type.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
var_name (str): The key of the model.
|
| 17 |
+
max_layer_id (int): Maximum number of backbone layers.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
int: The id number corresponding to different learning rate in
|
| 21 |
+
``LearningRateDecayOptimizerConstructor``.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
if var_name in ('backbone.cls_token', 'backbone.mask_token',
|
| 25 |
+
'backbone.pos_embed', 'backbone2.cls_token', 'backbone2.mask_token',
|
| 26 |
+
'backbone2.pos_embed'):
|
| 27 |
+
return 0
|
| 28 |
+
elif (var_name.startswith('backbone.downsample_layers') or var_name.startswith('backbone2.downsample_layers')):
|
| 29 |
+
stage_id = int(var_name.split('.')[2])
|
| 30 |
+
if stage_id == 0:
|
| 31 |
+
layer_id = 0
|
| 32 |
+
elif stage_id == 1:
|
| 33 |
+
layer_id = 2
|
| 34 |
+
elif stage_id == 2:
|
| 35 |
+
layer_id = 3
|
| 36 |
+
elif stage_id == 3:
|
| 37 |
+
layer_id = max_layer_id
|
| 38 |
+
return layer_id
|
| 39 |
+
elif (var_name.startswith('backbone.stages') or var_name.startswith('backbone2.stages')):
|
| 40 |
+
stage_id = int(var_name.split('.')[2])
|
| 41 |
+
block_id = int(var_name.split('.')[3])
|
| 42 |
+
if stage_id == 0:
|
| 43 |
+
layer_id = 1
|
| 44 |
+
elif stage_id == 1:
|
| 45 |
+
layer_id = 2
|
| 46 |
+
elif stage_id == 2:
|
| 47 |
+
layer_id = 3 + block_id // 3
|
| 48 |
+
elif stage_id == 3:
|
| 49 |
+
layer_id = max_layer_id
|
| 50 |
+
return layer_id
|
| 51 |
+
else:
|
| 52 |
+
return max_layer_id + 1
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def get_stage_id_for_convnext(var_name, max_stage_id):
|
| 56 |
+
"""Get the stage id to set the different learning rates in ``stage_wise``
|
| 57 |
+
decay_type.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
var_name (str): The key of the model.
|
| 61 |
+
max_stage_id (int): Maximum number of backbone layers.
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
int: The id number corresponding to different learning rate in
|
| 65 |
+
``LearningRateDecayOptimizerConstructor``.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
if var_name in ('backbone.cls_token', 'backbone.mask_token',
|
| 69 |
+
'backbone.pos_embed', 'backbone2.cls_token', 'backbone2.mask_token',
|
| 70 |
+
'backbone2.pos_embed'):
|
| 71 |
+
return 0
|
| 72 |
+
elif (var_name.startswith('backbone.downsample_layers') or var_name.startswith('backbone2.downsample_layers')):
|
| 73 |
+
return 0
|
| 74 |
+
elif ((var_name.startswith('backbone.stages') or var_name.startswith('backbone2.stages'))):
|
| 75 |
+
stage_id = int(var_name.split('.')[2])
|
| 76 |
+
return stage_id + 1
|
| 77 |
+
else:
|
| 78 |
+
return max_stage_id - 1
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def get_layer_id_for_vit(var_name, max_layer_id):
|
| 82 |
+
"""Get the layer id to set the different learning rates.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
var_name (str): The key of the model.
|
| 86 |
+
num_max_layer (int): Maximum number of backbone layers.
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
int: Returns the layer id of the key.
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
if var_name in ('backbone.cls_token', 'backbone.mask_token',
|
| 93 |
+
'backbone.pos_embed', 'backbone2.cls_token', 'backbone2.mask_token',
|
| 94 |
+
'backbone2.pos_embed'):
|
| 95 |
+
return 0
|
| 96 |
+
elif (var_name.startswith('backbone.patch_embed') or var_name.startswith('backbone2.patch_embed')):
|
| 97 |
+
return 0
|
| 98 |
+
elif (var_name.startswith('backbone.layers') or var_name.startswith('backbone2.layers')):
|
| 99 |
+
layer_id = int(var_name.split('.')[2])
|
| 100 |
+
return layer_id + 1
|
| 101 |
+
else:
|
| 102 |
+
return max_layer_id - 1
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@OPTIMIZER_BUILDERS.register_module()
|
| 106 |
+
class LearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor):
|
| 107 |
+
"""Different learning rates are set for different layers of backbone.
|
| 108 |
+
|
| 109 |
+
Note: Currently, this optimizer constructor is built for ConvNeXt,
|
| 110 |
+
BEiT and MAE.
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
def add_params(self, params, module, **kwargs):
|
| 114 |
+
"""Add all parameters of module to the params list.
|
| 115 |
+
|
| 116 |
+
The parameters of the given module will be added to the list of param
|
| 117 |
+
groups, with specific rules defined by paramwise_cfg.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
params (list[dict]): A list of param groups, it will be modified
|
| 121 |
+
in place.
|
| 122 |
+
module (nn.Module): The module to be added.
|
| 123 |
+
"""
|
| 124 |
+
logger = get_root_logger()
|
| 125 |
+
|
| 126 |
+
parameter_groups = {}
|
| 127 |
+
logger.info(f'self.paramwise_cfg is {self.paramwise_cfg}')
|
| 128 |
+
num_layers = self.paramwise_cfg.get('num_layers') + 2
|
| 129 |
+
decay_rate = self.paramwise_cfg.get('decay_rate')
|
| 130 |
+
decay_type = self.paramwise_cfg.get('decay_type', 'layer_wise')
|
| 131 |
+
logger.info('Build LearningRateDecayOptimizerConstructor '
|
| 132 |
+
f'{decay_type} {decay_rate} - {num_layers}')
|
| 133 |
+
weight_decay = self.base_wd
|
| 134 |
+
for name, param in module.named_parameters():
|
| 135 |
+
if not param.requires_grad:
|
| 136 |
+
continue # frozen weights
|
| 137 |
+
if len(param.shape) == 1 or name.endswith('.bias') or name in (
|
| 138 |
+
'pos_embed', 'cls_token'):
|
| 139 |
+
group_name = 'no_decay'
|
| 140 |
+
this_weight_decay = 0.
|
| 141 |
+
else:
|
| 142 |
+
group_name = 'decay'
|
| 143 |
+
this_weight_decay = weight_decay
|
| 144 |
+
if 'layer_wise' in decay_type:
|
| 145 |
+
if 'ConvNeXt' in module.backbone.__class__.__name__:
|
| 146 |
+
layer_id = get_layer_id_for_convnext(
|
| 147 |
+
name, self.paramwise_cfg.get('num_layers'))
|
| 148 |
+
logger.info(f'set param {name} as id {layer_id}')
|
| 149 |
+
elif 'BEiT' in module.backbone.__class__.__name__ or \
|
| 150 |
+
'MAE' in module.backbone.__class__.__name__:
|
| 151 |
+
layer_id = get_layer_id_for_vit(name, num_layers)
|
| 152 |
+
logger.info(f'set param {name} as id {layer_id}')
|
| 153 |
+
else:
|
| 154 |
+
raise NotImplementedError()
|
| 155 |
+
elif decay_type == 'stage_wise':
|
| 156 |
+
if 'ConvNeXt' in module.backbone.__class__.__name__:
|
| 157 |
+
layer_id = get_stage_id_for_convnext(name, num_layers)
|
| 158 |
+
logger.info(f'set param {name} as id {layer_id}')
|
| 159 |
+
else:
|
| 160 |
+
raise NotImplementedError()
|
| 161 |
+
group_name = f'layer_{layer_id}_{group_name}'
|
| 162 |
+
|
| 163 |
+
if group_name not in parameter_groups:
|
| 164 |
+
scale = decay_rate**(num_layers - layer_id - 1)
|
| 165 |
+
|
| 166 |
+
parameter_groups[group_name] = {
|
| 167 |
+
'weight_decay': this_weight_decay,
|
| 168 |
+
'params': [],
|
| 169 |
+
'param_names': [],
|
| 170 |
+
'lr_scale': scale,
|
| 171 |
+
'group_name': group_name,
|
| 172 |
+
'lr': scale * self.base_lr,
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
parameter_groups[group_name]['params'].append(param)
|
| 176 |
+
parameter_groups[group_name]['param_names'].append(name)
|
| 177 |
+
rank, _ = get_dist_info()
|
| 178 |
+
if rank == 0:
|
| 179 |
+
to_display = {}
|
| 180 |
+
for key in parameter_groups:
|
| 181 |
+
to_display[key] = {
|
| 182 |
+
'param_names': parameter_groups[key]['param_names'],
|
| 183 |
+
'lr_scale': parameter_groups[key]['lr_scale'],
|
| 184 |
+
'lr': parameter_groups[key]['lr'],
|
| 185 |
+
'weight_decay': parameter_groups[key]['weight_decay'],
|
| 186 |
+
}
|
| 187 |
+
logger.info(f'Param groups = {json.dumps(to_display, indent=2)}')
|
| 188 |
+
params.extend(parameter_groups.values())
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
@OPTIMIZER_BUILDERS.register_module()
|
| 192 |
+
class LayerDecayOptimizerConstructor(LearningRateDecayOptimizerConstructor):
|
| 193 |
+
"""Different learning rates are set for different layers of backbone.
|
| 194 |
+
|
| 195 |
+
Note: Currently, this optimizer constructor is built for BEiT,
|
| 196 |
+
and it will be deprecated.
|
| 197 |
+
Please use ``LearningRateDecayOptimizerConstructor`` instead.
|
| 198 |
+
"""
|
| 199 |
+
|
| 200 |
+
def __init__(self, optimizer_cfg, paramwise_cfg):
|
| 201 |
+
warnings.warn('DeprecationWarning: Original '
|
| 202 |
+
'LayerDecayOptimizerConstructor of BEiT '
|
| 203 |
+
'will be deprecated. Please use '
|
| 204 |
+
'LearningRateDecayOptimizerConstructor instead, '
|
| 205 |
+
'and set decay_type = layer_wise_vit in paramwise_cfg.')
|
| 206 |
+
paramwise_cfg.update({'decay_type': 'layer_wise_vit'})
|
| 207 |
+
warnings.warn('DeprecationWarning: Layer_decay_rate will '
|
| 208 |
+
'be deleted, please use decay_rate instead.')
|
| 209 |
+
paramwise_cfg['decay_rate'] = paramwise_cfg.pop('layer_decay_rate')
|
| 210 |
+
super(LayerDecayOptimizerConstructor,
|
| 211 |
+
self).__init__(optimizer_cfg, paramwise_cfg)
|
modelsforCIML/mmseg/core/seg/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from .builder import build_pixel_sampler
|
| 3 |
+
from .sampler import BasePixelSampler, OHEMPixelSampler
|
| 4 |
+
|
| 5 |
+
__all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler']
|
modelsforCIML/mmseg/core/seg/builder.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from mmcv.utils import Registry, build_from_cfg
|
| 3 |
+
|
| 4 |
+
PIXEL_SAMPLERS = Registry('pixel sampler')
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def build_pixel_sampler(cfg, **default_args):
|
| 8 |
+
"""Build pixel sampler for segmentation map."""
|
| 9 |
+
return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args)
|
modelsforCIML/mmseg/core/seg/sampler/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from .base_pixel_sampler import BasePixelSampler
|
| 3 |
+
from .ohem_pixel_sampler import OHEMPixelSampler
|
| 4 |
+
|
| 5 |
+
__all__ = ['BasePixelSampler', 'OHEMPixelSampler']
|
modelsforCIML/mmseg/core/seg/sampler/base_pixel_sampler.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from abc import ABCMeta, abstractmethod
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class BasePixelSampler(metaclass=ABCMeta):
|
| 6 |
+
"""Base class of pixel sampler."""
|
| 7 |
+
|
| 8 |
+
def __init__(self, **kwargs):
|
| 9 |
+
pass
|
| 10 |
+
|
| 11 |
+
@abstractmethod
|
| 12 |
+
def sample(self, seg_logit, seg_label):
|
| 13 |
+
"""Placeholder for sample function."""
|
modelsforCIML/mmseg/core/seg/sampler/ohem_pixel_sampler.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
from ..builder import PIXEL_SAMPLERS
|
| 7 |
+
from .base_pixel_sampler import BasePixelSampler
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@PIXEL_SAMPLERS.register_module()
|
| 11 |
+
class OHEMPixelSampler(BasePixelSampler):
|
| 12 |
+
"""Online Hard Example Mining Sampler for segmentation.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
context (nn.Module): The context of sampler, subclass of
|
| 16 |
+
:obj:`BaseDecodeHead`.
|
| 17 |
+
thresh (float, optional): The threshold for hard example selection.
|
| 18 |
+
Below which, are prediction with low confidence. If not
|
| 19 |
+
specified, the hard examples will be pixels of top ``min_kept``
|
| 20 |
+
loss. Default: None.
|
| 21 |
+
min_kept (int, optional): The minimum number of predictions to keep.
|
| 22 |
+
Default: 100000.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, context, thresh=None, min_kept=100000):
|
| 26 |
+
super(OHEMPixelSampler, self).__init__()
|
| 27 |
+
self.context = context
|
| 28 |
+
assert min_kept > 1
|
| 29 |
+
self.thresh = thresh
|
| 30 |
+
self.min_kept = min_kept
|
| 31 |
+
|
| 32 |
+
def sample(self, seg_logit, seg_label):
|
| 33 |
+
"""Sample pixels that have high loss or with low prediction confidence.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W)
|
| 37 |
+
seg_label (torch.Tensor): segmentation label, shape (N, 1, H, W)
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
torch.Tensor: segmentation weight, shape (N, H, W)
|
| 41 |
+
"""
|
| 42 |
+
with torch.no_grad():
|
| 43 |
+
assert seg_logit.shape[2:] == seg_label.shape[2:]
|
| 44 |
+
assert seg_label.shape[1] == 1
|
| 45 |
+
seg_label = seg_label.squeeze(1).long()
|
| 46 |
+
batch_kept = self.min_kept * seg_label.size(0)
|
| 47 |
+
valid_mask = seg_label != self.context.ignore_index
|
| 48 |
+
seg_weight = seg_logit.new_zeros(size=seg_label.size())
|
| 49 |
+
valid_seg_weight = seg_weight[valid_mask]
|
| 50 |
+
if self.thresh is not None:
|
| 51 |
+
seg_prob = F.softmax(seg_logit, dim=1)
|
| 52 |
+
|
| 53 |
+
tmp_seg_label = seg_label.clone().unsqueeze(1)
|
| 54 |
+
tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0
|
| 55 |
+
seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1)
|
| 56 |
+
sort_prob, sort_indices = seg_prob[valid_mask].sort()
|
| 57 |
+
|
| 58 |
+
if sort_prob.numel() > 0:
|
| 59 |
+
min_threshold = sort_prob[min(batch_kept,
|
| 60 |
+
sort_prob.numel() - 1)]
|
| 61 |
+
else:
|
| 62 |
+
min_threshold = 0.0
|
| 63 |
+
threshold = max(min_threshold, self.thresh)
|
| 64 |
+
valid_seg_weight[seg_prob[valid_mask] < threshold] = 1.
|
| 65 |
+
else:
|
| 66 |
+
if not isinstance(self.context.loss_decode, nn.ModuleList):
|
| 67 |
+
losses_decode = [self.context.loss_decode]
|
| 68 |
+
else:
|
| 69 |
+
losses_decode = self.context.loss_decode
|
| 70 |
+
losses = 0.0
|
| 71 |
+
for loss_module in losses_decode:
|
| 72 |
+
losses += loss_module(
|
| 73 |
+
seg_logit,
|
| 74 |
+
seg_label,
|
| 75 |
+
weight=None,
|
| 76 |
+
ignore_index=self.context.ignore_index,
|
| 77 |
+
reduction_override='none')
|
| 78 |
+
|
| 79 |
+
# faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa
|
| 80 |
+
_, sort_indices = losses[valid_mask].sort(descending=True)
|
| 81 |
+
valid_seg_weight[sort_indices[:batch_kept]] = 1.
|
| 82 |
+
|
| 83 |
+
seg_weight[valid_mask] = valid_seg_weight
|
| 84 |
+
|
| 85 |
+
return seg_weight
|
modelsforCIML/mmseg/core/utils/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from .dist_util import check_dist_init, sync_random_seed
|
| 3 |
+
from .misc import add_prefix
|
| 4 |
+
|
| 5 |
+
__all__ = ['add_prefix', 'check_dist_init', 'sync_random_seed']
|
modelsforCIML/mmseg/core/utils/dist_util.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.distributed as dist
|
| 5 |
+
from mmcv.runner import get_dist_info
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def check_dist_init():
|
| 9 |
+
return dist.is_available() and dist.is_initialized()
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def sync_random_seed(seed=None, device='cuda'):
|
| 13 |
+
"""Make sure different ranks share the same seed. All workers must call
|
| 14 |
+
this function, otherwise it will deadlock. This method is generally used in
|
| 15 |
+
`DistributedSampler`, because the seed should be identical across all
|
| 16 |
+
processes in the distributed group.
|
| 17 |
+
|
| 18 |
+
In distributed sampling, different ranks should sample non-overlapped
|
| 19 |
+
data in the dataset. Therefore, this function is used to make sure that
|
| 20 |
+
each rank shuffles the data indices in the same order based
|
| 21 |
+
on the same seed. Then different ranks could use different indices
|
| 22 |
+
to select non-overlapped data from the same data list.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
seed (int, Optional): The seed. Default to None.
|
| 26 |
+
device (str): The device where the seed will be put on.
|
| 27 |
+
Default to 'cuda'.
|
| 28 |
+
Returns:
|
| 29 |
+
int: Seed to be used.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
if seed is None:
|
| 33 |
+
seed = np.random.randint(2**31)
|
| 34 |
+
assert isinstance(seed, int)
|
| 35 |
+
|
| 36 |
+
rank, world_size = get_dist_info()
|
| 37 |
+
|
| 38 |
+
if world_size == 1:
|
| 39 |
+
return seed
|
| 40 |
+
|
| 41 |
+
if rank == 0:
|
| 42 |
+
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
|
| 43 |
+
else:
|
| 44 |
+
random_num = torch.tensor(0, dtype=torch.int32, device=device)
|
| 45 |
+
dist.broadcast(random_num, src=0)
|
| 46 |
+
return random_num.item()
|
modelsforCIML/mmseg/core/utils/misc.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
def add_prefix(inputs, prefix):
|
| 3 |
+
"""Add prefix for dict.
|
| 4 |
+
|
| 5 |
+
Args:
|
| 6 |
+
inputs (dict): The input dict with str keys.
|
| 7 |
+
prefix (str): The prefix to add.
|
| 8 |
+
|
| 9 |
+
Returns:
|
| 10 |
+
|
| 11 |
+
dict: The dict with keys updated with ``prefix``.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
outputs = dict()
|
| 15 |
+
for name, value in inputs.items():
|
| 16 |
+
outputs[f'{prefix}.{name}'] = value
|
| 17 |
+
|
| 18 |
+
return outputs
|
modelsforCIML/mmseg/models/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone,
|
| 3 |
+
build_head, build_loss, build_segmentor)
|
| 4 |
+
from .decode_heads import * # noqa: F401,F403
|
| 5 |
+
from .losses import *
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'build_backbone',
|
| 9 |
+
'build_head', 'build_loss', 'build_segmentor'
|
| 10 |
+
]
|
modelsforCIML/mmseg/models/builder.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import warnings
|
| 3 |
+
|
| 4 |
+
from mmcv.cnn import MODELS as MMCV_MODELS
|
| 5 |
+
from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION
|
| 6 |
+
from mmcv.utils import Registry
|
| 7 |
+
|
| 8 |
+
MODELS = Registry('models', parent=MMCV_MODELS)
|
| 9 |
+
ATTENTION = Registry('attention', parent=MMCV_ATTENTION)
|
| 10 |
+
|
| 11 |
+
BACKBONES = MODELS
|
| 12 |
+
NECKS = MODELS
|
| 13 |
+
HEADS = MODELS
|
| 14 |
+
LOSSES = MODELS
|
| 15 |
+
SEGMENTORS = MODELS
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def build_backbone(cfg):
|
| 19 |
+
"""Build backbone."""
|
| 20 |
+
return BACKBONES.build(cfg)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def build_neck(cfg):
|
| 24 |
+
"""Build neck."""
|
| 25 |
+
return NECKS.build(cfg)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def build_head(cfg):
|
| 29 |
+
"""Build head."""
|
| 30 |
+
return HEADS.build(cfg)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def build_loss(cfg):
|
| 34 |
+
"""Build loss."""
|
| 35 |
+
return LOSSES.build(cfg)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def build_segmentor(cfg, train_cfg=None, test_cfg=None):
|
| 39 |
+
"""Build segmentor."""
|
| 40 |
+
if train_cfg is not None or test_cfg is not None:
|
| 41 |
+
warnings.warn(
|
| 42 |
+
'train_cfg and test_cfg is deprecated, '
|
| 43 |
+
'please specify them in model', UserWarning)
|
| 44 |
+
assert cfg.get('train_cfg') is None or train_cfg is None, \
|
| 45 |
+
'train_cfg specified in both outer field and model field '
|
| 46 |
+
assert cfg.get('test_cfg') is None or test_cfg is None, \
|
| 47 |
+
'test_cfg specified in both outer field and model field '
|
| 48 |
+
return SEGMENTORS.build(
|
| 49 |
+
cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
|
modelsforCIML/mmseg/models/decode_heads/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from .fcn_head import FCNHead
|
| 3 |
+
from .uper_lab import UPerLab
|
| 4 |
+
from .uper_head import UPerHead
|
| 5 |
+
from .sep_aspp_head import DepthwiseSeparableASPPHead
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
'FCNHead', 'UPerLab', 'UPerHead', 'DepthwiseSeparableASPPHead'
|
| 9 |
+
]
|
modelsforCIML/mmseg/models/decode_heads/aspp_head.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from mmcv.cnn import ConvModule
|
| 5 |
+
|
| 6 |
+
from mmseg.ops import resize
|
| 7 |
+
from ..builder import HEADS
|
| 8 |
+
from .decode_head import BaseDecodeHead
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ASPPModule(nn.ModuleList):
|
| 12 |
+
"""Atrous Spatial Pyramid Pooling (ASPP) Module.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
dilations (tuple[int]): Dilation rate of each layer.
|
| 16 |
+
in_channels (int): Input channels.
|
| 17 |
+
channels (int): Channels after modules, before conv_seg.
|
| 18 |
+
conv_cfg (dict|None): Config of conv layers.
|
| 19 |
+
norm_cfg (dict|None): Config of norm layers.
|
| 20 |
+
act_cfg (dict): Config of activation layers.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, dilations, in_channels, channels, conv_cfg, norm_cfg,
|
| 24 |
+
act_cfg):
|
| 25 |
+
super(ASPPModule, self).__init__()
|
| 26 |
+
self.dilations = dilations
|
| 27 |
+
self.in_channels = in_channels
|
| 28 |
+
self.channels = channels
|
| 29 |
+
self.conv_cfg = conv_cfg
|
| 30 |
+
self.norm_cfg = norm_cfg
|
| 31 |
+
self.act_cfg = act_cfg
|
| 32 |
+
for dilation in dilations:
|
| 33 |
+
self.append(
|
| 34 |
+
ConvModule(
|
| 35 |
+
self.in_channels,
|
| 36 |
+
self.channels,
|
| 37 |
+
1 if dilation == 1 else 3,
|
| 38 |
+
dilation=dilation,
|
| 39 |
+
padding=0 if dilation == 1 else dilation,
|
| 40 |
+
conv_cfg=self.conv_cfg,
|
| 41 |
+
norm_cfg=self.norm_cfg,
|
| 42 |
+
act_cfg=self.act_cfg))
|
| 43 |
+
|
| 44 |
+
def forward(self, x):
|
| 45 |
+
"""Forward function."""
|
| 46 |
+
aspp_outs = []
|
| 47 |
+
for aspp_module in self:
|
| 48 |
+
aspp_outs.append(aspp_module(x))
|
| 49 |
+
|
| 50 |
+
return aspp_outs
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@HEADS.register_module()
|
| 54 |
+
class ASPPHead(BaseDecodeHead):
|
| 55 |
+
"""Rethinking Atrous Convolution for Semantic Image Segmentation.
|
| 56 |
+
|
| 57 |
+
This head is the implementation of `DeepLabV3
|
| 58 |
+
<https://arxiv.org/abs/1706.05587>`_.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
dilations (tuple[int]): Dilation rates for ASPP module.
|
| 62 |
+
Default: (1, 6, 12, 18).
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
def __init__(self, dilations=(1, 6, 12, 18), **kwargs):
|
| 66 |
+
super(ASPPHead, self).__init__(**kwargs)
|
| 67 |
+
assert isinstance(dilations, (list, tuple))
|
| 68 |
+
self.dilations = dilations
|
| 69 |
+
self.image_pool = nn.Sequential(
|
| 70 |
+
nn.AdaptiveAvgPool2d(1),
|
| 71 |
+
ConvModule(
|
| 72 |
+
self.in_channels,
|
| 73 |
+
self.channels,
|
| 74 |
+
1,
|
| 75 |
+
conv_cfg=self.conv_cfg,
|
| 76 |
+
norm_cfg=self.norm_cfg,
|
| 77 |
+
act_cfg=self.act_cfg))
|
| 78 |
+
self.aspp_modules = ASPPModule(
|
| 79 |
+
dilations,
|
| 80 |
+
self.in_channels,
|
| 81 |
+
self.channels,
|
| 82 |
+
conv_cfg=self.conv_cfg,
|
| 83 |
+
norm_cfg=self.norm_cfg,
|
| 84 |
+
act_cfg=self.act_cfg)
|
| 85 |
+
self.bottleneck = ConvModule(
|
| 86 |
+
(len(dilations) + 1) * self.channels,
|
| 87 |
+
self.channels,
|
| 88 |
+
3,
|
| 89 |
+
padding=1,
|
| 90 |
+
conv_cfg=self.conv_cfg,
|
| 91 |
+
norm_cfg=self.norm_cfg,
|
| 92 |
+
act_cfg=self.act_cfg)
|
| 93 |
+
|
| 94 |
+
def _forward_feature(self, inputs):
|
| 95 |
+
"""Forward function for feature maps before classifying each pixel with
|
| 96 |
+
``self.cls_seg`` fc.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
inputs (list[Tensor]): List of multi-level img features.
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
feats (Tensor): A tensor of shape (batch_size, self.channels,
|
| 103 |
+
H, W) which is feature map for last layer of decoder head.
|
| 104 |
+
"""
|
| 105 |
+
x = self._transform_inputs(inputs)
|
| 106 |
+
aspp_outs = [
|
| 107 |
+
resize(
|
| 108 |
+
self.image_pool(x),
|
| 109 |
+
size=x.size()[2:],
|
| 110 |
+
mode='bilinear',
|
| 111 |
+
align_corners=self.align_corners)
|
| 112 |
+
]
|
| 113 |
+
aspp_outs.extend(self.aspp_modules(x))
|
| 114 |
+
aspp_outs = torch.cat(aspp_outs, dim=1)
|
| 115 |
+
feats = self.bottleneck(aspp_outs)
|
| 116 |
+
return feats
|
| 117 |
+
|
| 118 |
+
def forward(self, inputs):
|
| 119 |
+
"""Forward function."""
|
| 120 |
+
output = self._forward_feature(inputs)
|
| 121 |
+
output = self.cls_seg(output)
|
| 122 |
+
return output
|
modelsforCIML/mmseg/models/decode_heads/decode_head.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import warnings
|
| 3 |
+
from abc import ABCMeta, abstractmethod
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from mmcv.runner import BaseModule, auto_fp16, force_fp32
|
| 8 |
+
|
| 9 |
+
from mmseg.core import build_pixel_sampler
|
| 10 |
+
from mmseg.ops import resize
|
| 11 |
+
from ..builder import build_loss
|
| 12 |
+
from ..losses import accuracy
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
| 16 |
+
"""Base class for BaseDecodeHead.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
in_channels (int|Sequence[int]): Input channels.
|
| 20 |
+
channels (int): Channels after modules, before conv_seg.
|
| 21 |
+
num_classes (int): Number of classes.
|
| 22 |
+
out_channels (int): Output channels of conv_seg.
|
| 23 |
+
threshold (float): Threshold for binary segmentation in the case of
|
| 24 |
+
`out_channels==1`. Default: None.
|
| 25 |
+
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
|
| 26 |
+
conv_cfg (dict|None): Config of conv layers. Default: None.
|
| 27 |
+
norm_cfg (dict|None): Config of norm layers. Default: None.
|
| 28 |
+
act_cfg (dict): Config of activation layers.
|
| 29 |
+
Default: dict(type='ReLU')
|
| 30 |
+
in_index (int|Sequence[int]): Input feature index. Default: -1
|
| 31 |
+
input_transform (str|None): Transformation type of input features.
|
| 32 |
+
Options: 'resize_concat', 'multiple_select', None.
|
| 33 |
+
'resize_concat': Multiple feature maps will be resize to the
|
| 34 |
+
same size as first one and than concat together.
|
| 35 |
+
Usually used in FCN head of HRNet.
|
| 36 |
+
'multiple_select': Multiple feature maps will be bundle into
|
| 37 |
+
a list and passed into decode head.
|
| 38 |
+
None: Only one select feature map is allowed.
|
| 39 |
+
Default: None.
|
| 40 |
+
loss_decode (dict | Sequence[dict]): Config of decode loss.
|
| 41 |
+
The `loss_name` is property of corresponding loss function which
|
| 42 |
+
could be shown in training log. If you want this loss
|
| 43 |
+
item to be included into the backward graph, `loss_` must be the
|
| 44 |
+
prefix of the name. Defaults to 'loss_ce'.
|
| 45 |
+
e.g. dict(type='CrossEntropyLoss'),
|
| 46 |
+
[dict(type='CrossEntropyLoss', loss_name='loss_ce'),
|
| 47 |
+
dict(type='DiceLoss', loss_name='loss_dice')]
|
| 48 |
+
Default: dict(type='CrossEntropyLoss').
|
| 49 |
+
ignore_index (int | None): The label index to be ignored. When using
|
| 50 |
+
masked BCE loss, ignore_index should be set to None. Default: 255.
|
| 51 |
+
sampler (dict|None): The config of segmentation map sampler.
|
| 52 |
+
Default: None.
|
| 53 |
+
align_corners (bool): align_corners argument of F.interpolate.
|
| 54 |
+
Default: False.
|
| 55 |
+
init_cfg (dict or list[dict], optional): Initialization config dict.
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
def __init__(self,
|
| 59 |
+
in_channels,
|
| 60 |
+
channels,
|
| 61 |
+
*,
|
| 62 |
+
num_classes,
|
| 63 |
+
out_channels=None,
|
| 64 |
+
threshold=None,
|
| 65 |
+
dropout_ratio=0.1,
|
| 66 |
+
conv_cfg=None,
|
| 67 |
+
norm_cfg=None,
|
| 68 |
+
act_cfg=dict(type='ReLU'),
|
| 69 |
+
in_index=-1,
|
| 70 |
+
input_transform=None,
|
| 71 |
+
loss_decode=dict(
|
| 72 |
+
type='CrossEntropyLoss',
|
| 73 |
+
use_sigmoid=False,
|
| 74 |
+
loss_weight=1.0),
|
| 75 |
+
ignore_index=255,
|
| 76 |
+
sampler=None,
|
| 77 |
+
align_corners=False,
|
| 78 |
+
init_cfg=dict(
|
| 79 |
+
type='Normal', std=0.01, override=dict(name='conv_seg'))):
|
| 80 |
+
super(BaseDecodeHead, self).__init__(init_cfg)
|
| 81 |
+
self._init_inputs(in_channels, in_index, input_transform)
|
| 82 |
+
self.channels = channels
|
| 83 |
+
self.dropout_ratio = dropout_ratio
|
| 84 |
+
self.conv_cfg = conv_cfg
|
| 85 |
+
self.norm_cfg = norm_cfg
|
| 86 |
+
self.act_cfg = act_cfg
|
| 87 |
+
self.in_index = in_index
|
| 88 |
+
|
| 89 |
+
self.ignore_index = ignore_index
|
| 90 |
+
self.align_corners = align_corners
|
| 91 |
+
|
| 92 |
+
if out_channels is None:
|
| 93 |
+
if num_classes == 2:
|
| 94 |
+
warnings.warn('For binary segmentation, we suggest using'
|
| 95 |
+
'`out_channels = 1` to define the output'
|
| 96 |
+
'channels of segmentor, and use `threshold`'
|
| 97 |
+
'to convert seg_logist into a prediction'
|
| 98 |
+
'applying a threshold')
|
| 99 |
+
out_channels = num_classes
|
| 100 |
+
|
| 101 |
+
if out_channels != num_classes and out_channels != 1:
|
| 102 |
+
raise ValueError(
|
| 103 |
+
'out_channels should be equal to num_classes,'
|
| 104 |
+
'except binary segmentation set out_channels == 1 and'
|
| 105 |
+
f'num_classes == 2, but got out_channels={out_channels}'
|
| 106 |
+
f'and num_classes={num_classes}')
|
| 107 |
+
|
| 108 |
+
if out_channels == 1 and threshold is None:
|
| 109 |
+
threshold = 0.3
|
| 110 |
+
warnings.warn('threshold is not defined for binary, and defaults'
|
| 111 |
+
'to 0.3')
|
| 112 |
+
self.num_classes = num_classes
|
| 113 |
+
self.out_channels = out_channels
|
| 114 |
+
self.threshold = threshold
|
| 115 |
+
|
| 116 |
+
if isinstance(loss_decode, dict):
|
| 117 |
+
self.loss_decode = build_loss(loss_decode)
|
| 118 |
+
elif isinstance(loss_decode, (list, tuple)):
|
| 119 |
+
self.loss_decode = nn.ModuleList()
|
| 120 |
+
for loss in loss_decode:
|
| 121 |
+
self.loss_decode.append(build_loss(loss))
|
| 122 |
+
else:
|
| 123 |
+
raise TypeError(f'loss_decode must be a dict or sequence of dict,\
|
| 124 |
+
but got {type(loss_decode)}')
|
| 125 |
+
|
| 126 |
+
if sampler is not None:
|
| 127 |
+
self.sampler = build_pixel_sampler(sampler, context=self)
|
| 128 |
+
else:
|
| 129 |
+
self.sampler = None
|
| 130 |
+
|
| 131 |
+
self.conv_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1)
|
| 132 |
+
if dropout_ratio > 0:
|
| 133 |
+
self.dropout = nn.Dropout2d(dropout_ratio)
|
| 134 |
+
else:
|
| 135 |
+
self.dropout = None
|
| 136 |
+
self.fp16_enabled = False
|
| 137 |
+
|
| 138 |
+
def extra_repr(self):
|
| 139 |
+
"""Extra repr."""
|
| 140 |
+
s = f'input_transform={self.input_transform}, ' \
|
| 141 |
+
f'ignore_index={self.ignore_index}, ' \
|
| 142 |
+
f'align_corners={self.align_corners}'
|
| 143 |
+
return s
|
| 144 |
+
|
| 145 |
+
def _init_inputs(self, in_channels, in_index, input_transform):
|
| 146 |
+
"""Check and initialize input transforms.
|
| 147 |
+
|
| 148 |
+
The in_channels, in_index and input_transform must match.
|
| 149 |
+
Specifically, when input_transform is None, only single feature map
|
| 150 |
+
will be selected. So in_channels and in_index must be of type int.
|
| 151 |
+
When input_transform
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
in_channels (int|Sequence[int]): Input channels.
|
| 155 |
+
in_index (int|Sequence[int]): Input feature index.
|
| 156 |
+
input_transform (str|None): Transformation type of input features.
|
| 157 |
+
Options: 'resize_concat', 'multiple_select', None.
|
| 158 |
+
'resize_concat': Multiple feature maps will be resize to the
|
| 159 |
+
same size as first one and than concat together.
|
| 160 |
+
Usually used in FCN head of HRNet.
|
| 161 |
+
'multiple_select': Multiple feature maps will be bundle into
|
| 162 |
+
a list and passed into decode head.
|
| 163 |
+
None: Only one select feature map is allowed.
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
if input_transform is not None:
|
| 167 |
+
assert input_transform in ['resize_concat', 'multiple_select']
|
| 168 |
+
self.input_transform = input_transform
|
| 169 |
+
self.in_index = in_index
|
| 170 |
+
if input_transform is not None:
|
| 171 |
+
assert isinstance(in_channels, (list, tuple))
|
| 172 |
+
assert isinstance(in_index, (list, tuple))
|
| 173 |
+
assert len(in_channels) == len(in_index)
|
| 174 |
+
if input_transform == 'resize_concat':
|
| 175 |
+
self.in_channels = sum(in_channels)
|
| 176 |
+
else:
|
| 177 |
+
self.in_channels = in_channels
|
| 178 |
+
else:
|
| 179 |
+
assert isinstance(in_channels, int)
|
| 180 |
+
assert isinstance(in_index, int)
|
| 181 |
+
self.in_channels = in_channels
|
| 182 |
+
|
| 183 |
+
def _transform_inputs(self, inputs):
|
| 184 |
+
"""Transform inputs for decoder.
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
inputs (list[Tensor]): List of multi-level img features.
|
| 188 |
+
|
| 189 |
+
Returns:
|
| 190 |
+
Tensor: The transformed inputs
|
| 191 |
+
"""
|
| 192 |
+
|
| 193 |
+
if self.input_transform == 'resize_concat':
|
| 194 |
+
inputs = [inputs[i] for i in self.in_index]
|
| 195 |
+
upsampled_inputs = [
|
| 196 |
+
resize(
|
| 197 |
+
input=x,
|
| 198 |
+
size=inputs[0].shape[2:],
|
| 199 |
+
mode='bilinear',
|
| 200 |
+
align_corners=self.align_corners) for x in inputs
|
| 201 |
+
]
|
| 202 |
+
inputs = torch.cat(upsampled_inputs, dim=1)
|
| 203 |
+
elif self.input_transform == 'multiple_select':
|
| 204 |
+
inputs = [inputs[i] for i in self.in_index]
|
| 205 |
+
else:
|
| 206 |
+
inputs = inputs[self.in_index]
|
| 207 |
+
|
| 208 |
+
return inputs
|
| 209 |
+
|
| 210 |
+
@auto_fp16()
|
| 211 |
+
@abstractmethod
|
| 212 |
+
def forward(self, inputs):
|
| 213 |
+
"""Placeholder of forward function."""
|
| 214 |
+
pass
|
| 215 |
+
|
| 216 |
+
def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):
|
| 217 |
+
"""Forward function for training.
|
| 218 |
+
Args:
|
| 219 |
+
inputs (list[Tensor]): List of multi-level img features.
|
| 220 |
+
img_metas (list[dict]): List of image info dict where each dict
|
| 221 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
| 222 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
| 223 |
+
For details on the values of these keys see
|
| 224 |
+
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
| 225 |
+
gt_semantic_seg (Tensor): Semantic segmentation masks
|
| 226 |
+
used if the architecture supports semantic segmentation task.
|
| 227 |
+
train_cfg (dict): The training config.
|
| 228 |
+
|
| 229 |
+
Returns:
|
| 230 |
+
dict[str, Tensor]: a dictionary of loss components
|
| 231 |
+
"""
|
| 232 |
+
seg_logits = self(inputs)
|
| 233 |
+
losses = self.losses(seg_logits, gt_semantic_seg)
|
| 234 |
+
return losses
|
| 235 |
+
|
| 236 |
+
def forward_test(self, inputs, img_metas, test_cfg):
|
| 237 |
+
"""Forward function for testing.
|
| 238 |
+
|
| 239 |
+
Args:
|
| 240 |
+
inputs (list[Tensor]): List of multi-level img features.
|
| 241 |
+
img_metas (list[dict]): List of image info dict where each dict
|
| 242 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
| 243 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
| 244 |
+
For details on the values of these keys see
|
| 245 |
+
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
| 246 |
+
test_cfg (dict): The testing config.
|
| 247 |
+
|
| 248 |
+
Returns:
|
| 249 |
+
Tensor: Output segmentation map.
|
| 250 |
+
"""
|
| 251 |
+
return self.forward(inputs)
|
| 252 |
+
|
| 253 |
+
def cls_seg(self, feat):
|
| 254 |
+
"""Classify each pixel."""
|
| 255 |
+
if self.dropout is not None:
|
| 256 |
+
feat = self.dropout(feat)
|
| 257 |
+
output = self.conv_seg(feat)
|
| 258 |
+
return output
|
| 259 |
+
|
| 260 |
+
@force_fp32(apply_to=('seg_logit', ))
|
| 261 |
+
def losses(self, seg_logit, seg_label, addstr=''):
|
| 262 |
+
"""Compute segmentation loss."""
|
| 263 |
+
loss = dict()
|
| 264 |
+
seg_logit = resize(
|
| 265 |
+
input=seg_logit,
|
| 266 |
+
size=seg_label.shape[2:],
|
| 267 |
+
mode='bilinear',
|
| 268 |
+
align_corners=self.align_corners)
|
| 269 |
+
if self.sampler is not None:
|
| 270 |
+
seg_weight = self.sampler.sample(seg_logit, seg_label)
|
| 271 |
+
else:
|
| 272 |
+
seg_weight = None
|
| 273 |
+
seg_label = seg_label.squeeze(1)
|
| 274 |
+
|
| 275 |
+
if not isinstance(self.loss_decode, nn.ModuleList):
|
| 276 |
+
losses_decode = [self.loss_decode]
|
| 277 |
+
else:
|
| 278 |
+
losses_decode = self.loss_decode
|
| 279 |
+
for loss_decode in losses_decode:
|
| 280 |
+
if loss_decode.loss_name not in loss:
|
| 281 |
+
loss[loss_decode.loss_name+addstr] = loss_decode(
|
| 282 |
+
seg_logit,
|
| 283 |
+
seg_label,
|
| 284 |
+
weight=seg_weight,
|
| 285 |
+
ignore_index=self.ignore_index)
|
| 286 |
+
else:
|
| 287 |
+
loss[loss_decode.loss_name+addstr] += loss_decode(
|
| 288 |
+
seg_logit,
|
| 289 |
+
seg_label,
|
| 290 |
+
weight=seg_weight,
|
| 291 |
+
ignore_index=self.ignore_index)
|
| 292 |
+
|
| 293 |
+
loss['acc_seg'+addstr] = accuracy(
|
| 294 |
+
seg_logit, seg_label, ignore_index=self.ignore_index)
|
| 295 |
+
return loss
|
modelsforCIML/mmseg/models/decode_heads/fcn_head.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from mmcv.cnn import ConvModule
|
| 5 |
+
|
| 6 |
+
from ..builder import HEADS
|
| 7 |
+
from .decode_head import BaseDecodeHead
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@HEADS.register_module()
|
| 11 |
+
class FCNHead(BaseDecodeHead):
|
| 12 |
+
"""Fully Convolution Networks for Semantic Segmentation.
|
| 13 |
+
|
| 14 |
+
This head is implemented of `FCNNet <https://arxiv.org/abs/1411.4038>`_.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
num_convs (int): Number of convs in the head. Default: 2.
|
| 18 |
+
kernel_size (int): The kernel size for convs in the head. Default: 3.
|
| 19 |
+
concat_input (bool): Whether concat the input and output of convs
|
| 20 |
+
before classification layer.
|
| 21 |
+
dilation (int): The dilation rate for convs in the head. Default: 1.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self,
|
| 25 |
+
num_convs=2,
|
| 26 |
+
kernel_size=3,
|
| 27 |
+
concat_input=True,
|
| 28 |
+
dilation=1,
|
| 29 |
+
**kwargs):
|
| 30 |
+
assert num_convs >= 0 and dilation > 0 and isinstance(dilation, int)
|
| 31 |
+
self.num_convs = num_convs
|
| 32 |
+
self.concat_input = concat_input
|
| 33 |
+
self.kernel_size = kernel_size
|
| 34 |
+
super(FCNHead, self).__init__(**kwargs)
|
| 35 |
+
if num_convs == 0:
|
| 36 |
+
assert self.in_channels == self.channels
|
| 37 |
+
|
| 38 |
+
conv_padding = (kernel_size // 2) * dilation
|
| 39 |
+
convs = []
|
| 40 |
+
for i in range(num_convs):
|
| 41 |
+
_in_channels = self.in_channels if i == 0 else self.channels
|
| 42 |
+
convs.append(
|
| 43 |
+
ConvModule(
|
| 44 |
+
_in_channels,
|
| 45 |
+
self.channels,
|
| 46 |
+
kernel_size=kernel_size,
|
| 47 |
+
padding=conv_padding,
|
| 48 |
+
dilation=dilation,
|
| 49 |
+
conv_cfg=self.conv_cfg,
|
| 50 |
+
norm_cfg=self.norm_cfg,
|
| 51 |
+
act_cfg=self.act_cfg))
|
| 52 |
+
|
| 53 |
+
if len(convs) == 0:
|
| 54 |
+
self.convs = nn.Identity()
|
| 55 |
+
else:
|
| 56 |
+
self.convs = nn.Sequential(*convs)
|
| 57 |
+
if self.concat_input:
|
| 58 |
+
self.conv_cat = ConvModule(
|
| 59 |
+
self.in_channels + self.channels,
|
| 60 |
+
self.channels,
|
| 61 |
+
kernel_size=kernel_size,
|
| 62 |
+
padding=kernel_size // 2,
|
| 63 |
+
conv_cfg=self.conv_cfg,
|
| 64 |
+
norm_cfg=self.norm_cfg,
|
| 65 |
+
act_cfg=self.act_cfg)
|
| 66 |
+
|
| 67 |
+
def _forward_feature(self, inputs):
|
| 68 |
+
"""Forward function for feature maps before classifying each pixel with
|
| 69 |
+
``self.cls_seg`` fc.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
inputs (list[Tensor]): List of multi-level img features.
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
feats (Tensor): A tensor of shape (batch_size, self.channels,
|
| 76 |
+
H, W) which is feature map for last layer of decoder head.
|
| 77 |
+
"""
|
| 78 |
+
x = self._transform_inputs(inputs)
|
| 79 |
+
feats = self.convs(x)
|
| 80 |
+
if self.concat_input:
|
| 81 |
+
feats = self.conv_cat(torch.cat([x, feats], dim=1))
|
| 82 |
+
return feats
|
| 83 |
+
|
| 84 |
+
def forward(self, inputs):
|
| 85 |
+
"""Forward function."""
|
| 86 |
+
output = self._forward_feature(inputs)
|
| 87 |
+
output = self.cls_seg(output)
|
| 88 |
+
return output
|
modelsforCIML/mmseg/models/decode_heads/psp_head.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from mmcv.cnn import ConvModule
|
| 5 |
+
|
| 6 |
+
from mmseg.ops import resize
|
| 7 |
+
from ..builder import HEADS
|
| 8 |
+
from .decode_head import BaseDecodeHead
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class PPM(nn.ModuleList):
|
| 12 |
+
"""Pooling Pyramid Module used in PSPNet.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
| 16 |
+
Module.
|
| 17 |
+
in_channels (int): Input channels.
|
| 18 |
+
channels (int): Channels after modules, before conv_seg.
|
| 19 |
+
conv_cfg (dict|None): Config of conv layers.
|
| 20 |
+
norm_cfg (dict|None): Config of norm layers.
|
| 21 |
+
act_cfg (dict): Config of activation layers.
|
| 22 |
+
align_corners (bool): align_corners argument of F.interpolate.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg,
|
| 26 |
+
act_cfg, align_corners, **kwargs):
|
| 27 |
+
super(PPM, self).__init__()
|
| 28 |
+
self.pool_scales = pool_scales
|
| 29 |
+
self.align_corners = align_corners
|
| 30 |
+
self.in_channels = in_channels
|
| 31 |
+
self.channels = channels
|
| 32 |
+
self.conv_cfg = conv_cfg
|
| 33 |
+
self.norm_cfg = norm_cfg
|
| 34 |
+
self.act_cfg = act_cfg
|
| 35 |
+
for pool_scale in pool_scales:
|
| 36 |
+
self.append(
|
| 37 |
+
nn.Sequential(
|
| 38 |
+
nn.AdaptiveAvgPool2d(pool_scale),
|
| 39 |
+
ConvModule(
|
| 40 |
+
self.in_channels,
|
| 41 |
+
self.channels,
|
| 42 |
+
1,
|
| 43 |
+
conv_cfg=self.conv_cfg,
|
| 44 |
+
norm_cfg=self.norm_cfg,
|
| 45 |
+
act_cfg=self.act_cfg,
|
| 46 |
+
**kwargs)))
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
"""Forward function."""
|
| 50 |
+
ppm_outs = []
|
| 51 |
+
for ppm in self:
|
| 52 |
+
ppm_out = ppm(x)
|
| 53 |
+
upsampled_ppm_out = resize(
|
| 54 |
+
ppm_out,
|
| 55 |
+
size=x.size()[2:],
|
| 56 |
+
mode='bilinear',
|
| 57 |
+
align_corners=self.align_corners)
|
| 58 |
+
ppm_outs.append(upsampled_ppm_out)
|
| 59 |
+
return ppm_outs
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@HEADS.register_module()
|
| 63 |
+
class PSPHead(BaseDecodeHead):
|
| 64 |
+
"""Pyramid Scene Parsing Network.
|
| 65 |
+
|
| 66 |
+
This head is the implementation of
|
| 67 |
+
`PSPNet <https://arxiv.org/abs/1612.01105>`_.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
| 71 |
+
Module. Default: (1, 2, 3, 6).
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
|
| 75 |
+
super(PSPHead, self).__init__(**kwargs)
|
| 76 |
+
assert isinstance(pool_scales, (list, tuple))
|
| 77 |
+
self.pool_scales = pool_scales
|
| 78 |
+
self.psp_modules = PPM(
|
| 79 |
+
self.pool_scales,
|
| 80 |
+
self.in_channels,
|
| 81 |
+
self.channels,
|
| 82 |
+
conv_cfg=self.conv_cfg,
|
| 83 |
+
norm_cfg=self.norm_cfg,
|
| 84 |
+
act_cfg=self.act_cfg,
|
| 85 |
+
align_corners=self.align_corners)
|
| 86 |
+
self.bottleneck = ConvModule(
|
| 87 |
+
self.in_channels + len(pool_scales) * self.channels,
|
| 88 |
+
self.channels,
|
| 89 |
+
3,
|
| 90 |
+
padding=1,
|
| 91 |
+
conv_cfg=self.conv_cfg,
|
| 92 |
+
norm_cfg=self.norm_cfg,
|
| 93 |
+
act_cfg=self.act_cfg)
|
| 94 |
+
|
| 95 |
+
def _forward_feature(self, inputs):
|
| 96 |
+
"""Forward function for feature maps before classifying each pixel with
|
| 97 |
+
``self.cls_seg`` fc.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
inputs (list[Tensor]): List of multi-level img features.
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
feats (Tensor): A tensor of shape (batch_size, self.channels,
|
| 104 |
+
H, W) which is feature map for last layer of decoder head.
|
| 105 |
+
"""
|
| 106 |
+
x = self._transform_inputs(inputs)
|
| 107 |
+
psp_outs = [x]
|
| 108 |
+
psp_outs.extend(self.psp_modules(x))
|
| 109 |
+
psp_outs = torch.cat(psp_outs, dim=1)
|
| 110 |
+
feats = self.bottleneck(psp_outs)
|
| 111 |
+
return feats
|
| 112 |
+
|
| 113 |
+
def forward(self, inputs):
|
| 114 |
+
"""Forward function."""
|
| 115 |
+
output = self._forward_feature(inputs)
|
| 116 |
+
output = self.cls_seg(output)
|
| 117 |
+
return output
|
modelsforCIML/mmseg/models/decode_heads/sep_aspp_head.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
|
| 5 |
+
|
| 6 |
+
from mmseg.ops import resize
|
| 7 |
+
from ..builder import HEADS
|
| 8 |
+
from .aspp_head import ASPPHead, ASPPModule
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class DepthwiseSeparableASPPModule(ASPPModule):
|
| 12 |
+
"""Atrous Spatial Pyramid Pooling (ASPP) Module with depthwise separable
|
| 13 |
+
conv."""
|
| 14 |
+
|
| 15 |
+
def __init__(self, **kwargs):
|
| 16 |
+
super(DepthwiseSeparableASPPModule, self).__init__(**kwargs)
|
| 17 |
+
for i, dilation in enumerate(self.dilations):
|
| 18 |
+
if dilation > 1:
|
| 19 |
+
self[i] = DepthwiseSeparableConvModule(
|
| 20 |
+
self.in_channels,
|
| 21 |
+
self.channels,
|
| 22 |
+
3,
|
| 23 |
+
dilation=dilation,
|
| 24 |
+
padding=dilation,
|
| 25 |
+
norm_cfg=self.norm_cfg,
|
| 26 |
+
act_cfg=self.act_cfg)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@HEADS.register_module()
|
| 30 |
+
class DepthwiseSeparableASPPHead(ASPPHead):
|
| 31 |
+
"""Encoder-Decoder with Atrous Separable Convolution for Semantic Image
|
| 32 |
+
Segmentation.
|
| 33 |
+
|
| 34 |
+
This head is the implementation of `DeepLabV3+
|
| 35 |
+
<https://arxiv.org/abs/1802.02611>`_.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
c1_in_channels (int): The input channels of c1 decoder. If is 0,
|
| 39 |
+
the no decoder will be used.
|
| 40 |
+
c1_channels (int): The intermediate channels of c1 decoder.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(self, c1_in_channels, c1_channels, **kwargs):
|
| 44 |
+
super(DepthwiseSeparableASPPHead, self).__init__(**kwargs)
|
| 45 |
+
assert c1_in_channels >= 0
|
| 46 |
+
self.aspp_modules = DepthwiseSeparableASPPModule(
|
| 47 |
+
dilations=self.dilations,
|
| 48 |
+
in_channels=self.in_channels,
|
| 49 |
+
channels=self.channels,
|
| 50 |
+
conv_cfg=self.conv_cfg,
|
| 51 |
+
norm_cfg=self.norm_cfg,
|
| 52 |
+
act_cfg=self.act_cfg)
|
| 53 |
+
# self.cls_seg = nn.Conv2d(512,2,1,1,0)
|
| 54 |
+
if c1_in_channels > 0:
|
| 55 |
+
self.c1_bottleneck = ConvModule(
|
| 56 |
+
c1_in_channels,
|
| 57 |
+
c1_channels,
|
| 58 |
+
1,
|
| 59 |
+
conv_cfg=self.conv_cfg,
|
| 60 |
+
norm_cfg=self.norm_cfg,
|
| 61 |
+
act_cfg=self.act_cfg)
|
| 62 |
+
else:
|
| 63 |
+
self.c1_bottleneck = None
|
| 64 |
+
self.sep_bottleneck = nn.Sequential(
|
| 65 |
+
DepthwiseSeparableConvModule(
|
| 66 |
+
self.channels + c1_channels,
|
| 67 |
+
self.channels,
|
| 68 |
+
3,
|
| 69 |
+
padding=1,
|
| 70 |
+
norm_cfg=self.norm_cfg,
|
| 71 |
+
act_cfg=self.act_cfg),
|
| 72 |
+
DepthwiseSeparableConvModule(
|
| 73 |
+
self.channels,
|
| 74 |
+
self.channels,
|
| 75 |
+
3,
|
| 76 |
+
padding=1,
|
| 77 |
+
norm_cfg=self.norm_cfg,
|
| 78 |
+
act_cfg=self.act_cfg))
|
| 79 |
+
|
| 80 |
+
def forward(self, inputs, trans=True):
|
| 81 |
+
"""Forward function."""
|
| 82 |
+
if trans:
|
| 83 |
+
x = self._transform_inputs(inputs)
|
| 84 |
+
x = inputs[1]
|
| 85 |
+
aspp_outs = [
|
| 86 |
+
resize(
|
| 87 |
+
self.image_pool(x),
|
| 88 |
+
size=x.size()[2:],
|
| 89 |
+
mode='bilinear',
|
| 90 |
+
align_corners=self.align_corners)
|
| 91 |
+
]
|
| 92 |
+
aspp_outs.extend(self.aspp_modules(x))
|
| 93 |
+
aspp_outs = torch.cat(aspp_outs, dim=1)
|
| 94 |
+
output = self.bottleneck(aspp_outs)
|
| 95 |
+
if self.c1_bottleneck is not None:
|
| 96 |
+
c1_output = self.c1_bottleneck(inputs[0])
|
| 97 |
+
output = resize(
|
| 98 |
+
input=output,
|
| 99 |
+
size=c1_output.shape[2:],
|
| 100 |
+
mode='bilinear',
|
| 101 |
+
align_corners=self.align_corners)
|
| 102 |
+
output = torch.cat([output, c1_output], dim=1)
|
| 103 |
+
output = self.sep_bottleneck(output)
|
| 104 |
+
output = self.cls_seg(output)
|
| 105 |
+
return output
|
modelsforCIML/mmseg/models/decode_heads/uper_head.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from mmcv.cnn import ConvModule
|
| 5 |
+
|
| 6 |
+
from mmseg.ops import resize
|
| 7 |
+
from ..builder import HEADS
|
| 8 |
+
from .decode_head import BaseDecodeHead
|
| 9 |
+
from .psp_head import PPM
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@HEADS.register_module()
|
| 13 |
+
class UPerHead(BaseDecodeHead):
|
| 14 |
+
"""Unified Perceptual Parsing for Scene Understanding.
|
| 15 |
+
|
| 16 |
+
This head is the implementation of `UPerNet
|
| 17 |
+
<https://arxiv.org/abs/1807.10221>`_.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
| 21 |
+
Module applied on the last feature. Default: (1, 2, 3, 6).
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
|
| 25 |
+
super(UPerHead, self).__init__(
|
| 26 |
+
input_transform='multiple_select', **kwargs)
|
| 27 |
+
# PSP Module
|
| 28 |
+
self.psp_modules = PPM(
|
| 29 |
+
pool_scales,
|
| 30 |
+
self.in_channels[-1],
|
| 31 |
+
self.channels,
|
| 32 |
+
conv_cfg=self.conv_cfg,
|
| 33 |
+
norm_cfg=self.norm_cfg,
|
| 34 |
+
act_cfg=self.act_cfg,
|
| 35 |
+
align_corners=self.align_corners)
|
| 36 |
+
self.bottleneck = ConvModule(
|
| 37 |
+
self.in_channels[-1] + len(pool_scales) * self.channels,
|
| 38 |
+
self.channels,
|
| 39 |
+
3,
|
| 40 |
+
padding=1,
|
| 41 |
+
conv_cfg=self.conv_cfg,
|
| 42 |
+
norm_cfg=self.norm_cfg,
|
| 43 |
+
act_cfg=self.act_cfg)
|
| 44 |
+
# FPN Module
|
| 45 |
+
self.lateral_convs = nn.ModuleList()
|
| 46 |
+
self.fpn_convs = nn.ModuleList()
|
| 47 |
+
for in_channels in self.in_channels[:-1]: # skip the top layer
|
| 48 |
+
l_conv = ConvModule(
|
| 49 |
+
in_channels,
|
| 50 |
+
self.channels,
|
| 51 |
+
1,
|
| 52 |
+
conv_cfg=self.conv_cfg,
|
| 53 |
+
norm_cfg=self.norm_cfg,
|
| 54 |
+
act_cfg=self.act_cfg,
|
| 55 |
+
inplace=False)
|
| 56 |
+
fpn_conv = ConvModule(
|
| 57 |
+
self.channels,
|
| 58 |
+
self.channels,
|
| 59 |
+
3,
|
| 60 |
+
padding=1,
|
| 61 |
+
conv_cfg=self.conv_cfg,
|
| 62 |
+
norm_cfg=self.norm_cfg,
|
| 63 |
+
act_cfg=self.act_cfg,
|
| 64 |
+
inplace=False)
|
| 65 |
+
self.lateral_convs.append(l_conv)
|
| 66 |
+
self.fpn_convs.append(fpn_conv)
|
| 67 |
+
|
| 68 |
+
def psp_forward(self, inputs):
|
| 69 |
+
"""Forward function of PSP module."""
|
| 70 |
+
x = inputs[-1]
|
| 71 |
+
psp_outs = [x]
|
| 72 |
+
psp_outs.extend(self.psp_modules(x))
|
| 73 |
+
psp_outs = torch.cat(psp_outs, dim=1)
|
| 74 |
+
output = self.bottleneck(psp_outs)
|
| 75 |
+
|
| 76 |
+
return output
|
| 77 |
+
|
| 78 |
+
def _forward_feature(self, inputs):
|
| 79 |
+
"""Forward function for feature maps before classifying each pixel with
|
| 80 |
+
``self.cls_seg`` fc.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
inputs (list[Tensor]): List of multi-level img features.
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
feats (Tensor): A tensor of shape (batch_size, self.channels,
|
| 87 |
+
H, W) which is feature map for last layer of decoder head.
|
| 88 |
+
"""
|
| 89 |
+
inputs = self._transform_inputs(inputs)
|
| 90 |
+
|
| 91 |
+
# build laterals
|
| 92 |
+
laterals = [
|
| 93 |
+
lateral_conv(inputs[i])
|
| 94 |
+
for i, lateral_conv in enumerate(self.lateral_convs)
|
| 95 |
+
]
|
| 96 |
+
|
| 97 |
+
laterals.append(self.psp_forward(inputs))
|
| 98 |
+
|
| 99 |
+
# build top-down path
|
| 100 |
+
used_backbone_levels = len(laterals)
|
| 101 |
+
for i in range(used_backbone_levels - 1, 0, -1):
|
| 102 |
+
prev_shape = laterals[i - 1].shape[2:]
|
| 103 |
+
laterals[i - 1] = laterals[i - 1] + resize(
|
| 104 |
+
laterals[i],
|
| 105 |
+
size=prev_shape,
|
| 106 |
+
mode='bilinear',
|
| 107 |
+
align_corners=self.align_corners)
|
| 108 |
+
|
| 109 |
+
# build outputs
|
| 110 |
+
fpn_outs = [
|
| 111 |
+
self.fpn_convs[i](laterals[i])
|
| 112 |
+
for i in range(used_backbone_levels - 1)
|
| 113 |
+
]
|
| 114 |
+
# append psp feature
|
| 115 |
+
fpn_outs.append(laterals[-1])
|
| 116 |
+
|
| 117 |
+
for i in range(used_backbone_levels - 1, -1, -1):
|
| 118 |
+
fpn_outs[i] = resize(
|
| 119 |
+
fpn_outs[i],
|
| 120 |
+
size=fpn_outs[1].shape[2:],
|
| 121 |
+
mode='bilinear',
|
| 122 |
+
align_corners=self.align_corners)
|
| 123 |
+
return fpn_outs#[:3]
|
| 124 |
+
|
| 125 |
+
def forward(self, inputs):
|
| 126 |
+
"""Forward function."""
|
| 127 |
+
output = self._forward_feature(inputs)
|
| 128 |
+
return output
|
modelsforCIML/mmseg/models/decode_heads/uper_lab.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from mmcv.cnn import ConvModule
|
| 5 |
+
from .sep_aspp_head import DepthwiseSeparableASPPHead
|
| 6 |
+
from mmseg.ops import resize
|
| 7 |
+
from ..builder import HEADS
|
| 8 |
+
from .decode_head import BaseDecodeHead
|
| 9 |
+
from .psp_head import PPM
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@HEADS.register_module()
|
| 13 |
+
class UPerLab(BaseDecodeHead):
|
| 14 |
+
|
| 15 |
+
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
|
| 16 |
+
super(UPerLab, self).__init__(
|
| 17 |
+
input_transform='multiple_select', **kwargs)
|
| 18 |
+
# PSP Module
|
| 19 |
+
self.deeplab = DepthwiseSeparableASPPHead(in_channels=2048,in_index=3,channels=512,dilations=(1, 12, 24, 36),c1_in_channels=256,c1_channels=48,dropout_ratio=0.1,num_classes=2,norm_cfg=dict(type='SyncBN', requires_grad=True),align_corners=False)
|
| 20 |
+
self.convert = nn.Conv2d(512,256,1,1,0)
|
| 21 |
+
self.psp_modules = PPM(
|
| 22 |
+
pool_scales,
|
| 23 |
+
self.in_channels[-1],
|
| 24 |
+
self.channels,
|
| 25 |
+
conv_cfg=self.conv_cfg,
|
| 26 |
+
norm_cfg=self.norm_cfg,
|
| 27 |
+
act_cfg=self.act_cfg,
|
| 28 |
+
align_corners=self.align_corners)
|
| 29 |
+
self.bottleneck = ConvModule(
|
| 30 |
+
self.in_channels[-1] + len(pool_scales) * self.channels,
|
| 31 |
+
self.channels,
|
| 32 |
+
3,
|
| 33 |
+
padding=1,
|
| 34 |
+
conv_cfg=self.conv_cfg,
|
| 35 |
+
norm_cfg=self.norm_cfg,
|
| 36 |
+
act_cfg=self.act_cfg)
|
| 37 |
+
# FPN Module
|
| 38 |
+
self.lateral_convs = nn.ModuleList()
|
| 39 |
+
self.fpn_convs = nn.ModuleList()
|
| 40 |
+
for in_channels in self.in_channels[:-1]: # skip the top layer
|
| 41 |
+
l_conv = ConvModule(
|
| 42 |
+
in_channels,
|
| 43 |
+
self.channels,
|
| 44 |
+
1,
|
| 45 |
+
conv_cfg=self.conv_cfg,
|
| 46 |
+
norm_cfg=self.norm_cfg,
|
| 47 |
+
act_cfg=self.act_cfg,
|
| 48 |
+
inplace=False)
|
| 49 |
+
fpn_conv = ConvModule(
|
| 50 |
+
self.channels,
|
| 51 |
+
self.channels,
|
| 52 |
+
3,
|
| 53 |
+
padding=1,
|
| 54 |
+
conv_cfg=self.conv_cfg,
|
| 55 |
+
norm_cfg=self.norm_cfg,
|
| 56 |
+
act_cfg=self.act_cfg,
|
| 57 |
+
inplace=False)
|
| 58 |
+
self.lateral_convs.append(l_conv)
|
| 59 |
+
self.fpn_convs.append(fpn_conv)
|
| 60 |
+
|
| 61 |
+
def psp_forward(self, inputs):
|
| 62 |
+
"""Forward function of PSP module."""
|
| 63 |
+
x = inputs[-1]
|
| 64 |
+
psp_outs = [x]
|
| 65 |
+
psp_outs.extend(self.psp_modules(x))
|
| 66 |
+
psp_outs = torch.cat(psp_outs, dim=1)
|
| 67 |
+
output = self.bottleneck(psp_outs)
|
| 68 |
+
|
| 69 |
+
return output
|
| 70 |
+
|
| 71 |
+
def forward(self, inputs):
|
| 72 |
+
inputs = self._transform_inputs(inputs)
|
| 73 |
+
|
| 74 |
+
# build laterals
|
| 75 |
+
laterals = [
|
| 76 |
+
lateral_conv(inputs[i])
|
| 77 |
+
for i, lateral_conv in enumerate(self.lateral_convs)
|
| 78 |
+
]
|
| 79 |
+
|
| 80 |
+
laterals.append(self.psp_forward(inputs))
|
| 81 |
+
|
| 82 |
+
# build top-down path
|
| 83 |
+
used_backbone_levels = len(laterals)
|
| 84 |
+
for i in range(used_backbone_levels - 1, 0, -1):
|
| 85 |
+
prev_shape = laterals[i - 1].shape[2:]
|
| 86 |
+
laterals[i - 1] = laterals[i - 1] + resize(
|
| 87 |
+
laterals[i],
|
| 88 |
+
size=prev_shape,
|
| 89 |
+
mode='bilinear',
|
| 90 |
+
align_corners=self.align_corners)
|
| 91 |
+
|
| 92 |
+
# build outputs
|
| 93 |
+
fpn_outs = [
|
| 94 |
+
self.fpn_convs[i](laterals[i])
|
| 95 |
+
for i in range(used_backbone_levels - 1)
|
| 96 |
+
]
|
| 97 |
+
# append psp feature
|
| 98 |
+
fpn_outs.append(laterals[-1])
|
| 99 |
+
if self.training:
|
| 100 |
+
cls_aux = self.cls_seg(fpn_outs[0])
|
| 101 |
+
feat0 = self.convert(fpn_outs[0])
|
| 102 |
+
for i in range(used_backbone_levels - 1, 0, -1):
|
| 103 |
+
fpn_outs[i] = resize(
|
| 104 |
+
fpn_outs[i],
|
| 105 |
+
size=fpn_outs[1].shape[2:],
|
| 106 |
+
mode='bilinear',
|
| 107 |
+
align_corners=self.align_corners)
|
| 108 |
+
fpn_outs[0] = resize(fpn_outs[0], size=fpn_outs[1].shape[2:], mode='bilinear', align_corners=self.align_corners)
|
| 109 |
+
fpn_outs = torch.cat(fpn_outs, dim=1)
|
| 110 |
+
if self.training:
|
| 111 |
+
return (self.deeplab([feat0, fpn_outs], trans=False), cls_aux)# feats
|
| 112 |
+
else:
|
| 113 |
+
return self.deeplab([feat0, fpn_outs], trans=False)
|
| 114 |
+
|
| 115 |
+
def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):
|
| 116 |
+
seg_logits, aux_logits = self(inputs)
|
| 117 |
+
losses = self.losses(seg_logits, gt_semantic_seg)
|
| 118 |
+
losses_aux = self.losses(aux_logits, gt_semantic_seg, addstr='_uper')
|
| 119 |
+
losses.update(losses_aux)
|
| 120 |
+
return losses
|
modelsforCIML/mmseg/models/losses/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from .accuracy import Accuracy, accuracy
|
| 3 |
+
from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
|
| 4 |
+
cross_entropy, mask_cross_entropy)
|
| 5 |
+
from .dice_loss import DiceLoss
|
| 6 |
+
# from .focal_loss import FocalLoss
|
| 7 |
+
from .lovasz_loss import LovaszLoss
|
| 8 |
+
from .tversky_loss import TverskyLoss
|
| 9 |
+
from .utils import reduce_loss, weight_reduce_loss, weighted_loss
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy',
|
| 13 |
+
'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
|
| 14 |
+
'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss',
|
| 15 |
+
'TverskyLoss'
|
| 16 |
+
]
|
modelsforCIML/mmseg/models/losses/accuracy.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def accuracy(pred, target, topk=1, thresh=None, ignore_index=-100):
|
| 7 |
+
"""Calculate accuracy according to the prediction and target.
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
pred (torch.Tensor): The model prediction, shape (N, num_class, ...)
|
| 11 |
+
target (torch.Tensor): The target of each prediction, shape (N, , ...)
|
| 12 |
+
ignore_index (int | None): The label index to be ignored. Default: None
|
| 13 |
+
topk (int | tuple[int], optional): If the predictions in ``topk``
|
| 14 |
+
matches the target, the predictions will be regarded as
|
| 15 |
+
correct ones. Defaults to 1.
|
| 16 |
+
thresh (float, optional): If not None, predictions with scores under
|
| 17 |
+
this threshold are considered incorrect. Default to None.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
float | tuple[float]: If the input ``topk`` is a single integer,
|
| 21 |
+
the function will return a single float as accuracy. If
|
| 22 |
+
``topk`` is a tuple containing multiple integers, the
|
| 23 |
+
function will return a tuple containing accuracies of
|
| 24 |
+
each ``topk`` number.
|
| 25 |
+
"""
|
| 26 |
+
assert isinstance(topk, (int, tuple))
|
| 27 |
+
if isinstance(topk, int):
|
| 28 |
+
topk = (topk, )
|
| 29 |
+
return_single = True
|
| 30 |
+
else:
|
| 31 |
+
return_single = False
|
| 32 |
+
|
| 33 |
+
maxk = max(topk)
|
| 34 |
+
if pred.size(0) == 0:
|
| 35 |
+
accu = [pred.new_tensor(0.) for i in range(len(topk))]
|
| 36 |
+
return accu[0] if return_single else accu
|
| 37 |
+
assert pred.ndim == target.ndim + 1
|
| 38 |
+
assert pred.size(0) == target.size(0)
|
| 39 |
+
assert maxk <= pred.size(1), \
|
| 40 |
+
f'maxk {maxk} exceeds pred dimension {pred.size(1)}'
|
| 41 |
+
pred_value, pred_label = pred.topk(maxk, dim=1)
|
| 42 |
+
# transpose to shape (maxk, N, ...)
|
| 43 |
+
pred_label = pred_label.transpose(0, 1)
|
| 44 |
+
correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label))
|
| 45 |
+
if thresh is not None:
|
| 46 |
+
# Only prediction values larger than thresh are counted as correct
|
| 47 |
+
correct = correct & (pred_value > thresh).t()
|
| 48 |
+
if ignore_index is not None:
|
| 49 |
+
correct = correct[:, target != ignore_index]
|
| 50 |
+
res = []
|
| 51 |
+
eps = torch.finfo(torch.float32).eps
|
| 52 |
+
for k in topk:
|
| 53 |
+
# Avoid causing ZeroDivisionError when all pixels
|
| 54 |
+
# of an image are ignored
|
| 55 |
+
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + eps
|
| 56 |
+
if ignore_index is not None:
|
| 57 |
+
total_num = target[target != ignore_index].numel() + eps
|
| 58 |
+
else:
|
| 59 |
+
total_num = target.numel() + eps
|
| 60 |
+
res.append(correct_k.mul_(100.0 / total_num))
|
| 61 |
+
return res[0] if return_single else res
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class Accuracy(nn.Module):
|
| 65 |
+
"""Accuracy calculation module."""
|
| 66 |
+
|
| 67 |
+
def __init__(self, topk=(1, ), thresh=None, ignore_index=None):
|
| 68 |
+
"""Module to calculate the accuracy.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
topk (tuple, optional): The criterion used to calculate the
|
| 72 |
+
accuracy. Defaults to (1,).
|
| 73 |
+
thresh (float, optional): If not None, predictions with scores
|
| 74 |
+
under this threshold are considered incorrect. Default to None.
|
| 75 |
+
"""
|
| 76 |
+
super().__init__()
|
| 77 |
+
self.topk = topk
|
| 78 |
+
self.thresh = thresh
|
| 79 |
+
self.ignore_index = ignore_index
|
| 80 |
+
|
| 81 |
+
def forward(self, pred, target):
|
| 82 |
+
"""Forward function to calculate accuracy.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
pred (torch.Tensor): Prediction of models.
|
| 86 |
+
target (torch.Tensor): Target for each prediction.
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
tuple[float]: The accuracies under different topk criterions.
|
| 90 |
+
"""
|
| 91 |
+
return accuracy(pred, target, self.topk, self.thresh,
|
| 92 |
+
self.ignore_index)
|
modelsforCIML/mmseg/models/losses/cross_entropy_loss.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import warnings
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
from ..builder import LOSSES
|
| 9 |
+
from .utils import get_class_weight, weight_reduce_loss
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def cross_entropy(pred,
|
| 13 |
+
label,
|
| 14 |
+
weight=None,
|
| 15 |
+
class_weight=None,
|
| 16 |
+
reduction='mean',
|
| 17 |
+
avg_factor=None,
|
| 18 |
+
ignore_index=-100,
|
| 19 |
+
avg_non_ignore=False):
|
| 20 |
+
"""cross_entropy. The wrapper function for :func:`F.cross_entropy`
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
pred (torch.Tensor): The prediction with shape (N, 1).
|
| 24 |
+
label (torch.Tensor): The learning label of the prediction.
|
| 25 |
+
weight (torch.Tensor, optional): Sample-wise loss weight.
|
| 26 |
+
Default: None.
|
| 27 |
+
class_weight (list[float], optional): The weight for each class.
|
| 28 |
+
Default: None.
|
| 29 |
+
reduction (str, optional): The method used to reduce the loss.
|
| 30 |
+
Options are 'none', 'mean' and 'sum'. Default: 'mean'.
|
| 31 |
+
avg_factor (int, optional): Average factor that is used to average
|
| 32 |
+
the loss. Default: None.
|
| 33 |
+
ignore_index (int): Specifies a target value that is ignored and
|
| 34 |
+
does not contribute to the input gradients. When
|
| 35 |
+
``avg_non_ignore `` is ``True``, and the ``reduction`` is
|
| 36 |
+
``''mean''``, the loss is averaged over non-ignored targets.
|
| 37 |
+
Defaults: -100.
|
| 38 |
+
avg_non_ignore (bool): The flag decides to whether the loss is
|
| 39 |
+
only averaged over non-ignored targets. Default: False.
|
| 40 |
+
`New in version 0.23.0.`
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
# class_weight is a manual rescaling weight given to each class.
|
| 44 |
+
# If given, has to be a Tensor of size C element-wise losses
|
| 45 |
+
loss = F.cross_entropy(
|
| 46 |
+
pred,
|
| 47 |
+
label,
|
| 48 |
+
weight=class_weight,
|
| 49 |
+
reduction='none',
|
| 50 |
+
ignore_index=ignore_index)
|
| 51 |
+
|
| 52 |
+
# apply weights and do the reduction
|
| 53 |
+
# average loss over non-ignored elements
|
| 54 |
+
# pytorch's official cross_entropy average loss over non-ignored elements
|
| 55 |
+
# refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa
|
| 56 |
+
if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
|
| 57 |
+
avg_factor = label.numel() - (label == ignore_index).sum().item()
|
| 58 |
+
if weight is not None:
|
| 59 |
+
weight = weight.float()
|
| 60 |
+
loss = weight_reduce_loss(
|
| 61 |
+
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
|
| 62 |
+
|
| 63 |
+
return loss
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
|
| 67 |
+
"""Expand onehot labels to match the size of prediction."""
|
| 68 |
+
bin_labels = labels.new_zeros(target_shape)
|
| 69 |
+
valid_mask = (labels >= 0) & (labels != ignore_index)
|
| 70 |
+
inds = torch.nonzero(valid_mask, as_tuple=True)
|
| 71 |
+
|
| 72 |
+
if inds[0].numel() > 0:
|
| 73 |
+
if labels.dim() == 3:
|
| 74 |
+
bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1
|
| 75 |
+
else:
|
| 76 |
+
bin_labels[inds[0], labels[valid_mask]] = 1
|
| 77 |
+
|
| 78 |
+
valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()
|
| 79 |
+
|
| 80 |
+
if label_weights is None:
|
| 81 |
+
bin_label_weights = valid_mask
|
| 82 |
+
else:
|
| 83 |
+
bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
|
| 84 |
+
bin_label_weights = bin_label_weights * valid_mask
|
| 85 |
+
|
| 86 |
+
return bin_labels, bin_label_weights, valid_mask
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def binary_cross_entropy(pred,
|
| 90 |
+
label,
|
| 91 |
+
weight=None,
|
| 92 |
+
reduction='mean',
|
| 93 |
+
avg_factor=None,
|
| 94 |
+
class_weight=None,
|
| 95 |
+
ignore_index=-100,
|
| 96 |
+
avg_non_ignore=False,
|
| 97 |
+
**kwargs):
|
| 98 |
+
"""Calculate the binary CrossEntropy loss.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
pred (torch.Tensor): The prediction with shape (N, 1).
|
| 102 |
+
label (torch.Tensor): The learning label of the prediction.
|
| 103 |
+
Note: In bce loss, label < 0 is invalid.
|
| 104 |
+
weight (torch.Tensor, optional): Sample-wise loss weight.
|
| 105 |
+
reduction (str, optional): The method used to reduce the loss.
|
| 106 |
+
Options are "none", "mean" and "sum".
|
| 107 |
+
avg_factor (int, optional): Average factor that is used to average
|
| 108 |
+
the loss. Defaults to None.
|
| 109 |
+
class_weight (list[float], optional): The weight for each class.
|
| 110 |
+
ignore_index (int): The label index to be ignored. Default: -100.
|
| 111 |
+
avg_non_ignore (bool): The flag decides to whether the loss is
|
| 112 |
+
only averaged over non-ignored targets. Default: False.
|
| 113 |
+
`New in version 0.23.0.`
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
torch.Tensor: The calculated loss
|
| 117 |
+
"""
|
| 118 |
+
if pred.size(1) == 1:
|
| 119 |
+
# For binary class segmentation, the shape of pred is
|
| 120 |
+
# [N, 1, H, W] and that of label is [N, H, W].
|
| 121 |
+
# As the ignore_index often set as 255, so the
|
| 122 |
+
# binary class label check should mask out
|
| 123 |
+
# ignore_index
|
| 124 |
+
assert label[label != ignore_index].max() <= 1, \
|
| 125 |
+
'For pred with shape [N, 1, H, W], its label must have at ' \
|
| 126 |
+
'most 2 classes'
|
| 127 |
+
pred = pred.squeeze(1)
|
| 128 |
+
if pred.dim() != label.dim():
|
| 129 |
+
assert (pred.dim() == 2 and label.dim() == 1) or (
|
| 130 |
+
pred.dim() == 4 and label.dim() == 3), \
|
| 131 |
+
'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \
|
| 132 |
+
'H, W], label shape [N, H, W] are supported'
|
| 133 |
+
# `weight` returned from `_expand_onehot_labels`
|
| 134 |
+
# has been treated for valid (non-ignore) pixels
|
| 135 |
+
label, weight, valid_mask = _expand_onehot_labels(
|
| 136 |
+
label, weight, pred.shape, ignore_index)
|
| 137 |
+
else:
|
| 138 |
+
# should mask out the ignored elements
|
| 139 |
+
valid_mask = ((label >= 0) & (label != ignore_index)).float()
|
| 140 |
+
if weight is not None:
|
| 141 |
+
weight = weight * valid_mask
|
| 142 |
+
else:
|
| 143 |
+
weight = valid_mask
|
| 144 |
+
# average loss over non-ignored and valid elements
|
| 145 |
+
if reduction == 'mean' and avg_factor is None and avg_non_ignore:
|
| 146 |
+
avg_factor = valid_mask.sum().item()
|
| 147 |
+
|
| 148 |
+
loss = F.binary_cross_entropy_with_logits(
|
| 149 |
+
pred, label.float(), pos_weight=class_weight, reduction='none')
|
| 150 |
+
# do the reduction for the weighted loss
|
| 151 |
+
loss = weight_reduce_loss(
|
| 152 |
+
loss, weight, reduction=reduction, avg_factor=avg_factor)
|
| 153 |
+
|
| 154 |
+
return loss
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def mask_cross_entropy(pred,
|
| 158 |
+
target,
|
| 159 |
+
label,
|
| 160 |
+
reduction='mean',
|
| 161 |
+
avg_factor=None,
|
| 162 |
+
class_weight=None,
|
| 163 |
+
ignore_index=None,
|
| 164 |
+
**kwargs):
|
| 165 |
+
"""Calculate the CrossEntropy loss for masks.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
pred (torch.Tensor): The prediction with shape (N, C), C is the number
|
| 169 |
+
of classes.
|
| 170 |
+
target (torch.Tensor): The learning label of the prediction.
|
| 171 |
+
label (torch.Tensor): ``label`` indicates the class label of the mask'
|
| 172 |
+
corresponding object. This will be used to select the mask in the
|
| 173 |
+
of the class which the object belongs to when the mask prediction
|
| 174 |
+
if not class-agnostic.
|
| 175 |
+
reduction (str, optional): The method used to reduce the loss.
|
| 176 |
+
Options are "none", "mean" and "sum".
|
| 177 |
+
avg_factor (int, optional): Average factor that is used to average
|
| 178 |
+
the loss. Defaults to None.
|
| 179 |
+
class_weight (list[float], optional): The weight for each class.
|
| 180 |
+
ignore_index (None): Placeholder, to be consistent with other loss.
|
| 181 |
+
Default: None.
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
torch.Tensor: The calculated loss
|
| 185 |
+
"""
|
| 186 |
+
assert ignore_index is None, 'BCE loss does not support ignore_index'
|
| 187 |
+
# TODO: handle these two reserved arguments
|
| 188 |
+
assert reduction == 'mean' and avg_factor is None
|
| 189 |
+
num_rois = pred.size()[0]
|
| 190 |
+
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
|
| 191 |
+
pred_slice = pred[inds, label].squeeze(1)
|
| 192 |
+
return F.binary_cross_entropy_with_logits(
|
| 193 |
+
pred_slice, target, weight=class_weight, reduction='mean')[None]
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
@LOSSES.register_module()
|
| 197 |
+
class CrossEntropyLoss(nn.Module):
|
| 198 |
+
"""CrossEntropyLoss.
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
use_sigmoid (bool, optional): Whether the prediction uses sigmoid
|
| 202 |
+
instead of softmax. Defaults to False.
|
| 203 |
+
use_mask (bool, optional): Whether to use mask cross entropy loss.
|
| 204 |
+
Defaults to False.
|
| 205 |
+
reduction (str, optional): . Defaults to 'mean'.
|
| 206 |
+
Options are "none", "mean" and "sum".
|
| 207 |
+
class_weight (list[float] | str, optional): Weight of each class. If in
|
| 208 |
+
str format, read them from a file. Defaults to None.
|
| 209 |
+
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
|
| 210 |
+
loss_name (str, optional): Name of the loss item. If you want this loss
|
| 211 |
+
item to be included into the backward graph, `loss_` must be the
|
| 212 |
+
prefix of the name. Defaults to 'loss_ce'.
|
| 213 |
+
avg_non_ignore (bool): The flag decides to whether the loss is
|
| 214 |
+
only averaged over non-ignored targets. Default: False.
|
| 215 |
+
`New in version 0.23.0.`
|
| 216 |
+
"""
|
| 217 |
+
|
| 218 |
+
def __init__(self,
|
| 219 |
+
use_sigmoid=False,
|
| 220 |
+
use_mask=False,
|
| 221 |
+
reduction='mean',
|
| 222 |
+
class_weight=None,
|
| 223 |
+
loss_weight=1.0,
|
| 224 |
+
loss_name='loss_ce',
|
| 225 |
+
avg_non_ignore=False):
|
| 226 |
+
super(CrossEntropyLoss, self).__init__()
|
| 227 |
+
assert (use_sigmoid is False) or (use_mask is False)
|
| 228 |
+
self.use_sigmoid = use_sigmoid
|
| 229 |
+
self.use_mask = use_mask
|
| 230 |
+
self.reduction = reduction
|
| 231 |
+
self.loss_weight = loss_weight
|
| 232 |
+
self.class_weight = get_class_weight(class_weight)
|
| 233 |
+
self.avg_non_ignore = avg_non_ignore
|
| 234 |
+
if not self.avg_non_ignore and self.reduction == 'mean':
|
| 235 |
+
warnings.warn(
|
| 236 |
+
'Default ``avg_non_ignore`` is False, if you would like to '
|
| 237 |
+
'ignore the certain label and average loss over non-ignore '
|
| 238 |
+
'labels, which is the same with PyTorch official '
|
| 239 |
+
'cross_entropy, set ``avg_non_ignore=True``.')
|
| 240 |
+
|
| 241 |
+
if self.use_sigmoid:
|
| 242 |
+
self.cls_criterion = binary_cross_entropy
|
| 243 |
+
elif self.use_mask:
|
| 244 |
+
self.cls_criterion = mask_cross_entropy
|
| 245 |
+
else:
|
| 246 |
+
self.cls_criterion = cross_entropy
|
| 247 |
+
self._loss_name = loss_name
|
| 248 |
+
|
| 249 |
+
def extra_repr(self):
|
| 250 |
+
"""Extra repr."""
|
| 251 |
+
s = f'avg_non_ignore={self.avg_non_ignore}'
|
| 252 |
+
return s
|
| 253 |
+
|
| 254 |
+
def forward(self,
|
| 255 |
+
cls_score,
|
| 256 |
+
label,
|
| 257 |
+
weight=None,
|
| 258 |
+
avg_factor=None,
|
| 259 |
+
reduction_override=None,
|
| 260 |
+
ignore_index=-100,
|
| 261 |
+
**kwargs):
|
| 262 |
+
"""Forward function."""
|
| 263 |
+
assert reduction_override in (None, 'none', 'mean', 'sum')
|
| 264 |
+
reduction = (
|
| 265 |
+
reduction_override if reduction_override else self.reduction)
|
| 266 |
+
if self.class_weight is not None:
|
| 267 |
+
class_weight = cls_score.new_tensor(self.class_weight)
|
| 268 |
+
else:
|
| 269 |
+
class_weight = None
|
| 270 |
+
# Note: for BCE loss, label < 0 is invalid.
|
| 271 |
+
loss_cls = self.loss_weight * self.cls_criterion(
|
| 272 |
+
cls_score,
|
| 273 |
+
label,
|
| 274 |
+
weight,
|
| 275 |
+
class_weight=class_weight,
|
| 276 |
+
reduction=reduction,
|
| 277 |
+
avg_factor=avg_factor,
|
| 278 |
+
avg_non_ignore=self.avg_non_ignore,
|
| 279 |
+
ignore_index=ignore_index,
|
| 280 |
+
**kwargs)
|
| 281 |
+
return loss_cls
|
| 282 |
+
|
| 283 |
+
@property
|
| 284 |
+
def loss_name(self):
|
| 285 |
+
"""Loss Name.
|
| 286 |
+
|
| 287 |
+
This function must be implemented and will return the name of this
|
| 288 |
+
loss function. This name will be used to combine different loss items
|
| 289 |
+
by simple sum operation. In addition, if you want this loss item to be
|
| 290 |
+
included into the backward graph, `loss_` must be the prefix of the
|
| 291 |
+
name.
|
| 292 |
+
|
| 293 |
+
Returns:
|
| 294 |
+
str: The name of this loss item.
|
| 295 |
+
"""
|
| 296 |
+
return self._loss_name
|
modelsforCIML/mmseg/models/losses/dice_loss.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
"""Modified from https://github.com/LikeLy-Journey/SegmenTron/blob/master/
|
| 3 |
+
segmentron/solver/loss.py (Apache-2.0 License)"""
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
from ..builder import LOSSES
|
| 9 |
+
from .utils import get_class_weight, weighted_loss
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@weighted_loss
|
| 13 |
+
def dice_loss(pred,
|
| 14 |
+
target,
|
| 15 |
+
valid_mask,
|
| 16 |
+
smooth=1,
|
| 17 |
+
exponent=2,
|
| 18 |
+
class_weight=None,
|
| 19 |
+
ignore_index=255):
|
| 20 |
+
assert pred.shape[0] == target.shape[0]
|
| 21 |
+
total_loss = 0
|
| 22 |
+
num_classes = pred.shape[1]
|
| 23 |
+
for i in range(num_classes):
|
| 24 |
+
if i != ignore_index:
|
| 25 |
+
dice_loss = binary_dice_loss(
|
| 26 |
+
pred[:, i],
|
| 27 |
+
target[..., i],
|
| 28 |
+
valid_mask=valid_mask,
|
| 29 |
+
smooth=smooth,
|
| 30 |
+
exponent=exponent)
|
| 31 |
+
if class_weight is not None:
|
| 32 |
+
dice_loss *= class_weight[i]
|
| 33 |
+
total_loss += dice_loss
|
| 34 |
+
return total_loss / num_classes
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@weighted_loss
|
| 38 |
+
def binary_dice_loss(pred, target, valid_mask, smooth=1, exponent=2, **kwargs):
|
| 39 |
+
assert pred.shape[0] == target.shape[0]
|
| 40 |
+
pred = pred.reshape(pred.shape[0], -1)
|
| 41 |
+
target = target.reshape(target.shape[0], -1)
|
| 42 |
+
valid_mask = valid_mask.reshape(valid_mask.shape[0], -1)
|
| 43 |
+
|
| 44 |
+
num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + smooth
|
| 45 |
+
den = torch.sum(pred.pow(exponent) + target.pow(exponent), dim=1) + smooth
|
| 46 |
+
|
| 47 |
+
return 1 - num / den
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@LOSSES.register_module()
|
| 51 |
+
class DiceLoss(nn.Module):
|
| 52 |
+
"""DiceLoss.
|
| 53 |
+
|
| 54 |
+
This loss is proposed in `V-Net: Fully Convolutional Neural Networks for
|
| 55 |
+
Volumetric Medical Image Segmentation <https://arxiv.org/abs/1606.04797>`_.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
smooth (float): A float number to smooth loss, and avoid NaN error.
|
| 59 |
+
Default: 1
|
| 60 |
+
exponent (float): An float number to calculate denominator
|
| 61 |
+
value: \\sum{x^exponent} + \\sum{y^exponent}. Default: 2.
|
| 62 |
+
reduction (str, optional): The method used to reduce the loss. Options
|
| 63 |
+
are "none", "mean" and "sum". This parameter only works when
|
| 64 |
+
per_image is True. Default: 'mean'.
|
| 65 |
+
class_weight (list[float] | str, optional): Weight of each class. If in
|
| 66 |
+
str format, read them from a file. Defaults to None.
|
| 67 |
+
loss_weight (float, optional): Weight of the loss. Default to 1.0.
|
| 68 |
+
ignore_index (int | None): The label index to be ignored. Default: 255.
|
| 69 |
+
loss_name (str, optional): Name of the loss item. If you want this loss
|
| 70 |
+
item to be included into the backward graph, `loss_` must be the
|
| 71 |
+
prefix of the name. Defaults to 'loss_dice'.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
def __init__(self,
|
| 75 |
+
smooth=1,
|
| 76 |
+
exponent=2,
|
| 77 |
+
reduction='mean',
|
| 78 |
+
class_weight=None,
|
| 79 |
+
loss_weight=1.0,
|
| 80 |
+
ignore_index=255,
|
| 81 |
+
loss_name='loss_dice',
|
| 82 |
+
**kwargs):
|
| 83 |
+
super(DiceLoss, self).__init__()
|
| 84 |
+
self.smooth = smooth
|
| 85 |
+
self.exponent = exponent
|
| 86 |
+
self.reduction = reduction
|
| 87 |
+
self.class_weight = get_class_weight(class_weight)
|
| 88 |
+
self.loss_weight = loss_weight
|
| 89 |
+
self.ignore_index = ignore_index
|
| 90 |
+
self._loss_name = loss_name
|
| 91 |
+
|
| 92 |
+
def forward(self,
|
| 93 |
+
pred,
|
| 94 |
+
target,
|
| 95 |
+
avg_factor=None,
|
| 96 |
+
reduction_override=None,
|
| 97 |
+
**kwargs):
|
| 98 |
+
assert reduction_override in (None, 'none', 'mean', 'sum')
|
| 99 |
+
reduction = (
|
| 100 |
+
reduction_override if reduction_override else self.reduction)
|
| 101 |
+
if self.class_weight is not None:
|
| 102 |
+
class_weight = pred.new_tensor(self.class_weight)
|
| 103 |
+
else:
|
| 104 |
+
class_weight = None
|
| 105 |
+
|
| 106 |
+
pred = F.softmax(pred, dim=1)
|
| 107 |
+
num_classes = pred.shape[1]
|
| 108 |
+
one_hot_target = F.one_hot(
|
| 109 |
+
torch.clamp(target.long(), 0, num_classes - 1),
|
| 110 |
+
num_classes=num_classes)
|
| 111 |
+
valid_mask = (target != self.ignore_index).long()
|
| 112 |
+
|
| 113 |
+
loss = self.loss_weight * dice_loss(
|
| 114 |
+
pred,
|
| 115 |
+
one_hot_target,
|
| 116 |
+
valid_mask=valid_mask,
|
| 117 |
+
reduction=reduction,
|
| 118 |
+
avg_factor=avg_factor,
|
| 119 |
+
smooth=self.smooth,
|
| 120 |
+
exponent=self.exponent,
|
| 121 |
+
class_weight=class_weight,
|
| 122 |
+
ignore_index=self.ignore_index)
|
| 123 |
+
return loss
|
| 124 |
+
|
| 125 |
+
@property
|
| 126 |
+
def loss_name(self):
|
| 127 |
+
"""Loss Name.
|
| 128 |
+
|
| 129 |
+
This function must be implemented and will return the name of this
|
| 130 |
+
loss function. This name will be used to combine different loss items
|
| 131 |
+
by simple sum operation. In addition, if you want this loss item to be
|
| 132 |
+
included into the backward graph, `loss_` must be the prefix of the
|
| 133 |
+
name.
|
| 134 |
+
Returns:
|
| 135 |
+
str: The name of this loss item.
|
| 136 |
+
"""
|
| 137 |
+
return self._loss_name
|
modelsforCIML/mmseg/models/losses/focal_loss.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
# Modified from https://github.com/open-mmlab/mmdetection
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss
|
| 7 |
+
|
| 8 |
+
from ..builder import LOSSES
|
| 9 |
+
from .utils import weight_reduce_loss
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# This method is used when cuda is not available
|
| 13 |
+
def py_sigmoid_focal_loss(pred,
|
| 14 |
+
target,
|
| 15 |
+
one_hot_target=None,
|
| 16 |
+
weight=None,
|
| 17 |
+
gamma=2.0,
|
| 18 |
+
alpha=0.5,
|
| 19 |
+
class_weight=None,
|
| 20 |
+
valid_mask=None,
|
| 21 |
+
reduction='mean',
|
| 22 |
+
avg_factor=None):
|
| 23 |
+
"""PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
pred (torch.Tensor): The prediction with shape (N, C), C is the
|
| 27 |
+
number of classes
|
| 28 |
+
target (torch.Tensor): The learning label of the prediction with
|
| 29 |
+
shape (N, C)
|
| 30 |
+
one_hot_target (None): Placeholder. It should be None.
|
| 31 |
+
weight (torch.Tensor, optional): Sample-wise loss weight.
|
| 32 |
+
gamma (float, optional): The gamma for calculating the modulating
|
| 33 |
+
factor. Defaults to 2.0.
|
| 34 |
+
alpha (float | list[float], optional): A balanced form for Focal Loss.
|
| 35 |
+
Defaults to 0.5.
|
| 36 |
+
class_weight (list[float], optional): Weight of each class.
|
| 37 |
+
Defaults to None.
|
| 38 |
+
valid_mask (torch.Tensor, optional): A mask uses 1 to mark the valid
|
| 39 |
+
samples and uses 0 to mark the ignored samples. Default: None.
|
| 40 |
+
reduction (str, optional): The method used to reduce the loss into
|
| 41 |
+
a scalar. Defaults to 'mean'.
|
| 42 |
+
avg_factor (int, optional): Average factor that is used to average
|
| 43 |
+
the loss. Defaults to None.
|
| 44 |
+
"""
|
| 45 |
+
if isinstance(alpha, list):
|
| 46 |
+
alpha = pred.new_tensor(alpha)
|
| 47 |
+
pred_sigmoid = pred.sigmoid()
|
| 48 |
+
target = target.type_as(pred)
|
| 49 |
+
one_minus_pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
|
| 50 |
+
focal_weight = (alpha * target + (1 - alpha) *
|
| 51 |
+
(1 - target)) * one_minus_pt.pow(gamma)
|
| 52 |
+
|
| 53 |
+
loss = F.binary_cross_entropy_with_logits(
|
| 54 |
+
pred, target, reduction='none') * focal_weight
|
| 55 |
+
final_weight = torch.ones(1, pred.size(1)).type_as(loss)
|
| 56 |
+
if weight is not None:
|
| 57 |
+
if weight.shape != loss.shape and weight.size(0) == loss.size(0):
|
| 58 |
+
# For most cases, weight is of shape (N, ),
|
| 59 |
+
# which means it does not have the second axis num_class
|
| 60 |
+
weight = weight.view(-1, 1)
|
| 61 |
+
assert weight.dim() == loss.dim()
|
| 62 |
+
final_weight = final_weight * weight
|
| 63 |
+
if class_weight is not None:
|
| 64 |
+
final_weight = final_weight * pred.new_tensor(class_weight)
|
| 65 |
+
if valid_mask is not None:
|
| 66 |
+
final_weight = final_weight * valid_mask
|
| 67 |
+
loss = weight_reduce_loss(loss, final_weight, reduction, avg_factor)
|
| 68 |
+
return loss
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def sigmoid_focal_loss(pred,
|
| 72 |
+
target,
|
| 73 |
+
one_hot_target,
|
| 74 |
+
weight=None,
|
| 75 |
+
gamma=2.0,
|
| 76 |
+
alpha=0.5,
|
| 77 |
+
class_weight=None,
|
| 78 |
+
valid_mask=None,
|
| 79 |
+
reduction='mean',
|
| 80 |
+
avg_factor=None):
|
| 81 |
+
r"""A wrapper of cuda version `Focal Loss
|
| 82 |
+
<https://arxiv.org/abs/1708.02002>`_.
|
| 83 |
+
Args:
|
| 84 |
+
pred (torch.Tensor): The prediction with shape (N, C), C is the number
|
| 85 |
+
of classes.
|
| 86 |
+
target (torch.Tensor): The learning label of the prediction. It's shape
|
| 87 |
+
should be (N, )
|
| 88 |
+
one_hot_target (torch.Tensor): The learning label with shape (N, C)
|
| 89 |
+
weight (torch.Tensor, optional): Sample-wise loss weight.
|
| 90 |
+
gamma (float, optional): The gamma for calculating the modulating
|
| 91 |
+
factor. Defaults to 2.0.
|
| 92 |
+
alpha (float | list[float], optional): A balanced form for Focal Loss.
|
| 93 |
+
Defaults to 0.5.
|
| 94 |
+
class_weight (list[float], optional): Weight of each class.
|
| 95 |
+
Defaults to None.
|
| 96 |
+
valid_mask (torch.Tensor, optional): A mask uses 1 to mark the valid
|
| 97 |
+
samples and uses 0 to mark the ignored samples. Default: None.
|
| 98 |
+
reduction (str, optional): The method used to reduce the loss into
|
| 99 |
+
a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum".
|
| 100 |
+
avg_factor (int, optional): Average factor that is used to average
|
| 101 |
+
the loss. Defaults to None.
|
| 102 |
+
"""
|
| 103 |
+
# Function.apply does not accept keyword arguments, so the decorator
|
| 104 |
+
# "weighted_loss" is not applicable
|
| 105 |
+
final_weight = torch.ones(1, pred.size(1)).type_as(pred)
|
| 106 |
+
if isinstance(alpha, list):
|
| 107 |
+
# _sigmoid_focal_loss doesn't accept alpha of list type. Therefore, if
|
| 108 |
+
# a list is given, we set the input alpha as 0.5. This means setting
|
| 109 |
+
# equal weight for foreground class and background class. By
|
| 110 |
+
# multiplying the loss by 2, the effect of setting alpha as 0.5 is
|
| 111 |
+
# undone. The alpha of type list is used to regulate the loss in the
|
| 112 |
+
# post-processing process.
|
| 113 |
+
loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(),
|
| 114 |
+
gamma, 0.5, None, 'none') * 2
|
| 115 |
+
alpha = pred.new_tensor(alpha)
|
| 116 |
+
final_weight = final_weight * (
|
| 117 |
+
alpha * one_hot_target + (1 - alpha) * (1 - one_hot_target))
|
| 118 |
+
else:
|
| 119 |
+
loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(),
|
| 120 |
+
gamma, alpha, None, 'none')
|
| 121 |
+
if weight is not None:
|
| 122 |
+
if weight.shape != loss.shape and weight.size(0) == loss.size(0):
|
| 123 |
+
# For most cases, weight is of shape (N, ),
|
| 124 |
+
# which means it does not have the second axis num_class
|
| 125 |
+
weight = weight.view(-1, 1)
|
| 126 |
+
assert weight.dim() == loss.dim()
|
| 127 |
+
final_weight = final_weight * weight
|
| 128 |
+
if class_weight is not None:
|
| 129 |
+
final_weight = final_weight * pred.new_tensor(class_weight)
|
| 130 |
+
if valid_mask is not None:
|
| 131 |
+
final_weight = final_weight * valid_mask
|
| 132 |
+
loss = weight_reduce_loss(loss, final_weight, reduction, avg_factor)
|
| 133 |
+
return loss
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
@LOSSES.register_module()
|
| 137 |
+
class FocalLoss(nn.Module):
|
| 138 |
+
|
| 139 |
+
def __init__(self,
|
| 140 |
+
use_sigmoid=True,
|
| 141 |
+
gamma=2.0,
|
| 142 |
+
alpha=0.5,
|
| 143 |
+
reduction='mean',
|
| 144 |
+
class_weight=None,
|
| 145 |
+
loss_weight=1.0,
|
| 146 |
+
loss_name='loss_focal'):
|
| 147 |
+
"""`Focal Loss <https://arxiv.org/abs/1708.02002>`_
|
| 148 |
+
Args:
|
| 149 |
+
use_sigmoid (bool, optional): Whether to the prediction is
|
| 150 |
+
used for sigmoid or softmax. Defaults to True.
|
| 151 |
+
gamma (float, optional): The gamma for calculating the modulating
|
| 152 |
+
factor. Defaults to 2.0.
|
| 153 |
+
alpha (float | list[float], optional): A balanced form for Focal
|
| 154 |
+
Loss. Defaults to 0.5. When a list is provided, the length
|
| 155 |
+
of the list should be equal to the number of classes.
|
| 156 |
+
Please be careful that this parameter is not the
|
| 157 |
+
class-wise weight but the weight of a binary classification
|
| 158 |
+
problem. This binary classification problem regards the
|
| 159 |
+
pixels which belong to one class as the foreground
|
| 160 |
+
and the other pixels as the background, each element in
|
| 161 |
+
the list is the weight of the corresponding foreground class.
|
| 162 |
+
The value of alpha or each element of alpha should be a float
|
| 163 |
+
in the interval [0, 1]. If you want to specify the class-wise
|
| 164 |
+
weight, please use `class_weight` parameter.
|
| 165 |
+
reduction (str, optional): The method used to reduce the loss into
|
| 166 |
+
a scalar. Defaults to 'mean'. Options are "none", "mean" and
|
| 167 |
+
"sum".
|
| 168 |
+
class_weight (list[float], optional): Weight of each class.
|
| 169 |
+
Defaults to None.
|
| 170 |
+
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
|
| 171 |
+
loss_name (str, optional): Name of the loss item. If you want this
|
| 172 |
+
loss item to be included into the backward graph, `loss_` must
|
| 173 |
+
be the prefix of the name. Defaults to 'loss_focal'.
|
| 174 |
+
"""
|
| 175 |
+
super(FocalLoss, self).__init__()
|
| 176 |
+
assert use_sigmoid is True, \
|
| 177 |
+
'AssertionError: Only sigmoid focal loss supported now.'
|
| 178 |
+
assert reduction in ('none', 'mean', 'sum'), \
|
| 179 |
+
"AssertionError: reduction should be 'none', 'mean' or " \
|
| 180 |
+
"'sum'"
|
| 181 |
+
assert isinstance(alpha, (float, list)), \
|
| 182 |
+
'AssertionError: alpha should be of type float'
|
| 183 |
+
assert isinstance(gamma, float), \
|
| 184 |
+
'AssertionError: gamma should be of type float'
|
| 185 |
+
assert isinstance(loss_weight, float), \
|
| 186 |
+
'AssertionError: loss_weight should be of type float'
|
| 187 |
+
assert isinstance(loss_name, str), \
|
| 188 |
+
'AssertionError: loss_name should be of type str'
|
| 189 |
+
assert isinstance(class_weight, list) or class_weight is None, \
|
| 190 |
+
'AssertionError: class_weight must be None or of type list'
|
| 191 |
+
self.use_sigmoid = use_sigmoid
|
| 192 |
+
self.gamma = gamma
|
| 193 |
+
self.alpha = alpha
|
| 194 |
+
self.reduction = reduction
|
| 195 |
+
self.class_weight = class_weight
|
| 196 |
+
self.loss_weight = loss_weight
|
| 197 |
+
self._loss_name = loss_name
|
| 198 |
+
|
| 199 |
+
def forward(self,
|
| 200 |
+
pred,
|
| 201 |
+
target,
|
| 202 |
+
weight=None,
|
| 203 |
+
avg_factor=None,
|
| 204 |
+
reduction_override=None,
|
| 205 |
+
ignore_index=255,
|
| 206 |
+
**kwargs):
|
| 207 |
+
"""Forward function.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
pred (torch.Tensor): The prediction with shape
|
| 211 |
+
(N, C) where C = number of classes, or
|
| 212 |
+
(N, C, d_1, d_2, ..., d_K) with K≥1 in the
|
| 213 |
+
case of K-dimensional loss.
|
| 214 |
+
target (torch.Tensor): The ground truth. If containing class
|
| 215 |
+
indices, shape (N) where each value is 0≤targets[i]≤C−1,
|
| 216 |
+
or (N, d_1, d_2, ..., d_K) with K≥1 in the case of
|
| 217 |
+
K-dimensional loss. If containing class probabilities,
|
| 218 |
+
same shape as the input.
|
| 219 |
+
weight (torch.Tensor, optional): The weight of loss for each
|
| 220 |
+
prediction. Defaults to None.
|
| 221 |
+
avg_factor (int, optional): Average factor that is used to
|
| 222 |
+
average the loss. Defaults to None.
|
| 223 |
+
reduction_override (str, optional): The reduction method used
|
| 224 |
+
to override the original reduction method of the loss.
|
| 225 |
+
Options are "none", "mean" and "sum".
|
| 226 |
+
ignore_index (int, optional): The label index to be ignored.
|
| 227 |
+
Default: 255
|
| 228 |
+
Returns:
|
| 229 |
+
torch.Tensor: The calculated loss
|
| 230 |
+
"""
|
| 231 |
+
assert isinstance(ignore_index, int), \
|
| 232 |
+
'ignore_index must be of type int'
|
| 233 |
+
assert reduction_override in (None, 'none', 'mean', 'sum'), \
|
| 234 |
+
"AssertionError: reduction should be 'none', 'mean' or " \
|
| 235 |
+
"'sum'"
|
| 236 |
+
assert pred.shape == target.shape or \
|
| 237 |
+
(pred.size(0) == target.size(0) and
|
| 238 |
+
pred.shape[2:] == target.shape[1:]), \
|
| 239 |
+
"The shape of pred doesn't match the shape of target"
|
| 240 |
+
|
| 241 |
+
original_shape = pred.shape
|
| 242 |
+
|
| 243 |
+
# [B, C, d_1, d_2, ..., d_k] -> [C, B, d_1, d_2, ..., d_k]
|
| 244 |
+
pred = pred.transpose(0, 1)
|
| 245 |
+
# [C, B, d_1, d_2, ..., d_k] -> [C, N]
|
| 246 |
+
pred = pred.reshape(pred.size(0), -1)
|
| 247 |
+
# [C, N] -> [N, C]
|
| 248 |
+
pred = pred.transpose(0, 1).contiguous()
|
| 249 |
+
|
| 250 |
+
if original_shape == target.shape:
|
| 251 |
+
# target with shape [B, C, d_1, d_2, ...]
|
| 252 |
+
# transform it's shape into [N, C]
|
| 253 |
+
# [B, C, d_1, d_2, ...] -> [C, B, d_1, d_2, ..., d_k]
|
| 254 |
+
target = target.transpose(0, 1)
|
| 255 |
+
# [C, B, d_1, d_2, ..., d_k] -> [C, N]
|
| 256 |
+
target = target.reshape(target.size(0), -1)
|
| 257 |
+
# [C, N] -> [N, C]
|
| 258 |
+
target = target.transpose(0, 1).contiguous()
|
| 259 |
+
else:
|
| 260 |
+
# target with shape [B, d_1, d_2, ...]
|
| 261 |
+
# transform it's shape into [N, ]
|
| 262 |
+
target = target.view(-1).contiguous()
|
| 263 |
+
valid_mask = (target != ignore_index).view(-1, 1)
|
| 264 |
+
# avoid raising error when using F.one_hot()
|
| 265 |
+
target = torch.where(target == ignore_index, target.new_tensor(0),
|
| 266 |
+
target)
|
| 267 |
+
|
| 268 |
+
reduction = (
|
| 269 |
+
reduction_override if reduction_override else self.reduction)
|
| 270 |
+
if self.use_sigmoid:
|
| 271 |
+
num_classes = pred.size(1)
|
| 272 |
+
if torch.cuda.is_available() and pred.is_cuda:
|
| 273 |
+
if target.dim() == 1:
|
| 274 |
+
one_hot_target = F.one_hot(target, num_classes=num_classes)
|
| 275 |
+
else:
|
| 276 |
+
one_hot_target = target
|
| 277 |
+
target = target.argmax(dim=1)
|
| 278 |
+
valid_mask = (target != ignore_index).view(-1, 1)
|
| 279 |
+
calculate_loss_func = sigmoid_focal_loss
|
| 280 |
+
else:
|
| 281 |
+
one_hot_target = None
|
| 282 |
+
if target.dim() == 1:
|
| 283 |
+
target = F.one_hot(target, num_classes=num_classes)
|
| 284 |
+
else:
|
| 285 |
+
valid_mask = (target.argmax(dim=1) != ignore_index).view(
|
| 286 |
+
-1, 1)
|
| 287 |
+
calculate_loss_func = py_sigmoid_focal_loss
|
| 288 |
+
|
| 289 |
+
loss_cls = self.loss_weight * calculate_loss_func(
|
| 290 |
+
pred,
|
| 291 |
+
target,
|
| 292 |
+
one_hot_target,
|
| 293 |
+
weight,
|
| 294 |
+
gamma=self.gamma,
|
| 295 |
+
alpha=self.alpha,
|
| 296 |
+
class_weight=self.class_weight,
|
| 297 |
+
valid_mask=valid_mask,
|
| 298 |
+
reduction=reduction,
|
| 299 |
+
avg_factor=avg_factor)
|
| 300 |
+
|
| 301 |
+
if reduction == 'none':
|
| 302 |
+
# [N, C] -> [C, N]
|
| 303 |
+
loss_cls = loss_cls.transpose(0, 1)
|
| 304 |
+
# [C, N] -> [C, B, d1, d2, ...]
|
| 305 |
+
# original_shape: [B, C, d1, d2, ...]
|
| 306 |
+
loss_cls = loss_cls.reshape(original_shape[1],
|
| 307 |
+
original_shape[0],
|
| 308 |
+
*original_shape[2:])
|
| 309 |
+
# [C, B, d1, d2, ...] -> [B, C, d1, d2, ...]
|
| 310 |
+
loss_cls = loss_cls.transpose(0, 1).contiguous()
|
| 311 |
+
else:
|
| 312 |
+
raise NotImplementedError
|
| 313 |
+
return loss_cls
|
| 314 |
+
|
| 315 |
+
@property
|
| 316 |
+
def loss_name(self):
|
| 317 |
+
"""Loss Name.
|
| 318 |
+
|
| 319 |
+
This function must be implemented and will return the name of this
|
| 320 |
+
loss function. This name will be used to combine different loss items
|
| 321 |
+
by simple sum operation. In addition, if you want this loss item to be
|
| 322 |
+
included into the backward graph, `loss_` must be the prefix of the
|
| 323 |
+
name.
|
| 324 |
+
Returns:
|
| 325 |
+
str: The name of this loss item.
|
| 326 |
+
"""
|
| 327 |
+
return self._loss_name
|
modelsforCIML/mmseg/models/losses/lovasz_loss.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
"""Modified from https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytor
|
| 3 |
+
ch/lovasz_losses.py Lovasz-Softmax and Jaccard hinge loss in PyTorch Maxim
|
| 4 |
+
Berman 2018 ESAT-PSI KU Leuven (MIT License)"""
|
| 5 |
+
|
| 6 |
+
import mmcv
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
from ..builder import LOSSES
|
| 12 |
+
from .utils import get_class_weight, weight_reduce_loss
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def lovasz_grad(gt_sorted):
|
| 16 |
+
"""Computes gradient of the Lovasz extension w.r.t sorted errors.
|
| 17 |
+
|
| 18 |
+
See Alg. 1 in paper.
|
| 19 |
+
"""
|
| 20 |
+
p = len(gt_sorted)
|
| 21 |
+
gts = gt_sorted.sum()
|
| 22 |
+
intersection = gts - gt_sorted.float().cumsum(0)
|
| 23 |
+
union = gts + (1 - gt_sorted).float().cumsum(0)
|
| 24 |
+
jaccard = 1. - intersection / union
|
| 25 |
+
if p > 1: # cover 1-pixel case
|
| 26 |
+
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
|
| 27 |
+
return jaccard
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def flatten_binary_logits(logits, labels, ignore_index=None):
|
| 31 |
+
"""Flattens predictions in the batch (binary case) Remove labels equal to
|
| 32 |
+
'ignore_index'."""
|
| 33 |
+
logits = logits.view(-1)
|
| 34 |
+
labels = labels.view(-1)
|
| 35 |
+
if ignore_index is None:
|
| 36 |
+
return logits, labels
|
| 37 |
+
valid = (labels != ignore_index)
|
| 38 |
+
vlogits = logits[valid]
|
| 39 |
+
vlabels = labels[valid]
|
| 40 |
+
return vlogits, vlabels
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def flatten_probs(probs, labels, ignore_index=None):
|
| 44 |
+
"""Flattens predictions in the batch."""
|
| 45 |
+
if probs.dim() == 3:
|
| 46 |
+
# assumes output of a sigmoid layer
|
| 47 |
+
B, H, W = probs.size()
|
| 48 |
+
probs = probs.view(B, 1, H, W)
|
| 49 |
+
B, C, H, W = probs.size()
|
| 50 |
+
probs = probs.permute(0, 2, 3, 1).contiguous().view(-1, C) # B*H*W, C=P,C
|
| 51 |
+
labels = labels.view(-1)
|
| 52 |
+
if ignore_index is None:
|
| 53 |
+
return probs, labels
|
| 54 |
+
valid = (labels != ignore_index)
|
| 55 |
+
vprobs = probs[valid.nonzero().squeeze()]
|
| 56 |
+
vlabels = labels[valid]
|
| 57 |
+
return vprobs, vlabels
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def lovasz_hinge_flat(logits, labels):
|
| 61 |
+
"""Binary Lovasz hinge loss.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
logits (torch.Tensor): [P], logits at each prediction
|
| 65 |
+
(between -infty and +infty).
|
| 66 |
+
labels (torch.Tensor): [P], binary ground truth labels (0 or 1).
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
torch.Tensor: The calculated loss.
|
| 70 |
+
"""
|
| 71 |
+
if len(labels) == 0:
|
| 72 |
+
# only void pixels, the gradients should be 0
|
| 73 |
+
return logits.sum() * 0.
|
| 74 |
+
signs = 2. * labels.float() - 1.
|
| 75 |
+
errors = (1. - logits * signs)
|
| 76 |
+
errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
|
| 77 |
+
perm = perm.data
|
| 78 |
+
gt_sorted = labels[perm]
|
| 79 |
+
grad = lovasz_grad(gt_sorted)
|
| 80 |
+
loss = torch.dot(F.relu(errors_sorted), grad)
|
| 81 |
+
return loss
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def lovasz_hinge(logits,
|
| 85 |
+
labels,
|
| 86 |
+
classes='present',
|
| 87 |
+
per_image=False,
|
| 88 |
+
class_weight=None,
|
| 89 |
+
reduction='mean',
|
| 90 |
+
avg_factor=None,
|
| 91 |
+
ignore_index=255):
|
| 92 |
+
"""Binary Lovasz hinge loss.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
logits (torch.Tensor): [B, H, W], logits at each pixel
|
| 96 |
+
(between -infty and +infty).
|
| 97 |
+
labels (torch.Tensor): [B, H, W], binary ground truth masks (0 or 1).
|
| 98 |
+
classes (str | list[int], optional): Placeholder, to be consistent with
|
| 99 |
+
other loss. Default: None.
|
| 100 |
+
per_image (bool, optional): If per_image is True, compute the loss per
|
| 101 |
+
image instead of per batch. Default: False.
|
| 102 |
+
class_weight (list[float], optional): Placeholder, to be consistent
|
| 103 |
+
with other loss. Default: None.
|
| 104 |
+
reduction (str, optional): The method used to reduce the loss. Options
|
| 105 |
+
are "none", "mean" and "sum". This parameter only works when
|
| 106 |
+
per_image is True. Default: 'mean'.
|
| 107 |
+
avg_factor (int, optional): Average factor that is used to average
|
| 108 |
+
the loss. This parameter only works when per_image is True.
|
| 109 |
+
Default: None.
|
| 110 |
+
ignore_index (int | None): The label index to be ignored. Default: 255.
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
torch.Tensor: The calculated loss.
|
| 114 |
+
"""
|
| 115 |
+
if per_image:
|
| 116 |
+
loss = [
|
| 117 |
+
lovasz_hinge_flat(*flatten_binary_logits(
|
| 118 |
+
logit.unsqueeze(0), label.unsqueeze(0), ignore_index))
|
| 119 |
+
for logit, label in zip(logits, labels)
|
| 120 |
+
]
|
| 121 |
+
loss = weight_reduce_loss(
|
| 122 |
+
torch.stack(loss), None, reduction, avg_factor)
|
| 123 |
+
else:
|
| 124 |
+
loss = lovasz_hinge_flat(
|
| 125 |
+
*flatten_binary_logits(logits, labels, ignore_index))
|
| 126 |
+
return loss
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def lovasz_softmax_flat(probs, labels, classes='present', class_weight=None):
|
| 130 |
+
"""Multi-class Lovasz-Softmax loss.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
probs (torch.Tensor): [P, C], class probabilities at each prediction
|
| 134 |
+
(between 0 and 1).
|
| 135 |
+
labels (torch.Tensor): [P], ground truth labels (between 0 and C - 1).
|
| 136 |
+
classes (str | list[int], optional): Classes chosen to calculate loss.
|
| 137 |
+
'all' for all classes, 'present' for classes present in labels, or
|
| 138 |
+
a list of classes to average. Default: 'present'.
|
| 139 |
+
class_weight (list[float], optional): The weight for each class.
|
| 140 |
+
Default: None.
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
torch.Tensor: The calculated loss.
|
| 144 |
+
"""
|
| 145 |
+
if probs.numel() == 0:
|
| 146 |
+
# only void pixels, the gradients should be 0
|
| 147 |
+
return probs * 0.
|
| 148 |
+
C = probs.size(1)
|
| 149 |
+
losses = []
|
| 150 |
+
class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
|
| 151 |
+
for c in class_to_sum:
|
| 152 |
+
fg = (labels == c).float() # foreground for class c
|
| 153 |
+
if (classes == 'present' and fg.sum() == 0):
|
| 154 |
+
continue
|
| 155 |
+
if C == 1:
|
| 156 |
+
if len(classes) > 1:
|
| 157 |
+
raise ValueError('Sigmoid output possible only with 1 class')
|
| 158 |
+
class_pred = probs[:, 0]
|
| 159 |
+
else:
|
| 160 |
+
class_pred = probs[:, c]
|
| 161 |
+
errors = (fg - class_pred).abs()
|
| 162 |
+
errors_sorted, perm = torch.sort(errors, 0, descending=True)
|
| 163 |
+
perm = perm.data
|
| 164 |
+
fg_sorted = fg[perm]
|
| 165 |
+
loss = torch.dot(errors_sorted, lovasz_grad(fg_sorted))
|
| 166 |
+
if class_weight is not None:
|
| 167 |
+
loss *= class_weight[c]
|
| 168 |
+
losses.append(loss)
|
| 169 |
+
return torch.stack(losses).mean()
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def lovasz_softmax(probs,
|
| 173 |
+
labels,
|
| 174 |
+
classes='present',
|
| 175 |
+
per_image=False,
|
| 176 |
+
class_weight=None,
|
| 177 |
+
reduction='mean',
|
| 178 |
+
avg_factor=None,
|
| 179 |
+
ignore_index=255):
|
| 180 |
+
"""Multi-class Lovasz-Softmax loss.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
probs (torch.Tensor): [B, C, H, W], class probabilities at each
|
| 184 |
+
prediction (between 0 and 1).
|
| 185 |
+
labels (torch.Tensor): [B, H, W], ground truth labels (between 0 and
|
| 186 |
+
C - 1).
|
| 187 |
+
classes (str | list[int], optional): Classes chosen to calculate loss.
|
| 188 |
+
'all' for all classes, 'present' for classes present in labels, or
|
| 189 |
+
a list of classes to average. Default: 'present'.
|
| 190 |
+
per_image (bool, optional): If per_image is True, compute the loss per
|
| 191 |
+
image instead of per batch. Default: False.
|
| 192 |
+
class_weight (list[float], optional): The weight for each class.
|
| 193 |
+
Default: None.
|
| 194 |
+
reduction (str, optional): The method used to reduce the loss. Options
|
| 195 |
+
are "none", "mean" and "sum". This parameter only works when
|
| 196 |
+
per_image is True. Default: 'mean'.
|
| 197 |
+
avg_factor (int, optional): Average factor that is used to average
|
| 198 |
+
the loss. This parameter only works when per_image is True.
|
| 199 |
+
Default: None.
|
| 200 |
+
ignore_index (int | None): The label index to be ignored. Default: 255.
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
torch.Tensor: The calculated loss.
|
| 204 |
+
"""
|
| 205 |
+
|
| 206 |
+
if per_image:
|
| 207 |
+
loss = [
|
| 208 |
+
lovasz_softmax_flat(
|
| 209 |
+
*flatten_probs(
|
| 210 |
+
prob.unsqueeze(0), label.unsqueeze(0), ignore_index),
|
| 211 |
+
classes=classes,
|
| 212 |
+
class_weight=class_weight)
|
| 213 |
+
for prob, label in zip(probs, labels)
|
| 214 |
+
]
|
| 215 |
+
loss = weight_reduce_loss(
|
| 216 |
+
torch.stack(loss), None, reduction, avg_factor)
|
| 217 |
+
else:
|
| 218 |
+
loss = lovasz_softmax_flat(
|
| 219 |
+
*flatten_probs(probs, labels, ignore_index),
|
| 220 |
+
classes=classes,
|
| 221 |
+
class_weight=class_weight)
|
| 222 |
+
return loss
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
@LOSSES.register_module()
|
| 226 |
+
class LovaszLoss(nn.Module):
|
| 227 |
+
"""LovaszLoss.
|
| 228 |
+
|
| 229 |
+
This loss is proposed in `The Lovasz-Softmax loss: A tractable surrogate
|
| 230 |
+
for the optimization of the intersection-over-union measure in neural
|
| 231 |
+
networks <https://arxiv.org/abs/1705.08790>`_.
|
| 232 |
+
|
| 233 |
+
Args:
|
| 234 |
+
loss_type (str, optional): Binary or multi-class loss.
|
| 235 |
+
Default: 'multi_class'. Options are "binary" and "multi_class".
|
| 236 |
+
classes (str | list[int], optional): Classes chosen to calculate loss.
|
| 237 |
+
'all' for all classes, 'present' for classes present in labels, or
|
| 238 |
+
a list of classes to average. Default: 'present'.
|
| 239 |
+
per_image (bool, optional): If per_image is True, compute the loss per
|
| 240 |
+
image instead of per batch. Default: False.
|
| 241 |
+
reduction (str, optional): The method used to reduce the loss. Options
|
| 242 |
+
are "none", "mean" and "sum". This parameter only works when
|
| 243 |
+
per_image is True. Default: 'mean'.
|
| 244 |
+
class_weight (list[float] | str, optional): Weight of each class. If in
|
| 245 |
+
str format, read them from a file. Defaults to None.
|
| 246 |
+
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
|
| 247 |
+
loss_name (str, optional): Name of the loss item. If you want this loss
|
| 248 |
+
item to be included into the backward graph, `loss_` must be the
|
| 249 |
+
prefix of the name. Defaults to 'loss_lovasz'.
|
| 250 |
+
"""
|
| 251 |
+
|
| 252 |
+
def __init__(self,
|
| 253 |
+
loss_type='multi_class',
|
| 254 |
+
classes='present',
|
| 255 |
+
per_image=False,
|
| 256 |
+
reduction='mean',
|
| 257 |
+
class_weight=None,
|
| 258 |
+
loss_weight=1.0,
|
| 259 |
+
loss_name='loss_lovasz'):
|
| 260 |
+
super(LovaszLoss, self).__init__()
|
| 261 |
+
assert loss_type in ('binary', 'multi_class'), "loss_type should be \
|
| 262 |
+
'binary' or 'multi_class'."
|
| 263 |
+
|
| 264 |
+
if loss_type == 'binary':
|
| 265 |
+
self.cls_criterion = lovasz_hinge
|
| 266 |
+
else:
|
| 267 |
+
self.cls_criterion = lovasz_softmax
|
| 268 |
+
assert classes in ('all', 'present') or mmcv.is_list_of(classes, int)
|
| 269 |
+
if not per_image:
|
| 270 |
+
assert reduction == 'none', "reduction should be 'none' when \
|
| 271 |
+
per_image is False."
|
| 272 |
+
|
| 273 |
+
self.classes = classes
|
| 274 |
+
self.per_image = per_image
|
| 275 |
+
self.reduction = reduction
|
| 276 |
+
self.loss_weight = loss_weight
|
| 277 |
+
self.class_weight = get_class_weight(class_weight)
|
| 278 |
+
self._loss_name = loss_name
|
| 279 |
+
|
| 280 |
+
def forward(self,
|
| 281 |
+
cls_score,
|
| 282 |
+
label,
|
| 283 |
+
weight=None,
|
| 284 |
+
avg_factor=None,
|
| 285 |
+
reduction_override=None,
|
| 286 |
+
**kwargs):
|
| 287 |
+
"""Forward function."""
|
| 288 |
+
assert reduction_override in (None, 'none', 'mean', 'sum')
|
| 289 |
+
reduction = (
|
| 290 |
+
reduction_override if reduction_override else self.reduction)
|
| 291 |
+
if self.class_weight is not None:
|
| 292 |
+
class_weight = cls_score.new_tensor(self.class_weight)
|
| 293 |
+
else:
|
| 294 |
+
class_weight = None
|
| 295 |
+
|
| 296 |
+
# if multi-class loss, transform logits to probs
|
| 297 |
+
if self.cls_criterion == lovasz_softmax:
|
| 298 |
+
cls_score = F.softmax(cls_score, dim=1)
|
| 299 |
+
|
| 300 |
+
loss_cls = self.loss_weight * self.cls_criterion(
|
| 301 |
+
cls_score,
|
| 302 |
+
label,
|
| 303 |
+
self.classes,
|
| 304 |
+
self.per_image,
|
| 305 |
+
class_weight=class_weight,
|
| 306 |
+
reduction=reduction,
|
| 307 |
+
avg_factor=avg_factor,
|
| 308 |
+
**kwargs)
|
| 309 |
+
return loss_cls
|
| 310 |
+
|
| 311 |
+
@property
|
| 312 |
+
def loss_name(self):
|
| 313 |
+
"""Loss Name.
|
| 314 |
+
|
| 315 |
+
This function must be implemented and will return the name of this
|
| 316 |
+
loss function. This name will be used to combine different loss items
|
| 317 |
+
by simple sum operation. In addition, if you want this loss item to be
|
| 318 |
+
included into the backward graph, `loss_` must be the prefix of the
|
| 319 |
+
name.
|
| 320 |
+
Returns:
|
| 321 |
+
str: The name of this loss item.
|
| 322 |
+
"""
|
| 323 |
+
return self._loss_name
|
modelsforCIML/mmseg/models/losses/tversky_loss.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
"""Modified from
|
| 3 |
+
https://github.com/JunMa11/SegLoss/blob/master/losses_pytorch/dice_loss.py#L333
|
| 4 |
+
(Apache-2.0 License)"""
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from ..builder import LOSSES
|
| 10 |
+
from .utils import get_class_weight, weighted_loss
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@weighted_loss
|
| 14 |
+
def tversky_loss(pred,
|
| 15 |
+
target,
|
| 16 |
+
valid_mask,
|
| 17 |
+
alpha=0.3,
|
| 18 |
+
beta=0.7,
|
| 19 |
+
smooth=1,
|
| 20 |
+
class_weight=None,
|
| 21 |
+
ignore_index=255):
|
| 22 |
+
assert pred.shape[0] == target.shape[0]
|
| 23 |
+
total_loss = 0
|
| 24 |
+
num_classes = pred.shape[1]
|
| 25 |
+
for i in range(num_classes):
|
| 26 |
+
if i != ignore_index:
|
| 27 |
+
tversky_loss = binary_tversky_loss(
|
| 28 |
+
pred[:, i],
|
| 29 |
+
target[..., i],
|
| 30 |
+
valid_mask=valid_mask,
|
| 31 |
+
alpha=alpha,
|
| 32 |
+
beta=beta,
|
| 33 |
+
smooth=smooth)
|
| 34 |
+
if class_weight is not None:
|
| 35 |
+
tversky_loss *= class_weight[i]
|
| 36 |
+
total_loss += tversky_loss
|
| 37 |
+
return total_loss / num_classes
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@weighted_loss
|
| 41 |
+
def binary_tversky_loss(pred,
|
| 42 |
+
target,
|
| 43 |
+
valid_mask,
|
| 44 |
+
alpha=0.3,
|
| 45 |
+
beta=0.7,
|
| 46 |
+
smooth=1):
|
| 47 |
+
assert pred.shape[0] == target.shape[0]
|
| 48 |
+
pred = pred.reshape(pred.shape[0], -1)
|
| 49 |
+
target = target.reshape(target.shape[0], -1)
|
| 50 |
+
valid_mask = valid_mask.reshape(valid_mask.shape[0], -1)
|
| 51 |
+
|
| 52 |
+
TP = torch.sum(torch.mul(pred, target) * valid_mask, dim=1)
|
| 53 |
+
FP = torch.sum(torch.mul(pred, 1 - target) * valid_mask, dim=1)
|
| 54 |
+
FN = torch.sum(torch.mul(1 - pred, target) * valid_mask, dim=1)
|
| 55 |
+
tversky = (TP + smooth) / (TP + alpha * FP + beta * FN + smooth)
|
| 56 |
+
|
| 57 |
+
return 1 - tversky
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@LOSSES.register_module()
|
| 61 |
+
class TverskyLoss(nn.Module):
|
| 62 |
+
"""TverskyLoss. This loss is proposed in `Tversky loss function for image
|
| 63 |
+
segmentation using 3D fully convolutional deep networks.
|
| 64 |
+
|
| 65 |
+
<https://arxiv.org/abs/1706.05721>`_.
|
| 66 |
+
Args:
|
| 67 |
+
smooth (float): A float number to smooth loss, and avoid NaN error.
|
| 68 |
+
Default: 1.
|
| 69 |
+
class_weight (list[float] | str, optional): Weight of each class. If in
|
| 70 |
+
str format, read them from a file. Defaults to None.
|
| 71 |
+
loss_weight (float, optional): Weight of the loss. Default to 1.0.
|
| 72 |
+
ignore_index (int | None): The label index to be ignored. Default: 255.
|
| 73 |
+
alpha(float, in [0, 1]):
|
| 74 |
+
The coefficient of false positives. Default: 0.3.
|
| 75 |
+
beta (float, in [0, 1]):
|
| 76 |
+
The coefficient of false negatives. Default: 0.7.
|
| 77 |
+
Note: alpha + beta = 1.
|
| 78 |
+
loss_name (str, optional): Name of the loss item. If you want this loss
|
| 79 |
+
item to be included into the backward graph, `loss_` must be the
|
| 80 |
+
prefix of the name. Defaults to 'loss_tversky'.
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
def __init__(self,
|
| 84 |
+
smooth=1,
|
| 85 |
+
class_weight=None,
|
| 86 |
+
loss_weight=1.0,
|
| 87 |
+
ignore_index=255,
|
| 88 |
+
alpha=0.3,
|
| 89 |
+
beta=0.7,
|
| 90 |
+
loss_name='loss_tversky'):
|
| 91 |
+
super(TverskyLoss, self).__init__()
|
| 92 |
+
self.smooth = smooth
|
| 93 |
+
self.class_weight = get_class_weight(class_weight)
|
| 94 |
+
self.loss_weight = loss_weight
|
| 95 |
+
self.ignore_index = ignore_index
|
| 96 |
+
assert (alpha + beta == 1.0), 'Sum of alpha and beta but be 1.0!'
|
| 97 |
+
self.alpha = alpha
|
| 98 |
+
self.beta = beta
|
| 99 |
+
self._loss_name = loss_name
|
| 100 |
+
|
| 101 |
+
def forward(self, pred, target, **kwargs):
|
| 102 |
+
if self.class_weight is not None:
|
| 103 |
+
class_weight = pred.new_tensor(self.class_weight)
|
| 104 |
+
else:
|
| 105 |
+
class_weight = None
|
| 106 |
+
|
| 107 |
+
pred = F.softmax(pred, dim=1)
|
| 108 |
+
num_classes = pred.shape[1]
|
| 109 |
+
one_hot_target = F.one_hot(
|
| 110 |
+
torch.clamp(target.long(), 0, num_classes - 1),
|
| 111 |
+
num_classes=num_classes)
|
| 112 |
+
valid_mask = (target != self.ignore_index).long()
|
| 113 |
+
|
| 114 |
+
loss = self.loss_weight * tversky_loss(
|
| 115 |
+
pred,
|
| 116 |
+
one_hot_target,
|
| 117 |
+
valid_mask=valid_mask,
|
| 118 |
+
alpha=self.alpha,
|
| 119 |
+
beta=self.beta,
|
| 120 |
+
smooth=self.smooth,
|
| 121 |
+
class_weight=class_weight,
|
| 122 |
+
ignore_index=self.ignore_index)
|
| 123 |
+
return loss
|
| 124 |
+
|
| 125 |
+
@property
|
| 126 |
+
def loss_name(self):
|
| 127 |
+
"""Loss Name.
|
| 128 |
+
|
| 129 |
+
This function must be implemented and will return the name of this
|
| 130 |
+
loss function. This name will be used to combine different loss items
|
| 131 |
+
by simple sum operation. In addition, if you want this loss item to be
|
| 132 |
+
included into the backward graph, `loss_` must be the prefix of the
|
| 133 |
+
name.
|
| 134 |
+
Returns:
|
| 135 |
+
str: The name of this loss item.
|
| 136 |
+
"""
|
| 137 |
+
return self._loss_name
|
modelsforCIML/mmseg/models/losses/utils.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import functools
|
| 3 |
+
|
| 4 |
+
import mmcv
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_class_weight(class_weight):
|
| 11 |
+
"""Get class weight for loss function.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
class_weight (list[float] | str | None): If class_weight is a str,
|
| 15 |
+
take it as a file name and read from it.
|
| 16 |
+
"""
|
| 17 |
+
if isinstance(class_weight, str):
|
| 18 |
+
# take it as a file path
|
| 19 |
+
if class_weight.endswith('.npy'):
|
| 20 |
+
class_weight = np.load(class_weight)
|
| 21 |
+
else:
|
| 22 |
+
# pkl, json or yaml
|
| 23 |
+
class_weight = mmcv.load(class_weight)
|
| 24 |
+
|
| 25 |
+
return class_weight
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def reduce_loss(loss, reduction):
|
| 29 |
+
"""Reduce loss as specified.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
loss (Tensor): Elementwise loss tensor.
|
| 33 |
+
reduction (str): Options are "none", "mean" and "sum".
|
| 34 |
+
|
| 35 |
+
Return:
|
| 36 |
+
Tensor: Reduced loss tensor.
|
| 37 |
+
"""
|
| 38 |
+
reduction_enum = F._Reduction.get_enum(reduction)
|
| 39 |
+
# none: 0, elementwise_mean:1, sum: 2
|
| 40 |
+
if reduction_enum == 0:
|
| 41 |
+
return loss
|
| 42 |
+
elif reduction_enum == 1:
|
| 43 |
+
return loss.mean()
|
| 44 |
+
elif reduction_enum == 2:
|
| 45 |
+
return loss.sum()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
|
| 49 |
+
"""Apply element-wise weight and reduce loss.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
loss (Tensor): Element-wise loss.
|
| 53 |
+
weight (Tensor): Element-wise weights.
|
| 54 |
+
reduction (str): Same as built-in losses of PyTorch.
|
| 55 |
+
avg_factor (float): Average factor when computing the mean of losses.
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
Tensor: Processed loss values.
|
| 59 |
+
"""
|
| 60 |
+
# if weight is specified, apply element-wise weight
|
| 61 |
+
if weight is not None:
|
| 62 |
+
assert weight.dim() == loss.dim()
|
| 63 |
+
if weight.dim() > 1:
|
| 64 |
+
assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
|
| 65 |
+
loss = loss * weight
|
| 66 |
+
|
| 67 |
+
# if avg_factor is not specified, just reduce the loss
|
| 68 |
+
if avg_factor is None:
|
| 69 |
+
loss = reduce_loss(loss, reduction)
|
| 70 |
+
else:
|
| 71 |
+
# if reduction is mean, then average the loss by avg_factor
|
| 72 |
+
if reduction == 'mean':
|
| 73 |
+
# Avoid causing ZeroDivisionError when avg_factor is 0.0,
|
| 74 |
+
# i.e., all labels of an image belong to ignore index.
|
| 75 |
+
eps = torch.finfo(torch.float32).eps
|
| 76 |
+
loss = loss.sum() / (avg_factor + eps)
|
| 77 |
+
# if reduction is 'none', then do nothing, otherwise raise an error
|
| 78 |
+
elif reduction != 'none':
|
| 79 |
+
raise ValueError('avg_factor can not be used with reduction="sum"')
|
| 80 |
+
return loss
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def weighted_loss(loss_func):
|
| 84 |
+
"""Create a weighted version of a given loss function.
|
| 85 |
+
|
| 86 |
+
To use this decorator, the loss function must have the signature like
|
| 87 |
+
`loss_func(pred, target, **kwargs)`. The function only needs to compute
|
| 88 |
+
element-wise loss without any reduction. This decorator will add weight
|
| 89 |
+
and reduction arguments to the function. The decorated function will have
|
| 90 |
+
the signature like `loss_func(pred, target, weight=None, reduction='mean',
|
| 91 |
+
avg_factor=None, **kwargs)`.
|
| 92 |
+
|
| 93 |
+
:Example:
|
| 94 |
+
|
| 95 |
+
>>> import torch
|
| 96 |
+
>>> @weighted_loss
|
| 97 |
+
>>> def l1_loss(pred, target):
|
| 98 |
+
>>> return (pred - target).abs()
|
| 99 |
+
|
| 100 |
+
>>> pred = torch.Tensor([0, 2, 3])
|
| 101 |
+
>>> target = torch.Tensor([1, 1, 1])
|
| 102 |
+
>>> weight = torch.Tensor([1, 0, 1])
|
| 103 |
+
|
| 104 |
+
>>> l1_loss(pred, target)
|
| 105 |
+
tensor(1.3333)
|
| 106 |
+
>>> l1_loss(pred, target, weight)
|
| 107 |
+
tensor(1.)
|
| 108 |
+
>>> l1_loss(pred, target, reduction='none')
|
| 109 |
+
tensor([1., 1., 2.])
|
| 110 |
+
>>> l1_loss(pred, target, weight, avg_factor=2)
|
| 111 |
+
tensor(1.5000)
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
@functools.wraps(loss_func)
|
| 115 |
+
def wrapper(pred,
|
| 116 |
+
target,
|
| 117 |
+
weight=None,
|
| 118 |
+
reduction='mean',
|
| 119 |
+
avg_factor=None,
|
| 120 |
+
**kwargs):
|
| 121 |
+
# get element-wise loss
|
| 122 |
+
loss = loss_func(pred, target, **kwargs)
|
| 123 |
+
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
|
| 124 |
+
return loss
|
| 125 |
+
|
| 126 |
+
return wrapper
|
modelsforCIML/mmseg/ops/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from .encoding import Encoding
|
| 3 |
+
from .wrappers import Upsample, resize
|
| 4 |
+
|
| 5 |
+
__all__ = ['Upsample', 'resize', 'Encoding']
|
modelsforCIML/mmseg/ops/encoding.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
from torch.nn import functional as F
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Encoding(nn.Module):
|
| 8 |
+
"""Encoding Layer: a learnable residual encoder.
|
| 9 |
+
|
| 10 |
+
Input is of shape (batch_size, channels, height, width).
|
| 11 |
+
Output is of shape (batch_size, num_codes, channels).
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
channels: dimension of the features or feature channels
|
| 15 |
+
num_codes: number of code words
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, channels, num_codes):
|
| 19 |
+
super(Encoding, self).__init__()
|
| 20 |
+
# init codewords and smoothing factor
|
| 21 |
+
self.channels, self.num_codes = channels, num_codes
|
| 22 |
+
std = 1. / ((num_codes * channels)**0.5)
|
| 23 |
+
# [num_codes, channels]
|
| 24 |
+
self.codewords = nn.Parameter(
|
| 25 |
+
torch.empty(num_codes, channels,
|
| 26 |
+
dtype=torch.float).uniform_(-std, std),
|
| 27 |
+
requires_grad=True)
|
| 28 |
+
# [num_codes]
|
| 29 |
+
self.scale = nn.Parameter(
|
| 30 |
+
torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0),
|
| 31 |
+
requires_grad=True)
|
| 32 |
+
|
| 33 |
+
@staticmethod
|
| 34 |
+
def scaled_l2(x, codewords, scale):
|
| 35 |
+
num_codes, channels = codewords.size()
|
| 36 |
+
batch_size = x.size(0)
|
| 37 |
+
reshaped_scale = scale.view((1, 1, num_codes))
|
| 38 |
+
expanded_x = x.unsqueeze(2).expand(
|
| 39 |
+
(batch_size, x.size(1), num_codes, channels))
|
| 40 |
+
reshaped_codewords = codewords.view((1, 1, num_codes, channels))
|
| 41 |
+
|
| 42 |
+
scaled_l2_norm = reshaped_scale * (
|
| 43 |
+
expanded_x - reshaped_codewords).pow(2).sum(dim=3)
|
| 44 |
+
return scaled_l2_norm
|
| 45 |
+
|
| 46 |
+
@staticmethod
|
| 47 |
+
def aggregate(assignment_weights, x, codewords):
|
| 48 |
+
num_codes, channels = codewords.size()
|
| 49 |
+
reshaped_codewords = codewords.view((1, 1, num_codes, channels))
|
| 50 |
+
batch_size = x.size(0)
|
| 51 |
+
|
| 52 |
+
expanded_x = x.unsqueeze(2).expand(
|
| 53 |
+
(batch_size, x.size(1), num_codes, channels))
|
| 54 |
+
encoded_feat = (assignment_weights.unsqueeze(3) *
|
| 55 |
+
(expanded_x - reshaped_codewords)).sum(dim=1)
|
| 56 |
+
return encoded_feat
|
| 57 |
+
|
| 58 |
+
def forward(self, x):
|
| 59 |
+
assert x.dim() == 4 and x.size(1) == self.channels
|
| 60 |
+
# [batch_size, channels, height, width]
|
| 61 |
+
batch_size = x.size(0)
|
| 62 |
+
# [batch_size, height x width, channels]
|
| 63 |
+
x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous()
|
| 64 |
+
# assignment_weights: [batch_size, channels, num_codes]
|
| 65 |
+
assignment_weights = F.softmax(
|
| 66 |
+
self.scaled_l2(x, self.codewords, self.scale), dim=2)
|
| 67 |
+
# aggregate
|
| 68 |
+
encoded_feat = self.aggregate(assignment_weights, x, self.codewords)
|
| 69 |
+
return encoded_feat
|
| 70 |
+
|
| 71 |
+
def __repr__(self):
|
| 72 |
+
repr_str = self.__class__.__name__
|
| 73 |
+
repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \
|
| 74 |
+
f'x{self.channels})'
|
| 75 |
+
return repr_str
|
modelsforCIML/mmseg/ops/wrappers.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import warnings
|
| 3 |
+
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def resize(input,
|
| 9 |
+
size=None,
|
| 10 |
+
scale_factor=None,
|
| 11 |
+
mode='nearest',
|
| 12 |
+
align_corners=None,
|
| 13 |
+
warning=True):
|
| 14 |
+
if warning:
|
| 15 |
+
if size is not None and align_corners:
|
| 16 |
+
input_h, input_w = tuple(int(x) for x in input.shape[2:])
|
| 17 |
+
output_h, output_w = tuple(int(x) for x in size)
|
| 18 |
+
if output_h > input_h or output_w > input_w:
|
| 19 |
+
if ((output_h > 1 and output_w > 1 and input_h > 1
|
| 20 |
+
and input_w > 1) and (output_h - 1) % (input_h - 1)
|
| 21 |
+
and (output_w - 1) % (input_w - 1)):
|
| 22 |
+
warnings.warn(
|
| 23 |
+
f'When align_corners={align_corners}, '
|
| 24 |
+
'the output would more aligned if '
|
| 25 |
+
f'input size {(input_h, input_w)} is `x+1` and '
|
| 26 |
+
f'out size {(output_h, output_w)} is `nx+1`')
|
| 27 |
+
return F.interpolate(input, size, scale_factor, mode, align_corners)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class Upsample(nn.Module):
|
| 31 |
+
|
| 32 |
+
def __init__(self,
|
| 33 |
+
size=None,
|
| 34 |
+
scale_factor=None,
|
| 35 |
+
mode='nearest',
|
| 36 |
+
align_corners=None):
|
| 37 |
+
super(Upsample, self).__init__()
|
| 38 |
+
self.size = size
|
| 39 |
+
if isinstance(scale_factor, tuple):
|
| 40 |
+
self.scale_factor = tuple(float(factor) for factor in scale_factor)
|
| 41 |
+
else:
|
| 42 |
+
self.scale_factor = float(scale_factor) if scale_factor else None
|
| 43 |
+
self.mode = mode
|
| 44 |
+
self.align_corners = align_corners
|
| 45 |
+
|
| 46 |
+
def forward(self, x):
|
| 47 |
+
if not self.size:
|
| 48 |
+
size = [int(t * self.scale_factor) for t in x.shape[-2:]]
|
| 49 |
+
else:
|
| 50 |
+
size = self.size
|
| 51 |
+
return resize(x, size, None, self.mode, self.align_corners)
|
modelsforCIML/mmseg/utils/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from .collect_env import collect_env
|
| 3 |
+
from .logger import get_root_logger
|
| 4 |
+
from .misc import find_latest_checkpoint
|
| 5 |
+
from .set_env import setup_multi_processes
|
| 6 |
+
from .util_distribution import build_ddp, build_dp, get_device
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
'get_root_logger', 'collect_env', 'find_latest_checkpoint',
|
| 10 |
+
'setup_multi_processes', 'build_ddp', 'build_dp', 'get_device'
|
| 11 |
+
]
|