InPeerReview commited on
Commit
4e89a1c
·
verified ·
1 Parent(s): 4a9d419

Upload 2 files

Browse files
Files changed (2) hide show
  1. utils/data.py +103 -0
  2. utils/metric.py +206 -0
utils/data.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.utils.data as Data
4
+ import torchvision.transforms as transforms
5
+
6
+ import os
7
+ from PIL import Image, ImageOps, ImageFilter
8
+ import os.path as osp
9
+ import sys
10
+ import random
11
+ import shutil
12
+
13
+
14
+ class IRSTD_Dataset(Data.Dataset):
15
+ def __init__(self, args, mode='train'):
16
+
17
+ dataset_dir = args.dataset_dir
18
+
19
+ if mode == 'train':
20
+ txtfile = 'trainval.txt'
21
+ elif mode == 'val':
22
+ txtfile = 'test.txt'
23
+
24
+ self.list_dir = osp.join(dataset_dir, txtfile)
25
+ self.imgs_dir = osp.join(dataset_dir, 'images')
26
+ self.label_dir = osp.join(dataset_dir, 'masks')
27
+
28
+ self.names = []
29
+ with open(self.list_dir, 'r') as f:
30
+ self.names += [line.strip() for line in f.readlines()]
31
+
32
+ self.mode = mode
33
+ self.crop_size = args.crop_size
34
+ self.base_size = args.base_size
35
+ self.transform = transforms.Compose([
36
+ transforms.ToTensor(),
37
+ transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
38
+ ])
39
+
40
+ def __getitem__(self, i):
41
+ name = self.names[i]
42
+ img_path = osp.join(self.imgs_dir, name + '.png')
43
+ label_path = osp.join(self.label_dir, name + '.png')
44
+
45
+ img = Image.open(img_path).convert('RGB')
46
+ mask = Image.open(label_path)
47
+
48
+ if self.mode == 'train':
49
+ img, mask = self._sync_transform(img, mask)
50
+ elif self.mode == 'val':
51
+ img, mask = self._testval_sync_transform(img, mask)
52
+ else:
53
+ raise ValueError("Unkown self.mode")
54
+
55
+ img, mask = self.transform(img), transforms.ToTensor()(mask)
56
+ return img, mask
57
+
58
+ def __len__(self):
59
+ return len(self.names)
60
+
61
+ def _sync_transform(self, img, mask):
62
+ # random mirror
63
+ if random.random() < 0.5:
64
+ img = img.transpose(Image.FLIP_LEFT_RIGHT)
65
+ mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
66
+ crop_size = self.crop_size
67
+ # random scale (short edge)
68
+ long_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0))
69
+ w, h = img.size
70
+ if h > w:
71
+ oh = long_size
72
+ ow = int(1.0 * w * long_size / h + 0.5)
73
+ short_size = ow
74
+ else:
75
+ ow = long_size
76
+ oh = int(1.0 * h * long_size / w + 0.5)
77
+ short_size = oh
78
+ img = img.resize((ow, oh), Image.BILINEAR)
79
+ mask = mask.resize((ow, oh), Image.NEAREST)
80
+ # pad crop
81
+ if short_size < crop_size:
82
+ padh = crop_size - oh if oh < crop_size else 0
83
+ padw = crop_size - ow if ow < crop_size else 0
84
+ img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
85
+ mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0)
86
+ # random crop crop_size
87
+ w, h = img.size
88
+ x1 = random.randint(0, w - crop_size)
89
+ y1 = random.randint(0, h - crop_size)
90
+ img = img.crop((x1, y1, x1 + crop_size, y1 + crop_size))
91
+ mask = mask.crop((x1, y1, x1 + crop_size, y1 + crop_size))
92
+ # gaussian blur as in PSP
93
+ if random.random() < 0.5:
94
+ img = img.filter(ImageFilter.GaussianBlur(
95
+ radius=random.random()))
96
+ return img, mask
97
+
98
+ def _testval_sync_transform(self, img, mask):
99
+ base_size = self.base_size
100
+ img = img.resize((base_size, base_size), Image.BILINEAR)
101
+ mask = mask.resize((base_size, base_size), Image.NEAREST)
102
+
103
+ return img, mask
utils/metric.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch.nn as nn
3
+ import torch
4
+ from skimage import measure
5
+ import numpy
6
+
7
+
8
+ class ROCMetric():
9
+ """Computes pixAcc and mIoU metric scores
10
+ """
11
+
12
+ def __init__(self, nclass, bins): # bin的意义实际上是确定ROC曲线上的threshold取多少个离散值
13
+ super(ROCMetric, self).__init__()
14
+ self.nclass = nclass
15
+ self.bins = bins
16
+ self.tp_arr = np.zeros(self.bins + 1)
17
+ self.pos_arr = np.zeros(self.bins + 1)
18
+ self.fp_arr = np.zeros(self.bins + 1)
19
+ self.neg_arr = np.zeros(self.bins + 1)
20
+ self.class_pos = np.zeros(self.bins + 1)
21
+ # self.reset()
22
+
23
+ def update(self, preds, labels):
24
+ for iBin in range(self.bins + 1):
25
+ score_thresh = (iBin + 0.0) / self.bins
26
+ # print(iBin, "-th, score_thresh: ", score_thresh)
27
+ i_tp, i_pos, i_fp, i_neg, i_class_pos = cal_tp_pos_fp_neg(preds, labels, self.nclass, score_thresh)
28
+ self.tp_arr[iBin] += i_tp
29
+ self.pos_arr[iBin] += i_pos
30
+ self.fp_arr[iBin] += i_fp
31
+ self.neg_arr[iBin] += i_neg
32
+ self.class_pos[iBin] += i_class_pos
33
+
34
+ def get(self):
35
+ tp_rates = self.tp_arr / (self.pos_arr + 0.001)
36
+ fp_rates = self.fp_arr / (self.neg_arr + 0.001)
37
+
38
+ recall = self.tp_arr / (self.pos_arr + 0.001)
39
+ precision = self.tp_arr / (self.class_pos + 0.001)
40
+
41
+ return tp_rates, fp_rates, recall, precision
42
+
43
+ def reset(self):
44
+ self.tp_arr = np.zeros([11])
45
+ self.pos_arr = np.zeros([11])
46
+ self.fp_arr = np.zeros([11])
47
+ self.neg_arr = np.zeros([11])
48
+ self.class_pos = np.zeros([11])
49
+
50
+
51
+ class PD_FA():
52
+ def __init__(self, nclass, bins, size):
53
+ super(PD_FA, self).__init__()
54
+ self.nclass = nclass
55
+ self.bins = bins
56
+ self.image_area_total = []
57
+ self.image_area_match = []
58
+ self.FA = np.zeros(self.bins + 1)
59
+ self.PD = np.zeros(self.bins + 1)
60
+ self.target = np.zeros(self.bins + 1)
61
+ self.size = size
62
+
63
+ def update(self, preds, labels):
64
+
65
+ for iBin in range(self.bins + 1):
66
+ score_thresh = iBin * (255 / self.bins)
67
+ predits = np.array((preds > score_thresh).cpu()).astype('int64')
68
+
69
+ predits = np.reshape(predits, (self.size, self.size))
70
+ labelss = np.array((labels).cpu()).astype('int64')
71
+ labelss = np.reshape(labelss, (self.size, self.size))
72
+
73
+ image = measure.label(predits, connectivity=2)
74
+ coord_image = measure.regionprops(image)
75
+ label = measure.label(labelss, connectivity=2)
76
+ coord_label = measure.regionprops(label)
77
+
78
+ self.target[iBin] += len(coord_label)
79
+ self.image_area_total = []
80
+ self.image_area_match = []
81
+ self.distance_match = []
82
+ self.dismatch = []
83
+
84
+ for K in range(len(coord_image)):
85
+ area_image = np.array(coord_image[K].area)
86
+ self.image_area_total.append(area_image)
87
+
88
+ for i in range(len(coord_label)):
89
+ centroid_label = np.array(list(coord_label[i].centroid))
90
+ for m in range(len(coord_image)):
91
+ centroid_image = np.array(list(coord_image[m].centroid))
92
+ distance = np.linalg.norm(centroid_image - centroid_label)
93
+ area_image = np.array(coord_image[m].area)
94
+ if distance < 3:
95
+ self.distance_match.append(distance)
96
+ self.image_area_match.append(area_image)
97
+
98
+ del coord_image[m]
99
+ break
100
+
101
+ self.dismatch = [x for x in self.image_area_total if x not in self.image_area_match]
102
+ self.FA[iBin] += np.sum(self.dismatch)
103
+ self.PD[iBin] += len(self.distance_match)
104
+
105
+ def get(self, img_num):
106
+
107
+ Final_FA = self.FA / ((self.size * self.size) * img_num)
108
+ Final_PD = self.PD / self.target
109
+
110
+ return Final_FA, Final_PD
111
+
112
+ def reset(self):
113
+ self.FA = np.zeros([self.bins + 1])
114
+ self.PD = np.zeros([self.bins + 1])
115
+
116
+
117
+ class mIoU():
118
+
119
+ def __init__(self, nclass):
120
+ super(mIoU, self).__init__()
121
+ self.nclass = nclass
122
+ self.reset()
123
+
124
+ def update(self, preds, labels):
125
+ # print('come_ininin')
126
+
127
+ correct, labeled = batch_pix_accuracy(preds, labels)
128
+ inter, union = batch_intersection_union(preds, labels, self.nclass)
129
+ self.total_correct += correct
130
+ self.total_label += labeled
131
+ self.total_inter += inter
132
+ self.total_union += union
133
+
134
+ def get(self):
135
+ pixAcc = 1.0 * self.total_correct / (np.spacing(1) + self.total_label)
136
+ IoU = 1.0 * self.total_inter / (np.spacing(1) + self.total_union)
137
+ mIoU = IoU.mean()
138
+ return pixAcc, mIoU
139
+
140
+ def reset(self):
141
+ self.total_inter = 0
142
+ self.total_union = 0
143
+ self.total_correct = 0
144
+ self.total_label = 0
145
+
146
+
147
+ def cal_tp_pos_fp_neg(output, target, nclass, score_thresh):
148
+ predict = (torch.sigmoid(output) > score_thresh).float()
149
+ if len(target.shape) == 3:
150
+ target = np.expand_dims(target.float(), axis=1)
151
+ elif len(target.shape) == 4:
152
+ target = target.float()
153
+ else:
154
+ raise ValueError("Unknown target dimension")
155
+
156
+ intersection = predict * ((predict == target).float())
157
+
158
+ tp = intersection.sum()
159
+ fp = (predict * ((predict != target).float())).sum()
160
+ tn = ((1 - predict) * ((predict == target).float())).sum()
161
+ fn = (((predict != target).float()) * (1 - predict)).sum()
162
+ pos = tp + fn
163
+ neg = fp + tn
164
+ class_pos = tp + fp
165
+
166
+ return tp, pos, fp, neg, class_pos
167
+
168
+
169
+ def batch_pix_accuracy(output, target):
170
+ if len(target.shape) == 3:
171
+ target = np.expand_dims(target.float(), axis=1)
172
+ elif len(target.shape) == 4:
173
+ target = target.float()
174
+ else:
175
+ raise ValueError("Unknown target dimension")
176
+
177
+ assert output.shape == target.shape, "Predict and Label Shape Don't Match"
178
+ predict = (output > 0).float()
179
+ pixel_labeled = (target > 0).float().sum()
180
+ pixel_correct = (((predict == target).float()) * ((target > 0)).float()).sum()
181
+
182
+ assert pixel_correct <= pixel_labeled, "Correct area should be smaller than Labeled"
183
+ return pixel_correct, pixel_labeled
184
+
185
+
186
+ def batch_intersection_union(output, target, nclass):
187
+ mini = 1
188
+ maxi = 1
189
+ nbins = 1
190
+ predict = (output > 0).float()
191
+ if len(target.shape) == 3:
192
+ target = np.expand_dims(target.float(), axis=1)
193
+ elif len(target.shape) == 4:
194
+ target = target.float()
195
+ else:
196
+ raise ValueError("Unknown target dimension")
197
+ intersection = predict * ((predict == target).float())
198
+
199
+ area_inter, _ = np.histogram(intersection.cpu(), bins=nbins, range=(mini, maxi))
200
+ area_pred, _ = np.histogram(predict.cpu(), bins=nbins, range=(mini, maxi))
201
+ area_lab, _ = np.histogram(target.cpu(), bins=nbins, range=(mini, maxi))
202
+ area_union = area_pred + area_lab - area_inter
203
+
204
+ assert (area_inter <= area_union).all(), \
205
+ "Error: Intersection area should be smaller than Union area"
206
+ return area_inter, area_union