Upload 2 files
Browse files- utils/data.py +103 -0
- utils/metric.py +206 -0
utils/data.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.utils.data as Data
|
| 4 |
+
import torchvision.transforms as transforms
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from PIL import Image, ImageOps, ImageFilter
|
| 8 |
+
import os.path as osp
|
| 9 |
+
import sys
|
| 10 |
+
import random
|
| 11 |
+
import shutil
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class IRSTD_Dataset(Data.Dataset):
|
| 15 |
+
def __init__(self, args, mode='train'):
|
| 16 |
+
|
| 17 |
+
dataset_dir = args.dataset_dir
|
| 18 |
+
|
| 19 |
+
if mode == 'train':
|
| 20 |
+
txtfile = 'trainval.txt'
|
| 21 |
+
elif mode == 'val':
|
| 22 |
+
txtfile = 'test.txt'
|
| 23 |
+
|
| 24 |
+
self.list_dir = osp.join(dataset_dir, txtfile)
|
| 25 |
+
self.imgs_dir = osp.join(dataset_dir, 'images')
|
| 26 |
+
self.label_dir = osp.join(dataset_dir, 'masks')
|
| 27 |
+
|
| 28 |
+
self.names = []
|
| 29 |
+
with open(self.list_dir, 'r') as f:
|
| 30 |
+
self.names += [line.strip() for line in f.readlines()]
|
| 31 |
+
|
| 32 |
+
self.mode = mode
|
| 33 |
+
self.crop_size = args.crop_size
|
| 34 |
+
self.base_size = args.base_size
|
| 35 |
+
self.transform = transforms.Compose([
|
| 36 |
+
transforms.ToTensor(),
|
| 37 |
+
transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
|
| 38 |
+
])
|
| 39 |
+
|
| 40 |
+
def __getitem__(self, i):
|
| 41 |
+
name = self.names[i]
|
| 42 |
+
img_path = osp.join(self.imgs_dir, name + '.png')
|
| 43 |
+
label_path = osp.join(self.label_dir, name + '.png')
|
| 44 |
+
|
| 45 |
+
img = Image.open(img_path).convert('RGB')
|
| 46 |
+
mask = Image.open(label_path)
|
| 47 |
+
|
| 48 |
+
if self.mode == 'train':
|
| 49 |
+
img, mask = self._sync_transform(img, mask)
|
| 50 |
+
elif self.mode == 'val':
|
| 51 |
+
img, mask = self._testval_sync_transform(img, mask)
|
| 52 |
+
else:
|
| 53 |
+
raise ValueError("Unkown self.mode")
|
| 54 |
+
|
| 55 |
+
img, mask = self.transform(img), transforms.ToTensor()(mask)
|
| 56 |
+
return img, mask
|
| 57 |
+
|
| 58 |
+
def __len__(self):
|
| 59 |
+
return len(self.names)
|
| 60 |
+
|
| 61 |
+
def _sync_transform(self, img, mask):
|
| 62 |
+
# random mirror
|
| 63 |
+
if random.random() < 0.5:
|
| 64 |
+
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
| 65 |
+
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
|
| 66 |
+
crop_size = self.crop_size
|
| 67 |
+
# random scale (short edge)
|
| 68 |
+
long_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0))
|
| 69 |
+
w, h = img.size
|
| 70 |
+
if h > w:
|
| 71 |
+
oh = long_size
|
| 72 |
+
ow = int(1.0 * w * long_size / h + 0.5)
|
| 73 |
+
short_size = ow
|
| 74 |
+
else:
|
| 75 |
+
ow = long_size
|
| 76 |
+
oh = int(1.0 * h * long_size / w + 0.5)
|
| 77 |
+
short_size = oh
|
| 78 |
+
img = img.resize((ow, oh), Image.BILINEAR)
|
| 79 |
+
mask = mask.resize((ow, oh), Image.NEAREST)
|
| 80 |
+
# pad crop
|
| 81 |
+
if short_size < crop_size:
|
| 82 |
+
padh = crop_size - oh if oh < crop_size else 0
|
| 83 |
+
padw = crop_size - ow if ow < crop_size else 0
|
| 84 |
+
img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
|
| 85 |
+
mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0)
|
| 86 |
+
# random crop crop_size
|
| 87 |
+
w, h = img.size
|
| 88 |
+
x1 = random.randint(0, w - crop_size)
|
| 89 |
+
y1 = random.randint(0, h - crop_size)
|
| 90 |
+
img = img.crop((x1, y1, x1 + crop_size, y1 + crop_size))
|
| 91 |
+
mask = mask.crop((x1, y1, x1 + crop_size, y1 + crop_size))
|
| 92 |
+
# gaussian blur as in PSP
|
| 93 |
+
if random.random() < 0.5:
|
| 94 |
+
img = img.filter(ImageFilter.GaussianBlur(
|
| 95 |
+
radius=random.random()))
|
| 96 |
+
return img, mask
|
| 97 |
+
|
| 98 |
+
def _testval_sync_transform(self, img, mask):
|
| 99 |
+
base_size = self.base_size
|
| 100 |
+
img = img.resize((base_size, base_size), Image.BILINEAR)
|
| 101 |
+
mask = mask.resize((base_size, base_size), Image.NEAREST)
|
| 102 |
+
|
| 103 |
+
return img, mask
|
utils/metric.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch
|
| 4 |
+
from skimage import measure
|
| 5 |
+
import numpy
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ROCMetric():
|
| 9 |
+
"""Computes pixAcc and mIoU metric scores
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
def __init__(self, nclass, bins): # bin的意义实际上是确定ROC曲线上的threshold取多少个离散值
|
| 13 |
+
super(ROCMetric, self).__init__()
|
| 14 |
+
self.nclass = nclass
|
| 15 |
+
self.bins = bins
|
| 16 |
+
self.tp_arr = np.zeros(self.bins + 1)
|
| 17 |
+
self.pos_arr = np.zeros(self.bins + 1)
|
| 18 |
+
self.fp_arr = np.zeros(self.bins + 1)
|
| 19 |
+
self.neg_arr = np.zeros(self.bins + 1)
|
| 20 |
+
self.class_pos = np.zeros(self.bins + 1)
|
| 21 |
+
# self.reset()
|
| 22 |
+
|
| 23 |
+
def update(self, preds, labels):
|
| 24 |
+
for iBin in range(self.bins + 1):
|
| 25 |
+
score_thresh = (iBin + 0.0) / self.bins
|
| 26 |
+
# print(iBin, "-th, score_thresh: ", score_thresh)
|
| 27 |
+
i_tp, i_pos, i_fp, i_neg, i_class_pos = cal_tp_pos_fp_neg(preds, labels, self.nclass, score_thresh)
|
| 28 |
+
self.tp_arr[iBin] += i_tp
|
| 29 |
+
self.pos_arr[iBin] += i_pos
|
| 30 |
+
self.fp_arr[iBin] += i_fp
|
| 31 |
+
self.neg_arr[iBin] += i_neg
|
| 32 |
+
self.class_pos[iBin] += i_class_pos
|
| 33 |
+
|
| 34 |
+
def get(self):
|
| 35 |
+
tp_rates = self.tp_arr / (self.pos_arr + 0.001)
|
| 36 |
+
fp_rates = self.fp_arr / (self.neg_arr + 0.001)
|
| 37 |
+
|
| 38 |
+
recall = self.tp_arr / (self.pos_arr + 0.001)
|
| 39 |
+
precision = self.tp_arr / (self.class_pos + 0.001)
|
| 40 |
+
|
| 41 |
+
return tp_rates, fp_rates, recall, precision
|
| 42 |
+
|
| 43 |
+
def reset(self):
|
| 44 |
+
self.tp_arr = np.zeros([11])
|
| 45 |
+
self.pos_arr = np.zeros([11])
|
| 46 |
+
self.fp_arr = np.zeros([11])
|
| 47 |
+
self.neg_arr = np.zeros([11])
|
| 48 |
+
self.class_pos = np.zeros([11])
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class PD_FA():
|
| 52 |
+
def __init__(self, nclass, bins, size):
|
| 53 |
+
super(PD_FA, self).__init__()
|
| 54 |
+
self.nclass = nclass
|
| 55 |
+
self.bins = bins
|
| 56 |
+
self.image_area_total = []
|
| 57 |
+
self.image_area_match = []
|
| 58 |
+
self.FA = np.zeros(self.bins + 1)
|
| 59 |
+
self.PD = np.zeros(self.bins + 1)
|
| 60 |
+
self.target = np.zeros(self.bins + 1)
|
| 61 |
+
self.size = size
|
| 62 |
+
|
| 63 |
+
def update(self, preds, labels):
|
| 64 |
+
|
| 65 |
+
for iBin in range(self.bins + 1):
|
| 66 |
+
score_thresh = iBin * (255 / self.bins)
|
| 67 |
+
predits = np.array((preds > score_thresh).cpu()).astype('int64')
|
| 68 |
+
|
| 69 |
+
predits = np.reshape(predits, (self.size, self.size))
|
| 70 |
+
labelss = np.array((labels).cpu()).astype('int64')
|
| 71 |
+
labelss = np.reshape(labelss, (self.size, self.size))
|
| 72 |
+
|
| 73 |
+
image = measure.label(predits, connectivity=2)
|
| 74 |
+
coord_image = measure.regionprops(image)
|
| 75 |
+
label = measure.label(labelss, connectivity=2)
|
| 76 |
+
coord_label = measure.regionprops(label)
|
| 77 |
+
|
| 78 |
+
self.target[iBin] += len(coord_label)
|
| 79 |
+
self.image_area_total = []
|
| 80 |
+
self.image_area_match = []
|
| 81 |
+
self.distance_match = []
|
| 82 |
+
self.dismatch = []
|
| 83 |
+
|
| 84 |
+
for K in range(len(coord_image)):
|
| 85 |
+
area_image = np.array(coord_image[K].area)
|
| 86 |
+
self.image_area_total.append(area_image)
|
| 87 |
+
|
| 88 |
+
for i in range(len(coord_label)):
|
| 89 |
+
centroid_label = np.array(list(coord_label[i].centroid))
|
| 90 |
+
for m in range(len(coord_image)):
|
| 91 |
+
centroid_image = np.array(list(coord_image[m].centroid))
|
| 92 |
+
distance = np.linalg.norm(centroid_image - centroid_label)
|
| 93 |
+
area_image = np.array(coord_image[m].area)
|
| 94 |
+
if distance < 3:
|
| 95 |
+
self.distance_match.append(distance)
|
| 96 |
+
self.image_area_match.append(area_image)
|
| 97 |
+
|
| 98 |
+
del coord_image[m]
|
| 99 |
+
break
|
| 100 |
+
|
| 101 |
+
self.dismatch = [x for x in self.image_area_total if x not in self.image_area_match]
|
| 102 |
+
self.FA[iBin] += np.sum(self.dismatch)
|
| 103 |
+
self.PD[iBin] += len(self.distance_match)
|
| 104 |
+
|
| 105 |
+
def get(self, img_num):
|
| 106 |
+
|
| 107 |
+
Final_FA = self.FA / ((self.size * self.size) * img_num)
|
| 108 |
+
Final_PD = self.PD / self.target
|
| 109 |
+
|
| 110 |
+
return Final_FA, Final_PD
|
| 111 |
+
|
| 112 |
+
def reset(self):
|
| 113 |
+
self.FA = np.zeros([self.bins + 1])
|
| 114 |
+
self.PD = np.zeros([self.bins + 1])
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class mIoU():
|
| 118 |
+
|
| 119 |
+
def __init__(self, nclass):
|
| 120 |
+
super(mIoU, self).__init__()
|
| 121 |
+
self.nclass = nclass
|
| 122 |
+
self.reset()
|
| 123 |
+
|
| 124 |
+
def update(self, preds, labels):
|
| 125 |
+
# print('come_ininin')
|
| 126 |
+
|
| 127 |
+
correct, labeled = batch_pix_accuracy(preds, labels)
|
| 128 |
+
inter, union = batch_intersection_union(preds, labels, self.nclass)
|
| 129 |
+
self.total_correct += correct
|
| 130 |
+
self.total_label += labeled
|
| 131 |
+
self.total_inter += inter
|
| 132 |
+
self.total_union += union
|
| 133 |
+
|
| 134 |
+
def get(self):
|
| 135 |
+
pixAcc = 1.0 * self.total_correct / (np.spacing(1) + self.total_label)
|
| 136 |
+
IoU = 1.0 * self.total_inter / (np.spacing(1) + self.total_union)
|
| 137 |
+
mIoU = IoU.mean()
|
| 138 |
+
return pixAcc, mIoU
|
| 139 |
+
|
| 140 |
+
def reset(self):
|
| 141 |
+
self.total_inter = 0
|
| 142 |
+
self.total_union = 0
|
| 143 |
+
self.total_correct = 0
|
| 144 |
+
self.total_label = 0
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def cal_tp_pos_fp_neg(output, target, nclass, score_thresh):
|
| 148 |
+
predict = (torch.sigmoid(output) > score_thresh).float()
|
| 149 |
+
if len(target.shape) == 3:
|
| 150 |
+
target = np.expand_dims(target.float(), axis=1)
|
| 151 |
+
elif len(target.shape) == 4:
|
| 152 |
+
target = target.float()
|
| 153 |
+
else:
|
| 154 |
+
raise ValueError("Unknown target dimension")
|
| 155 |
+
|
| 156 |
+
intersection = predict * ((predict == target).float())
|
| 157 |
+
|
| 158 |
+
tp = intersection.sum()
|
| 159 |
+
fp = (predict * ((predict != target).float())).sum()
|
| 160 |
+
tn = ((1 - predict) * ((predict == target).float())).sum()
|
| 161 |
+
fn = (((predict != target).float()) * (1 - predict)).sum()
|
| 162 |
+
pos = tp + fn
|
| 163 |
+
neg = fp + tn
|
| 164 |
+
class_pos = tp + fp
|
| 165 |
+
|
| 166 |
+
return tp, pos, fp, neg, class_pos
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def batch_pix_accuracy(output, target):
|
| 170 |
+
if len(target.shape) == 3:
|
| 171 |
+
target = np.expand_dims(target.float(), axis=1)
|
| 172 |
+
elif len(target.shape) == 4:
|
| 173 |
+
target = target.float()
|
| 174 |
+
else:
|
| 175 |
+
raise ValueError("Unknown target dimension")
|
| 176 |
+
|
| 177 |
+
assert output.shape == target.shape, "Predict and Label Shape Don't Match"
|
| 178 |
+
predict = (output > 0).float()
|
| 179 |
+
pixel_labeled = (target > 0).float().sum()
|
| 180 |
+
pixel_correct = (((predict == target).float()) * ((target > 0)).float()).sum()
|
| 181 |
+
|
| 182 |
+
assert pixel_correct <= pixel_labeled, "Correct area should be smaller than Labeled"
|
| 183 |
+
return pixel_correct, pixel_labeled
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def batch_intersection_union(output, target, nclass):
|
| 187 |
+
mini = 1
|
| 188 |
+
maxi = 1
|
| 189 |
+
nbins = 1
|
| 190 |
+
predict = (output > 0).float()
|
| 191 |
+
if len(target.shape) == 3:
|
| 192 |
+
target = np.expand_dims(target.float(), axis=1)
|
| 193 |
+
elif len(target.shape) == 4:
|
| 194 |
+
target = target.float()
|
| 195 |
+
else:
|
| 196 |
+
raise ValueError("Unknown target dimension")
|
| 197 |
+
intersection = predict * ((predict == target).float())
|
| 198 |
+
|
| 199 |
+
area_inter, _ = np.histogram(intersection.cpu(), bins=nbins, range=(mini, maxi))
|
| 200 |
+
area_pred, _ = np.histogram(predict.cpu(), bins=nbins, range=(mini, maxi))
|
| 201 |
+
area_lab, _ = np.histogram(target.cpu(), bins=nbins, range=(mini, maxi))
|
| 202 |
+
area_union = area_pred + area_lab - area_inter
|
| 203 |
+
|
| 204 |
+
assert (area_inter <= area_union).all(), \
|
| 205 |
+
"Error: Intersection area should be smaller than Union area"
|
| 206 |
+
return area_inter, area_union
|