HirraA commited on
Commit
168ec29
·
verified ·
1 Parent(s): cf4d9b9

Upload 30 files

Browse files
NWRD_dataset.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from PIL import Image
3
+ import os
4
+ from torchvision import transforms
5
+
6
+ class NWRD(Dataset):
7
+ def __init__(self, root_dir, transform=None, train=True):
8
+ self.root_dir = root_dir
9
+ self.transform = transform
10
+ self.images = []
11
+ self.labels = []
12
+ self.load_data()
13
+
14
+ def load_data(self):
15
+ rust_dir = os.path.join(self.root_dir, "rust")
16
+ non_rust_dir = os.path.join(self.root_dir, "non_rust")
17
+
18
+ # Load rust images
19
+ for filename in os.listdir(rust_dir):
20
+ filepath = os.path.join(rust_dir, filename)
21
+ self.images.append(filepath)
22
+ self.labels.append(1)
23
+
24
+ # Load non-rust images
25
+ for filename in os.listdir(non_rust_dir):
26
+ filepath = os.path.join(non_rust_dir, filename)
27
+ self.images.append(filepath)
28
+ self.labels.append(0)
29
+
30
+ def __len__(self):
31
+ return len(self.images)
32
+
33
+ def __getitem__(self, idx):
34
+ image_path = self.images[idx]
35
+ image = Image.open(image_path).convert('RGB')
36
+
37
+ label = int(self.labels[idx])
38
+ if self.transform:
39
+ image = self.transform(image)
40
+ return image, label
README.md CHANGED
@@ -1,3 +1,52 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DCFM
2
+ The official repo of the paper `Democracy Does Matter: Comprehensive Feature Mining for Co-Salient Object Detection`.
3
+
4
+ ## Environment Requirement
5
+ create enviroment and intall as following:
6
+ `pip install -r requirements.txt`
7
+
8
+ ## Data Format
9
+ trainset: CoCo-SEG
10
+
11
+ testset: CoCA, CoSOD3k, Cosal2015
12
+
13
+ Put the [CoCo-SEG](https://drive.google.com/file/d/1GbA_WKvJm04Z1tR8pTSzBdYVQ75avg4f/view), [CoCA](http://zhaozhang.net/coca.html), [CoSOD3k](http://dpfan.net/CoSOD3K/) and [Cosal2015](https://drive.google.com/u/0/uc?id=1mmYpGx17t8WocdPcw2WKeuFpz6VHoZ6K&export=download) datasets to `DCFM/data` as the following structure:
14
+ ```
15
+ DCFM
16
+ ├── other codes
17
+ ├── ...
18
+
19
+ └── data
20
+
21
+ ├── CoCo-SEG (CoCo-SEG's image files)
22
+ ├── CoCA (CoCA's image files)
23
+ ├── CoSOD3k (CoSOD3k's image files)
24
+ └── Cosal2015 (Cosal2015's image files)
25
+ ```
26
+
27
+ ## Trained model
28
+
29
+ trained model can be downloaded from [papermodel](https://drive.google.com/file/d/1cfuq4eJoCwvFR9W1XOJX7Y0ttd8TGjlp/view?usp=sharing).
30
+
31
+ Run `test.py` for inference.
32
+
33
+ The evaluation tool please follow: https://github.com/zzhanghub/eval-co-sod
34
+
35
+
36
+ <!-- USAGE EXAMPLES -->
37
+ ## Usage
38
+ Download pretrainde backbone model [VGG](https://drive.google.com/file/d/1Z1aAYXMyJ6txQ1Z9N7gtxLOIai4dxrXd/view?usp=sharing).
39
+
40
+ Run `train.py` for training.
41
+
42
+ ## Prediction results
43
+ The co-saliency maps of DCFM can be found at [preds](https://drive.google.com/file/d/1wGeNHXFWVSyqvmL4NIUmEFdlHDovEtQR/view?usp=sharing).
44
+
45
+ ## Reproduction
46
+ reproductions by myself on 2080Ti can be found at [reproduction1](https://drive.google.com/file/d/1vovii0RtYR_EC0Y2zxjY_cTWKWM3WaxP/view?usp=sharing) and [reproduction2](https://drive.google.com/file/d/1YPOKZ5kBtmZrCDhHpP3-w1GMVR5BfDoU/view?usp=sharing).
47
+
48
+ reprodution by myself on TITAN X can be found at [reproduction3](https://drive.google.com/file/d/1bnGFtRTYkVXqI2dcjeWFRDXnqqbUUBJr/view?usp=sharing).
49
+
50
+ ## Others
51
+ The code is based on [GCoNet](https://github.com/fanq15/GCoNet).
52
+ I've added a validation part to help select the model for closer results. This validation part is based on [GCoNet_plus](https://github.com/ZhengPeng7/GCoNet_plus). You can try different evaluation metrics to select the model.
config.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ class Config():
5
+ def __init__(self) -> None:
6
+
7
+ # Performance of GCoNet
8
+ self.val_measures = {
9
+ 'Emax': {'CoCA': 0.783, 'CoSOD3k': 0.874, 'CoSal2015': 0.892},
10
+ 'Smeasure': {'CoCA': 0.710, 'CoSOD3k': 0.810, 'CoSal2015': 0.838},
11
+ 'Fmax': {'CoCA': 0.598, 'CoSOD3k': 0.805, 'CoSal2015': 0.856},
12
+ }
13
+
14
+ # others
15
+
16
+
17
+ self.validation = True
dataloader.cpython-37.pyc ADDED
Binary file (1.82 kB). View file
 
dataloader.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils import data
2
+ import os
3
+ from PIL import Image, ImageFile
4
+
5
+
6
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
7
+
8
+
9
+ class EvalDataset(data.Dataset):
10
+ def __init__(self, pred_root, label_root, return_predpath=False, return_gtpath=False):
11
+ self.return_predpath = return_predpath
12
+ self.return_gtpath = return_gtpath
13
+ pred_dirs = os.listdir(pred_root)
14
+ label_dirs = os.listdir(label_root)
15
+
16
+ dir_name_list = []
17
+ for idir in pred_dirs:
18
+ if idir in label_dirs:
19
+ pred_names = os.listdir(os.path.join(pred_root, idir))
20
+ label_names = os.listdir(os.path.join(label_root, idir))
21
+ for iname in pred_names:
22
+ if iname in label_names:
23
+ dir_name_list.append(os.path.join(idir, iname))
24
+
25
+ self.image_path = list(
26
+ map(lambda x: os.path.join(pred_root, x), dir_name_list))
27
+ self.label_path = list(
28
+ map(lambda x: os.path.join(label_root, x), dir_name_list))
29
+
30
+ self.labels = []
31
+ for p in self.label_path:
32
+ self.labels.append(Image.open(p).convert('L'))
33
+
34
+
35
+ def __getitem__(self, item):
36
+ predpath = self.image_path[item]
37
+ gtpath = self.label_path[item]
38
+ pred = Image.open(predpath).convert('L')
39
+ gt = self.labels[item]
40
+ if pred.size != gt.size:
41
+ pred = pred.resize(gt.size, Image.BILINEAR)
42
+ returns = [pred, gt]
43
+ if self.return_predpath:
44
+ returns.append(predpath)
45
+ if self.return_gtpath:
46
+ returns.append(gtpath)
47
+ return returns
48
+
49
+ def __len__(self):
50
+ return len(self.image_path)
dataset.cpython-38.pyc ADDED
Binary file (8.83 kB). View file
 
dataset.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image, ImageOps, ImageFilter#, PILLOW_VERSION
3
+ import torch
4
+ import random
5
+ import numpy as np
6
+ from torch.utils import data
7
+ from torchvision import transforms
8
+ from torchvision.transforms import functional as F
9
+ import numbers
10
+ import random
11
+ import pandas as pd
12
+
13
+
14
+ class CoData(data.Dataset):
15
+ def __init__(self, img_root, gt_root, img_size, transform, max_num, is_train):
16
+
17
+ class_list = os.listdir(img_root)
18
+ self.size = [img_size, img_size]
19
+ self.img_dirs = list(
20
+ map(lambda x: os.path.join(img_root, x), class_list))
21
+ self.gt_dirs = list(
22
+ map(lambda x: os.path.join(gt_root, x), class_list))
23
+ self.transform = transform
24
+ self.max_num = max_num
25
+ self.is_train = is_train
26
+
27
+ def __getitem__(self, item):
28
+ names = os.listdir(self.img_dirs[item])
29
+ num = len(names)
30
+ img_paths = list(
31
+ map(lambda x: os.path.join(self.img_dirs[item], x), names))
32
+ gt_paths = list(
33
+ map(lambda x: os.path.join(self.gt_dirs[item], x[:-4]+'.png'), names))
34
+
35
+ if self.is_train:
36
+ final_num = min(num, self.max_num)
37
+
38
+ sampled_list = random.sample(range(num), final_num)
39
+ # print(sampled_list)
40
+ new_img_paths = [img_paths[i] for i in sampled_list]
41
+ img_paths = new_img_paths
42
+ new_gt_paths = [gt_paths[i] for i in sampled_list]
43
+ gt_paths = new_gt_paths
44
+
45
+ final_num = final_num
46
+ else:
47
+ final_num = num
48
+
49
+ imgs = torch.Tensor(final_num, 3, self.size[0], self.size[1])
50
+ gts = torch.Tensor(final_num, 1, self.size[0], self.size[1])
51
+
52
+ subpaths = []
53
+ ori_sizes = []
54
+ for idx in range(final_num):
55
+ # print(idx)
56
+ img = Image.open(img_paths[idx]).convert('RGB')
57
+ gt = Image.open(gt_paths[idx]).convert('L')
58
+
59
+ subpaths.append(os.path.join(img_paths[idx].split('/')[-2], img_paths[idx].split('/')[-1][:-4]+'.png'))
60
+ ori_sizes.append((img.size[1], img.size[0]))
61
+ # ori_sizes += ((img.size[1], img.size[0]))
62
+
63
+ [img, gt] = self.transform(img, gt)
64
+
65
+ imgs[idx] = img
66
+ gts[idx] = gt
67
+ if self.is_train:
68
+ cls_ls = [item] * int(final_num)
69
+ return imgs, gts, subpaths, ori_sizes, cls_ls
70
+ else:
71
+ return imgs, gts, subpaths, ori_sizes
72
+
73
+ def __len__(self):
74
+ return len(self.img_dirs)
75
+
76
+
77
+ class FixedResize(object):
78
+ def __init__(self, size):
79
+ self.size = (size, size) # size: (h, w)
80
+
81
+ def __call__(self, img, gt):
82
+ # assert img.size == gt.size
83
+
84
+ img = img.resize(self.size, Image.BILINEAR)
85
+ gt = gt.resize(self.size, Image.NEAREST)
86
+ # gt = gt.resize(self.size, Image.BILINEAR)
87
+
88
+ return img, gt
89
+
90
+
91
+ class ToTensor(object):
92
+ def __call__(self, img, gt):
93
+
94
+ return F.to_tensor(img), F.to_tensor(gt)
95
+
96
+
97
+ class Normalize(object):
98
+ """Normalize a tensor image with mean and standard deviation.
99
+ Args:
100
+ mean (tuple): means for each channel.
101
+ std (tuple): standard deviations for each channel.
102
+ """
103
+
104
+ def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)):
105
+ self.mean = mean
106
+ self.std = std
107
+
108
+ def __call__(self, img, gt):
109
+ img = F.normalize(img, self.mean, self.std)
110
+
111
+ return img, gt
112
+
113
+
114
+ class RandomHorizontalFlip(object):
115
+ def __init__(self, p=0.5):
116
+ self.p = p
117
+
118
+ def __call__(self, img, gt):
119
+ if random.random() < self.p:
120
+ img = img.transpose(Image.FLIP_LEFT_RIGHT)
121
+ gt = gt.transpose(Image.FLIP_LEFT_RIGHT)
122
+
123
+ return img, gt
124
+
125
+
126
+ class RandomScaleCrop(object):
127
+ def __init__(self, base_size, crop_size, fill=0):
128
+ self.base_size = base_size
129
+ self.crop_size = crop_size
130
+ self.fill = fill
131
+
132
+ def __call__(self, img, mask):
133
+ # random scale (short edge)
134
+ # img = img.numpy()
135
+ # mask = mask.numpy()
136
+ short_size = random.randint(int(self.base_size * 0.8), int(self.base_size * 1.2))
137
+ w, h = img.size
138
+ if h > w:
139
+ ow = short_size
140
+ oh = int(1.0 * h * ow / w)
141
+ else:
142
+ oh = short_size
143
+ ow = int(1.0 * w * oh / h)
144
+ img = img.resize((ow, oh), Image.BILINEAR)
145
+ mask = mask.resize((ow, oh), Image.NEAREST)
146
+ # pad crop
147
+ if short_size < self.crop_size:
148
+ padh = self.crop_size - oh if oh < self.crop_size else 0
149
+ padw = self.crop_size - ow if ow < self.crop_size else 0
150
+ img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
151
+ mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill)
152
+ # random crop crop_size
153
+ w, h = img.size
154
+ x1 = random.randint(0, w - self.crop_size)
155
+ y1 = random.randint(0, h - self.crop_size)
156
+ img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
157
+ mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
158
+
159
+ return img, mask
160
+
161
+
162
+ class RandomRotation(object):
163
+ def __init__(self, degrees, resample=False, expand=False, center=None):
164
+ if isinstance(degrees, numbers.Number):
165
+ if degrees < 0:
166
+ raise ValueError("If degrees is a single number, it must be positive.")
167
+ self.degrees = (-degrees, degrees)
168
+ else:
169
+ if len(degrees) != 2:
170
+ raise ValueError("If degrees is a sequence, it must be of len 2.")
171
+ self.degrees = degrees
172
+
173
+ self.resample = resample
174
+ self.expand = expand
175
+ self.center = center
176
+
177
+ @staticmethod
178
+ def get_params(degrees):
179
+ angle = random.uniform(degrees[0], degrees[1])
180
+
181
+ return angle
182
+
183
+ def __call__(self, img, gt):
184
+ """
185
+ img (PIL Image): Image to be rotated.
186
+
187
+ Returns:
188
+ PIL Image: Rotated image.
189
+ """
190
+
191
+ angle = self.get_params(self.degrees)
192
+
193
+ return F.rotate(img, angle, Image.BILINEAR, self.expand, self.center), F.rotate(gt, angle, Image.NEAREST, self.expand, self.center)
194
+
195
+
196
+
197
+ class Compose(object):
198
+ def __init__(self, transforms):
199
+ self.transforms = transforms
200
+
201
+ def __call__(self, img, gt):
202
+ for t in self.transforms:
203
+ img, gt = t(img, gt)
204
+ return img, gt
205
+
206
+ def __repr__(self):
207
+ format_string = self.__class__.__name__ + '('
208
+ for t in self.transforms:
209
+ format_string += '\n'
210
+ format_string += ' {0}'.format(t)
211
+ format_string += '\n)'
212
+ return format_string
213
+
214
+
215
+ # get the dataloader (Note: without data augmentation)
216
+ def get_loader(img_root, gt_root, img_size, batch_size, max_num = float('inf'), istrain=True, shuffle=False, num_workers=0, pin=False):
217
+ if istrain:
218
+ transform = Compose([
219
+ RandomScaleCrop(img_size*2, img_size*2),
220
+ FixedResize(img_size),
221
+ RandomHorizontalFlip(),
222
+
223
+ RandomRotation((-90, 90)),
224
+ ToTensor(),
225
+ Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
226
+ ])
227
+ else:
228
+ transform = Compose([
229
+ FixedResize(img_size),
230
+ # RandomHorizontalFlip(),
231
+ ToTensor(),
232
+ Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
233
+ ])
234
+
235
+ dataset = CoData(img_root, gt_root, img_size, transform, max_num, is_train=istrain)
236
+ data_loader = data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers,
237
+ pin_memory=pin)
238
+ return data_loader
239
+
240
+
241
+ if __name__ == '__main__':
242
+ import matplotlib.pyplot as plt
243
+
244
+ mean = [0.485, 0.456, 0.406]
245
+ std = [0.229, 0.224, 0.225]
246
+ img_root = './data/testtrain/img/'
247
+ gt_root = './data/testtrain/gt/'
248
+ loader = get_loader(img_root, gt_root, 20, 1, 16, istrain=False)
249
+ for batch in loader:
250
+ b, c, h, w = batch[0][0].shape
251
+ for i in range(b):
252
+ img = batch[0].squeeze(0)[i].permute(1, 2, 0).cpu().numpy() * std + mean
253
+ image = img * 255
254
+ mask = batch[1].squeeze(0)[i].squeeze().cpu().numpy()
255
+ plt.subplot(121)
256
+ plt.imshow(np.uint8(image))
257
+ plt.subplot(122)
258
+ plt.imshow(mask)
259
+ plt.show(block=True)
dataset_preprocessing.ipynb ADDED
@@ -0,0 +1,1046 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "/home/wej36how/.conda/envs/dmt/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13
+ " from .autonotebook import tqdm as notebook_tqdm\n"
14
+ ]
15
+ }
16
+ ],
17
+ "source": [
18
+ "import glob\n",
19
+ "import os\n",
20
+ "import cv2\n",
21
+ "import numpy as np\n",
22
+ "from PIL import Image\n",
23
+ "import torch \n",
24
+ "from PIL import Image\n",
25
+ "from pathlib import Path\n",
26
+ "import torchvision.transforms as T\n",
27
+ "import torchvision.transforms.functional as TF\n",
28
+ "import numpy as np\n",
29
+ "from torchvision import transforms\n",
30
+ "import os\n",
31
+ "import cv2\n",
32
+ "import matplotlib.pyplot as plt\n",
33
+ "import shutil"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": 2,
39
+ "metadata": {},
40
+ "outputs": [],
41
+ "source": [
42
+ "source = \"/scratch/wej36how/Datasets/NWRD/val\"\n",
43
+ "dest = \"/scratch/wej36how/Datasets/NWRDProcessed/val\"\n",
44
+ "patch_size = 224\n",
45
+ "rust_threshold = 150\n",
46
+ "max_number_of_images_per_group = 12"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "markdown",
51
+ "metadata": {},
52
+ "source": [
53
+ "This snippet will make patches of the images in the destination/patches directory."
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "execution_count": null,
59
+ "metadata": {},
60
+ "outputs": [],
61
+ "source": [
62
+ "patches_path = os.path.join(dest, \"patches\")\n",
63
+ "images_dir = os.path.join(patches_path, \"images\")\n",
64
+ "masks_dir = os.path.join(patches_path, \"masks\")"
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "code",
69
+ "execution_count": 3,
70
+ "metadata": {},
71
+ "outputs": [
72
+ {
73
+ "name": "stdout",
74
+ "output_type": "stream",
75
+ "text": [
76
+ "/scratch/wej36how/Datasets/NWRD/train/images/10.jpg\n",
77
+ "image shape: (4000, 6016, 3)\n",
78
+ "total patches: 442\n",
79
+ "/scratch/wej36how/Datasets/NWRD/train/images/100.jpg\n",
80
+ "image shape: (4608, 3456, 3)\n",
81
+ "total patches: 300\n",
82
+ "/scratch/wej36how/Datasets/NWRD/train/images/101.jpg\n",
83
+ "image shape: (3456, 3612, 3)\n",
84
+ "total patches: 240\n",
85
+ "/scratch/wej36how/Datasets/NWRD/train/images/102.jpg\n",
86
+ "image shape: (2984, 4248, 3)\n",
87
+ "total patches: 234\n",
88
+ "/scratch/wej36how/Datasets/NWRD/train/images/103.jpg\n",
89
+ "image shape: (3584, 3456, 3)\n",
90
+ "total patches: 226\n",
91
+ "/scratch/wej36how/Datasets/NWRD/train/images/104.jpg\n",
92
+ "image shape: (3456, 4608, 3)\n",
93
+ "total patches: 300\n",
94
+ "/scratch/wej36how/Datasets/NWRD/train/images/109.jpg\n",
95
+ "image shape: (4608, 3456, 3)\n",
96
+ "total patches: 300\n",
97
+ "/scratch/wej36how/Datasets/NWRD/train/images/11.jpg\n",
98
+ "image shape: (4000, 6016, 3)\n",
99
+ "total patches: 442\n",
100
+ "/scratch/wej36how/Datasets/NWRD/train/images/110.jpg\n",
101
+ "image shape: (4600, 2536, 3)\n",
102
+ "total patches: 220\n",
103
+ "/scratch/wej36how/Datasets/NWRD/train/images/111.jpg\n",
104
+ "image shape: (2909, 4608, 3)\n",
105
+ "total patches: 240\n",
106
+ "/scratch/wej36how/Datasets/NWRD/train/images/113.jpg\n",
107
+ "image shape: (4608, 3456, 3)\n",
108
+ "total patches: 300\n",
109
+ "/scratch/wej36how/Datasets/NWRD/train/images/114.jpg\n",
110
+ "image shape: (3456, 4608, 3)\n",
111
+ "total patches: 300\n",
112
+ "/scratch/wej36how/Datasets/NWRD/train/images/117.jpg\n",
113
+ "image shape: (3456, 4608, 3)\n",
114
+ "total patches: 300\n",
115
+ "/scratch/wej36how/Datasets/NWRD/train/images/118.jpg\n",
116
+ "image shape: (3456, 4608, 3)\n",
117
+ "total patches: 300\n",
118
+ "/scratch/wej36how/Datasets/NWRD/train/images/119.jpg\n",
119
+ "image shape: (3456, 4608, 3)\n",
120
+ "total patches: 300\n",
121
+ "/scratch/wej36how/Datasets/NWRD/train/images/12.jpg\n",
122
+ "image shape: (4000, 6016, 3)\n",
123
+ "total patches: 442\n",
124
+ "/scratch/wej36how/Datasets/NWRD/train/images/121.jpg\n",
125
+ "image shape: (3456, 4608, 3)\n",
126
+ "total patches: 300\n",
127
+ "/scratch/wej36how/Datasets/NWRD/train/images/122.jpg\n",
128
+ "image shape: (3456, 4608, 3)\n",
129
+ "total patches: 300\n",
130
+ "/scratch/wej36how/Datasets/NWRD/train/images/124.jpg\n",
131
+ "image shape: (4608, 3456, 3)\n",
132
+ "total patches: 300\n",
133
+ "/scratch/wej36how/Datasets/NWRD/train/images/125.jpg\n",
134
+ "image shape: (4608, 3456, 3)\n",
135
+ "total patches: 300\n",
136
+ "/scratch/wej36how/Datasets/NWRD/train/images/128.jpg\n",
137
+ "image shape: (3456, 4608, 3)\n",
138
+ "total patches: 300\n",
139
+ "/scratch/wej36how/Datasets/NWRD/train/images/129.jpg\n",
140
+ "image shape: (3456, 4608, 3)\n",
141
+ "total patches: 300\n",
142
+ "/scratch/wej36how/Datasets/NWRD/train/images/13.jpg\n",
143
+ "image shape: (4000, 6016, 3)\n",
144
+ "total patches: 442\n",
145
+ "/scratch/wej36how/Datasets/NWRD/train/images/130.jpg\n",
146
+ "image shape: (3456, 4608, 3)\n",
147
+ "total patches: 300\n",
148
+ "/scratch/wej36how/Datasets/NWRD/train/images/131.jpg\n",
149
+ "image shape: (3456, 4608, 3)\n",
150
+ "total patches: 300\n",
151
+ "/scratch/wej36how/Datasets/NWRD/train/images/132.jpg\n",
152
+ "image shape: (4608, 3456, 3)\n",
153
+ "total patches: 300\n",
154
+ "/scratch/wej36how/Datasets/NWRD/train/images/133.jpg\n",
155
+ "image shape: (4608, 3456, 3)\n",
156
+ "total patches: 300\n",
157
+ "/scratch/wej36how/Datasets/NWRD/train/images/134.jpg\n",
158
+ "image shape: (4608, 3456, 3)\n",
159
+ "total patches: 300\n",
160
+ "/scratch/wej36how/Datasets/NWRD/train/images/135.jpg\n",
161
+ "image shape: (4608, 3456, 3)\n",
162
+ "total patches: 300\n",
163
+ "/scratch/wej36how/Datasets/NWRD/train/images/136.jpg\n",
164
+ "image shape: (3456, 4608, 3)\n",
165
+ "total patches: 300\n",
166
+ "/scratch/wej36how/Datasets/NWRD/train/images/137.jpg\n",
167
+ "image shape: (4608, 3456, 3)\n",
168
+ "total patches: 300\n",
169
+ "/scratch/wej36how/Datasets/NWRD/train/images/138.jpg\n",
170
+ "image shape: (4608, 3456, 3)\n",
171
+ "total patches: 300\n",
172
+ "/scratch/wej36how/Datasets/NWRD/train/images/139.jpg\n",
173
+ "image shape: (3456, 4608, 3)\n",
174
+ "total patches: 300\n",
175
+ "/scratch/wej36how/Datasets/NWRD/train/images/14.jpg\n",
176
+ "image shape: (4000, 6016, 3)\n",
177
+ "total patches: 442\n",
178
+ "/scratch/wej36how/Datasets/NWRD/train/images/140.jpg\n",
179
+ "image shape: (3456, 4608, 3)\n",
180
+ "total patches: 300\n",
181
+ "/scratch/wej36how/Datasets/NWRD/train/images/141.jpg\n",
182
+ "image shape: (4608, 3456, 3)\n",
183
+ "total patches: 300\n",
184
+ "/scratch/wej36how/Datasets/NWRD/train/images/142.jpg\n",
185
+ "image shape: (4608, 3456, 3)\n",
186
+ "total patches: 300\n",
187
+ "/scratch/wej36how/Datasets/NWRD/train/images/144.jpg\n",
188
+ "image shape: (4608, 3456, 3)\n",
189
+ "total patches: 300\n",
190
+ "/scratch/wej36how/Datasets/NWRD/train/images/145.jpg\n",
191
+ "image shape: (4608, 3456, 3)\n",
192
+ "total patches: 300\n",
193
+ "/scratch/wej36how/Datasets/NWRD/train/images/146.jpg\n",
194
+ "image shape: (4608, 3456, 3)\n",
195
+ "total patches: 300\n",
196
+ "/scratch/wej36how/Datasets/NWRD/train/images/149.jpg\n",
197
+ "image shape: (3968, 3424, 3)\n",
198
+ "total patches: 255\n",
199
+ "/scratch/wej36how/Datasets/NWRD/train/images/15.jpg\n",
200
+ "image shape: (4000, 6016, 3)\n",
201
+ "total patches: 442\n",
202
+ "/scratch/wej36how/Datasets/NWRD/train/images/150.jpg\n",
203
+ "image shape: (3456, 4608, 3)\n",
204
+ "total patches: 300\n",
205
+ "/scratch/wej36how/Datasets/NWRD/train/images/19.jpg\n",
206
+ "image shape: (4000, 6016, 3)\n",
207
+ "total patches: 442\n",
208
+ "/scratch/wej36how/Datasets/NWRD/train/images/2.jpg\n",
209
+ "image shape: (4000, 6016, 3)\n",
210
+ "total patches: 442\n",
211
+ "/scratch/wej36how/Datasets/NWRD/train/images/21.jpg\n",
212
+ "image shape: (4000, 6016, 3)\n",
213
+ "total patches: 442\n",
214
+ "/scratch/wej36how/Datasets/NWRD/train/images/24.jpg\n",
215
+ "image shape: (4000, 6016, 3)\n",
216
+ "total patches: 442\n",
217
+ "/scratch/wej36how/Datasets/NWRD/train/images/26.jpg\n",
218
+ "image shape: (4000, 6016, 3)\n",
219
+ "total patches: 442\n",
220
+ "/scratch/wej36how/Datasets/NWRD/train/images/27.jpg\n",
221
+ "image shape: (4000, 6016, 3)\n",
222
+ "total patches: 442\n",
223
+ "/scratch/wej36how/Datasets/NWRD/train/images/28.jpg\n",
224
+ "image shape: (4000, 6016, 3)\n",
225
+ "total patches: 442\n",
226
+ "/scratch/wej36how/Datasets/NWRD/train/images/3.jpg\n",
227
+ "image shape: (4000, 6016, 3)\n",
228
+ "total patches: 442\n",
229
+ "/scratch/wej36how/Datasets/NWRD/train/images/30.jpg\n",
230
+ "image shape: (4000, 6016, 3)\n",
231
+ "total patches: 442\n",
232
+ "/scratch/wej36how/Datasets/NWRD/train/images/31.jpg\n",
233
+ "image shape: (4000, 6016, 3)\n",
234
+ "total patches: 442\n",
235
+ "/scratch/wej36how/Datasets/NWRD/train/images/32.jpg\n",
236
+ "image shape: (4000, 6016, 3)\n",
237
+ "total patches: 442\n",
238
+ "/scratch/wej36how/Datasets/NWRD/train/images/33.jpg\n",
239
+ "image shape: (4000, 6016, 3)\n",
240
+ "total patches: 442\n",
241
+ "/scratch/wej36how/Datasets/NWRD/train/images/5.jpg\n",
242
+ "image shape: (4000, 6016, 3)\n",
243
+ "total patches: 442\n",
244
+ "/scratch/wej36how/Datasets/NWRD/train/images/57.jpg\n",
245
+ "image shape: (4000, 6016, 3)\n",
246
+ "total patches: 442\n",
247
+ "/scratch/wej36how/Datasets/NWRD/train/images/58.jpg\n",
248
+ "image shape: (4000, 6016, 3)\n",
249
+ "total patches: 442\n",
250
+ "/scratch/wej36how/Datasets/NWRD/train/images/60.jpg\n",
251
+ "image shape: (4000, 6016, 3)\n",
252
+ "total patches: 442\n",
253
+ "/scratch/wej36how/Datasets/NWRD/train/images/64.jpg\n",
254
+ "image shape: (4000, 6016, 3)\n",
255
+ "total patches: 442\n",
256
+ "/scratch/wej36how/Datasets/NWRD/train/images/69.jpg\n",
257
+ "image shape: (4608, 3456, 3)\n",
258
+ "total patches: 300\n",
259
+ "/scratch/wej36how/Datasets/NWRD/train/images/71.jpg\n",
260
+ "image shape: (3456, 4608, 3)\n",
261
+ "total patches: 300\n",
262
+ "/scratch/wej36how/Datasets/NWRD/train/images/72.jpg\n",
263
+ "image shape: (3456, 4608, 3)\n",
264
+ "total patches: 300\n",
265
+ "/scratch/wej36how/Datasets/NWRD/train/images/73.jpg\n",
266
+ "image shape: (3456, 4608, 3)\n",
267
+ "total patches: 300\n",
268
+ "/scratch/wej36how/Datasets/NWRD/train/images/74.jpg\n",
269
+ "image shape: (3456, 4608, 3)\n",
270
+ "total patches: 300\n",
271
+ "/scratch/wej36how/Datasets/NWRD/train/images/75.jpg\n",
272
+ "image shape: (3456, 4608, 3)\n",
273
+ "total patches: 300\n",
274
+ "/scratch/wej36how/Datasets/NWRD/train/images/76.jpg\n",
275
+ "image shape: (3456, 4608, 3)\n",
276
+ "total patches: 300\n",
277
+ "/scratch/wej36how/Datasets/NWRD/train/images/78.jpg\n",
278
+ "image shape: (3456, 4608, 3)\n",
279
+ "total patches: 300\n",
280
+ "/scratch/wej36how/Datasets/NWRD/train/images/79.jpg\n",
281
+ "image shape: (3456, 4608, 3)\n",
282
+ "total patches: 300\n",
283
+ "/scratch/wej36how/Datasets/NWRD/train/images/81.jpg\n",
284
+ "image shape: (3456, 4608, 3)\n",
285
+ "total patches: 300\n",
286
+ "/scratch/wej36how/Datasets/NWRD/train/images/83.jpg\n",
287
+ "image shape: (3456, 4608, 3)\n",
288
+ "total patches: 300\n",
289
+ "/scratch/wej36how/Datasets/NWRD/train/images/85.jpg\n",
290
+ "image shape: (4608, 3456, 3)\n",
291
+ "total patches: 300\n",
292
+ "/scratch/wej36how/Datasets/NWRD/train/images/86.jpg\n",
293
+ "image shape: (4608, 3456, 3)\n",
294
+ "total patches: 300\n",
295
+ "/scratch/wej36how/Datasets/NWRD/train/images/87.jpg\n",
296
+ "image shape: (4608, 3456, 3)\n",
297
+ "total patches: 300\n",
298
+ "/scratch/wej36how/Datasets/NWRD/train/images/88.jpg\n",
299
+ "image shape: (4608, 3456, 3)\n",
300
+ "total patches: 300\n",
301
+ "/scratch/wej36how/Datasets/NWRD/train/images/9.jpg\n",
302
+ "image shape: (4000, 6016, 3)\n",
303
+ "total patches: 442\n",
304
+ "/scratch/wej36how/Datasets/NWRD/train/images/90.jpg\n",
305
+ "image shape: (4608, 3456, 3)\n",
306
+ "total patches: 300\n",
307
+ "/scratch/wej36how/Datasets/NWRD/train/images/91.jpg\n",
308
+ "image shape: (4608, 3456, 3)\n",
309
+ "total patches: 300\n",
310
+ "/scratch/wej36how/Datasets/NWRD/train/images/92.jpg\n",
311
+ "image shape: (4608, 3456, 3)\n",
312
+ "total patches: 300\n",
313
+ "/scratch/wej36how/Datasets/NWRD/train/images/93.jpg\n",
314
+ "image shape: (4608, 3456, 3)\n",
315
+ "total patches: 300\n",
316
+ "/scratch/wej36how/Datasets/NWRD/train/images/94.jpg\n",
317
+ "image shape: (4608, 3456, 3)\n",
318
+ "total patches: 300\n",
319
+ "/scratch/wej36how/Datasets/NWRD/train/images/95.jpg\n",
320
+ "image shape: (4608, 3456, 3)\n",
321
+ "total patches: 300\n",
322
+ "/scratch/wej36how/Datasets/NWRD/train/images/96.jpg\n",
323
+ "image shape: (4608, 3456, 3)\n",
324
+ "total patches: 300\n",
325
+ "/scratch/wej36how/Datasets/NWRD/train/images/98.jpg\n",
326
+ "image shape: (4608, 3456, 3)\n",
327
+ "total patches: 300\n",
328
+ "/scratch/wej36how/Datasets/NWRD/train/images/99.jpg\n",
329
+ "image shape: (4608, 3456, 3)\n",
330
+ "total patches: 300\n",
331
+ "total image count: 28523\n",
332
+ "/scratch/wej36how/Datasets/NWRD/train/masks/10.png\n",
333
+ "image shape: (4000, 6016, 3)\n",
334
+ "total patches: 442\n",
335
+ "/scratch/wej36how/Datasets/NWRD/train/masks/100.png\n",
336
+ "image shape: (4608, 3456, 3)\n",
337
+ "total patches: 300\n",
338
+ "/scratch/wej36how/Datasets/NWRD/train/masks/101.png\n",
339
+ "image shape: (3435, 3593, 3)\n",
340
+ "total patches: 240\n",
341
+ "/scratch/wej36how/Datasets/NWRD/train/masks/102.png\n",
342
+ "image shape: (2984, 4248, 3)\n",
343
+ "total patches: 234\n",
344
+ "/scratch/wej36how/Datasets/NWRD/train/masks/103.png\n",
345
+ "image shape: (3584, 3456, 3)\n",
346
+ "total patches: 226\n",
347
+ "/scratch/wej36how/Datasets/NWRD/train/masks/104.png\n",
348
+ "image shape: (3456, 4608, 3)\n",
349
+ "total patches: 300\n",
350
+ "/scratch/wej36how/Datasets/NWRD/train/masks/109.png\n",
351
+ "image shape: (4608, 3456, 3)\n",
352
+ "total patches: 300\n",
353
+ "/scratch/wej36how/Datasets/NWRD/train/masks/11.png\n",
354
+ "image shape: (4000, 6016, 3)\n",
355
+ "total patches: 442\n",
356
+ "/scratch/wej36how/Datasets/NWRD/train/masks/110.png\n",
357
+ "image shape: (4600, 2536, 3)\n",
358
+ "total patches: 220\n",
359
+ "/scratch/wej36how/Datasets/NWRD/train/masks/111.png\n",
360
+ "image shape: (2909, 4608, 3)\n",
361
+ "total patches: 240\n",
362
+ "/scratch/wej36how/Datasets/NWRD/train/masks/113.png\n",
363
+ "image shape: (4608, 3456, 3)\n",
364
+ "total patches: 300\n",
365
+ "/scratch/wej36how/Datasets/NWRD/train/masks/114.png\n",
366
+ "image shape: (3456, 4608, 3)\n",
367
+ "total patches: 300\n",
368
+ "/scratch/wej36how/Datasets/NWRD/train/masks/117.png\n",
369
+ "image shape: (3456, 4608, 3)\n",
370
+ "total patches: 300\n",
371
+ "/scratch/wej36how/Datasets/NWRD/train/masks/118.png\n",
372
+ "image shape: (3456, 4608, 3)\n",
373
+ "total patches: 300\n",
374
+ "/scratch/wej36how/Datasets/NWRD/train/masks/119.png\n",
375
+ "image shape: (3456, 4608, 3)\n",
376
+ "total patches: 300\n",
377
+ "/scratch/wej36how/Datasets/NWRD/train/masks/12.png\n",
378
+ "image shape: (4000, 6016, 3)\n",
379
+ "total patches: 442\n",
380
+ "/scratch/wej36how/Datasets/NWRD/train/masks/121.png\n",
381
+ "image shape: (3456, 4608, 3)\n",
382
+ "total patches: 300\n",
383
+ "/scratch/wej36how/Datasets/NWRD/train/masks/122.png\n",
384
+ "image shape: (3456, 4608, 3)\n",
385
+ "total patches: 300\n",
386
+ "/scratch/wej36how/Datasets/NWRD/train/masks/124.png\n",
387
+ "image shape: (4608, 3456, 3)\n",
388
+ "total patches: 300\n",
389
+ "/scratch/wej36how/Datasets/NWRD/train/masks/125.png\n",
390
+ "image shape: (4608, 3456, 3)\n",
391
+ "total patches: 300\n",
392
+ "/scratch/wej36how/Datasets/NWRD/train/masks/128.png\n",
393
+ "image shape: (3456, 4608, 3)\n",
394
+ "total patches: 300\n",
395
+ "/scratch/wej36how/Datasets/NWRD/train/masks/129.png\n",
396
+ "image shape: (3456, 4608, 3)\n",
397
+ "total patches: 300\n",
398
+ "/scratch/wej36how/Datasets/NWRD/train/masks/13.png\n",
399
+ "image shape: (4000, 6016, 3)\n",
400
+ "total patches: 442\n",
401
+ "/scratch/wej36how/Datasets/NWRD/train/masks/130.png\n",
402
+ "image shape: (3456, 4608, 3)\n",
403
+ "total patches: 300\n",
404
+ "/scratch/wej36how/Datasets/NWRD/train/masks/131.png\n",
405
+ "image shape: (3456, 4608, 3)\n",
406
+ "total patches: 300\n",
407
+ "/scratch/wej36how/Datasets/NWRD/train/masks/132.png\n",
408
+ "image shape: (4608, 3456, 3)\n",
409
+ "total patches: 300\n",
410
+ "/scratch/wej36how/Datasets/NWRD/train/masks/133.png\n",
411
+ "image shape: (4608, 3456, 3)\n",
412
+ "total patches: 300\n",
413
+ "/scratch/wej36how/Datasets/NWRD/train/masks/134.png\n",
414
+ "image shape: (4608, 3456, 3)\n",
415
+ "total patches: 300\n",
416
+ "/scratch/wej36how/Datasets/NWRD/train/masks/135.png\n",
417
+ "image shape: (4608, 3456, 3)\n",
418
+ "total patches: 300\n",
419
+ "/scratch/wej36how/Datasets/NWRD/train/masks/136.png\n",
420
+ "image shape: (3456, 4608, 3)\n",
421
+ "total patches: 300\n",
422
+ "/scratch/wej36how/Datasets/NWRD/train/masks/137.png\n",
423
+ "image shape: (4608, 3456, 3)\n",
424
+ "total patches: 300\n",
425
+ "/scratch/wej36how/Datasets/NWRD/train/masks/138.png\n",
426
+ "image shape: (4608, 3456, 3)\n",
427
+ "total patches: 300\n",
428
+ "/scratch/wej36how/Datasets/NWRD/train/masks/139.png\n",
429
+ "image shape: (3456, 4608, 3)\n",
430
+ "total patches: 300\n",
431
+ "/scratch/wej36how/Datasets/NWRD/train/masks/14.png\n",
432
+ "image shape: (4000, 6016, 3)\n",
433
+ "total patches: 442\n",
434
+ "/scratch/wej36how/Datasets/NWRD/train/masks/140.png\n",
435
+ "image shape: (3456, 4608, 3)\n",
436
+ "total patches: 300\n",
437
+ "/scratch/wej36how/Datasets/NWRD/train/masks/141.png\n",
438
+ "image shape: (4608, 3456, 3)\n",
439
+ "total patches: 300\n",
440
+ "/scratch/wej36how/Datasets/NWRD/train/masks/142.png\n",
441
+ "image shape: (4608, 3456, 3)\n",
442
+ "total patches: 300\n",
443
+ "/scratch/wej36how/Datasets/NWRD/train/masks/144.png\n",
444
+ "image shape: (4608, 3456, 3)\n",
445
+ "total patches: 300\n",
446
+ "/scratch/wej36how/Datasets/NWRD/train/masks/145.png\n",
447
+ "image shape: (4608, 3456, 3)\n",
448
+ "total patches: 300\n",
449
+ "/scratch/wej36how/Datasets/NWRD/train/masks/146.png\n",
450
+ "image shape: (4608, 3456, 3)\n",
451
+ "total patches: 300\n",
452
+ "/scratch/wej36how/Datasets/NWRD/train/masks/149.png\n",
453
+ "image shape: (3968, 3424, 3)\n",
454
+ "total patches: 255\n",
455
+ "/scratch/wej36how/Datasets/NWRD/train/masks/15.png\n",
456
+ "image shape: (4000, 6016, 3)\n",
457
+ "total patches: 442\n",
458
+ "/scratch/wej36how/Datasets/NWRD/train/masks/150.png\n",
459
+ "image shape: (3456, 4608, 3)\n",
460
+ "total patches: 300\n",
461
+ "/scratch/wej36how/Datasets/NWRD/train/masks/19.png\n",
462
+ "image shape: (4000, 6016, 3)\n",
463
+ "total patches: 442\n",
464
+ "/scratch/wej36how/Datasets/NWRD/train/masks/2.png\n",
465
+ "image shape: (4000, 6016, 3)\n",
466
+ "total patches: 442\n",
467
+ "/scratch/wej36how/Datasets/NWRD/train/masks/21.png\n",
468
+ "image shape: (4000, 6016, 3)\n",
469
+ "total patches: 442\n",
470
+ "/scratch/wej36how/Datasets/NWRD/train/masks/24.png\n",
471
+ "image shape: (4000, 6016, 3)\n",
472
+ "total patches: 442\n",
473
+ "/scratch/wej36how/Datasets/NWRD/train/masks/26.png\n",
474
+ "image shape: (4000, 6016, 3)\n",
475
+ "total patches: 442\n",
476
+ "/scratch/wej36how/Datasets/NWRD/train/masks/27.png\n",
477
+ "image shape: (4000, 6016, 3)\n",
478
+ "total patches: 442\n",
479
+ "/scratch/wej36how/Datasets/NWRD/train/masks/28.png\n",
480
+ "image shape: (4000, 6016, 3)\n",
481
+ "total patches: 442\n",
482
+ "/scratch/wej36how/Datasets/NWRD/train/masks/3.png\n",
483
+ "image shape: (4000, 6016, 3)\n",
484
+ "total patches: 442\n",
485
+ "/scratch/wej36how/Datasets/NWRD/train/masks/30.png\n",
486
+ "image shape: (4000, 6016, 3)\n",
487
+ "total patches: 442\n",
488
+ "/scratch/wej36how/Datasets/NWRD/train/masks/31.png\n",
489
+ "image shape: (4000, 6016, 3)\n",
490
+ "total patches: 442\n",
491
+ "/scratch/wej36how/Datasets/NWRD/train/masks/32.png\n",
492
+ "image shape: (4000, 6016, 3)\n",
493
+ "total patches: 442\n",
494
+ "/scratch/wej36how/Datasets/NWRD/train/masks/33.png\n",
495
+ "image shape: (4000, 6016, 3)\n",
496
+ "total patches: 442\n",
497
+ "/scratch/wej36how/Datasets/NWRD/train/masks/5.png\n",
498
+ "image shape: (4000, 6016, 3)\n",
499
+ "total patches: 442\n",
500
+ "/scratch/wej36how/Datasets/NWRD/train/masks/57.png\n",
501
+ "image shape: (4000, 6016, 3)\n",
502
+ "total patches: 442\n",
503
+ "/scratch/wej36how/Datasets/NWRD/train/masks/58.png\n",
504
+ "image shape: (4000, 6016, 3)\n",
505
+ "total patches: 442\n",
506
+ "/scratch/wej36how/Datasets/NWRD/train/masks/60.png\n",
507
+ "image shape: (4000, 6016, 3)\n",
508
+ "total patches: 442\n",
509
+ "/scratch/wej36how/Datasets/NWRD/train/masks/64.png\n",
510
+ "image shape: (4000, 6016, 3)\n",
511
+ "total patches: 442\n",
512
+ "/scratch/wej36how/Datasets/NWRD/train/masks/69.png\n",
513
+ "image shape: (4608, 3456, 3)\n",
514
+ "total patches: 300\n",
515
+ "/scratch/wej36how/Datasets/NWRD/train/masks/71.png\n",
516
+ "image shape: (3456, 4608, 3)\n",
517
+ "total patches: 300\n",
518
+ "/scratch/wej36how/Datasets/NWRD/train/masks/72.png\n",
519
+ "image shape: (3456, 4608, 3)\n",
520
+ "total patches: 300\n",
521
+ "/scratch/wej36how/Datasets/NWRD/train/masks/73.png\n",
522
+ "image shape: (3456, 4608, 3)\n",
523
+ "total patches: 300\n",
524
+ "/scratch/wej36how/Datasets/NWRD/train/masks/74.png\n",
525
+ "image shape: (3456, 4608, 3)\n",
526
+ "total patches: 300\n",
527
+ "/scratch/wej36how/Datasets/NWRD/train/masks/75.png\n",
528
+ "image shape: (3456, 4608, 3)\n",
529
+ "total patches: 300\n",
530
+ "/scratch/wej36how/Datasets/NWRD/train/masks/76.png\n",
531
+ "image shape: (3456, 4608, 3)\n",
532
+ "total patches: 300\n",
533
+ "/scratch/wej36how/Datasets/NWRD/train/masks/78.png\n",
534
+ "image shape: (3456, 4608, 3)\n",
535
+ "total patches: 300\n",
536
+ "/scratch/wej36how/Datasets/NWRD/train/masks/79.png\n",
537
+ "image shape: (3456, 4608, 3)\n",
538
+ "total patches: 300\n",
539
+ "/scratch/wej36how/Datasets/NWRD/train/masks/81.png\n",
540
+ "image shape: (3456, 4608, 3)\n",
541
+ "total patches: 300\n",
542
+ "/scratch/wej36how/Datasets/NWRD/train/masks/83.png\n",
543
+ "image shape: (3456, 4608, 3)\n",
544
+ "total patches: 300\n",
545
+ "/scratch/wej36how/Datasets/NWRD/train/masks/85.png\n",
546
+ "image shape: (4608, 3456, 3)\n",
547
+ "total patches: 300\n",
548
+ "/scratch/wej36how/Datasets/NWRD/train/masks/86.png\n",
549
+ "image shape: (4608, 3456, 3)\n",
550
+ "total patches: 300\n",
551
+ "/scratch/wej36how/Datasets/NWRD/train/masks/87.png\n",
552
+ "image shape: (4608, 3456, 3)\n",
553
+ "total patches: 300\n",
554
+ "/scratch/wej36how/Datasets/NWRD/train/masks/88.png\n",
555
+ "image shape: (4608, 3456, 3)\n",
556
+ "total patches: 300\n",
557
+ "/scratch/wej36how/Datasets/NWRD/train/masks/9.png\n",
558
+ "image shape: (4000, 6016, 3)\n",
559
+ "total patches: 442\n",
560
+ "/scratch/wej36how/Datasets/NWRD/train/masks/90.png\n",
561
+ "image shape: (4608, 3456, 3)\n",
562
+ "total patches: 300\n",
563
+ "/scratch/wej36how/Datasets/NWRD/train/masks/91.png\n",
564
+ "image shape: (4608, 3456, 3)\n",
565
+ "total patches: 300\n",
566
+ "/scratch/wej36how/Datasets/NWRD/train/masks/92.png\n",
567
+ "image shape: (4608, 3456, 3)\n",
568
+ "total patches: 300\n",
569
+ "/scratch/wej36how/Datasets/NWRD/train/masks/93.png\n",
570
+ "image shape: (4608, 3456, 3)\n",
571
+ "total patches: 300\n",
572
+ "/scratch/wej36how/Datasets/NWRD/train/masks/94.png\n",
573
+ "image shape: (4608, 3456, 3)\n",
574
+ "total patches: 300\n",
575
+ "/scratch/wej36how/Datasets/NWRD/train/masks/95.png\n",
576
+ "image shape: (4608, 3456, 3)\n",
577
+ "total patches: 300\n",
578
+ "/scratch/wej36how/Datasets/NWRD/train/masks/96.png\n",
579
+ "image shape: (4608, 3456, 3)\n",
580
+ "total patches: 300\n",
581
+ "/scratch/wej36how/Datasets/NWRD/train/masks/98.png\n",
582
+ "image shape: (4608, 3456, 3)\n",
583
+ "total patches: 300\n",
584
+ "/scratch/wej36how/Datasets/NWRD/train/masks/99.png\n",
585
+ "image shape: (4608, 3456, 3)\n",
586
+ "total patches: 300\n",
587
+ "total masks count: 28523\n"
588
+ ]
589
+ }
590
+ ],
591
+ "source": [
592
+ "masks_paths = glob.glob(f'{os.path.join(source, \"masks\", \"*\")}')\n",
593
+ "images_paths = glob.glob(f'{os.path.join(source, \"images\", \"*\")}')\n",
594
+ "images_paths.sort()\n",
595
+ "masks_paths.sort()\n",
596
+ "\n",
597
+ "\n",
598
+ "os.makedirs(patches_path)\n",
599
+ "os.makedirs(images_dir)\n",
600
+ "os.makedirs(masks_dir)\n",
601
+ "\n",
602
+ "def create_patches(fname):\n",
603
+ " x = 0\n",
604
+ " y = 0\n",
605
+ " patches = []\n",
606
+ " img = cv2.imread(fname)\n",
607
+ " print(\"image shape:\",img.shape)\n",
608
+ " p_num = 0\n",
609
+ " while (y + patch_size < img.shape[0]):\n",
610
+ " \n",
611
+ " if (x + patch_size > img.shape[1]):\n",
612
+ " x = 0\n",
613
+ " y += patch_size\n",
614
+ " if y + patch_size <= img.shape[0] and x + patch_size <= img.shape[1]:\n",
615
+ " patches.append([x, y])\n",
616
+ " x += patch_size\n",
617
+ " print(\"total patches: \", len(patches))\n",
618
+ " return patches\n",
619
+ "\n",
620
+ "total_count = 0\n",
621
+ "for u in images_paths:\n",
622
+ " print(u)\n",
623
+ " patches = create_patches(u)\n",
624
+ " bgr = cv2.imread(u)\n",
625
+ " image_name = u.split('/')[-1].split('.')[0]\n",
626
+ " total_count += len(patches)\n",
627
+ "\n",
628
+ " for count, P in enumerate(patches):\n",
629
+ " cv2.imwrite(os.path.join(images_dir,f\"{image_name}_{count}.png\"), bgr[P[1]:P[1]+patch_size,P[0]:P[0]+patch_size])\n",
630
+ " \n",
631
+ "print(\"total image count:\", total_count)\n",
632
+ "\n",
633
+ "total_count = 0\n",
634
+ "for u in masks_paths:\n",
635
+ " print(u)\n",
636
+ " patches = create_patches(u)\n",
637
+ " bgr = cv2.imread(u)\n",
638
+ " image_name = u.split('/')[-1].split('.')[0]\n",
639
+ "\n",
640
+ " total_count += len(patches)\n",
641
+ "\n",
642
+ " for count, P in enumerate(patches):\n",
643
+ " cv2.imwrite(os.path.join(masks_dir,f\"{image_name}_{count}.png\"), bgr[P[1]:P[1]+patch_size,P[0]:P[0]+patch_size])\n",
644
+ " \n",
645
+ "print(\"total masks count:\", total_count)"
646
+ ]
647
+ },
648
+ {
649
+ "cell_type": "markdown",
650
+ "metadata": {},
651
+ "source": [
652
+ "This will saperate the rust and non rust patches and put them in and put them in directory destination/RustNonRustSplit"
653
+ ]
654
+ },
655
+ {
656
+ "cell_type": "code",
657
+ "execution_count": 4,
658
+ "metadata": {},
659
+ "outputs": [],
660
+ "source": [
661
+ "destination = os.path.join(dest, \"RustNonRustSplit\")\n",
662
+ "root = patches_path\n",
663
+ "\n",
664
+ "os.makedirs(destination)\n",
665
+ "os.makedirs(os.path.join(destination,\"non_rust\",\"images\"))\n",
666
+ "os.makedirs(os.path.join(destination,\"non_rust\",\"masks\"))\n",
667
+ "os.makedirs(os.path.join(destination,\"rust\",\"images\"))\n",
668
+ "os.makedirs(os.path.join(destination,\"rust\",\"masks\"))\n",
669
+ "\n",
670
+ "masks_path = os.path.join(root, \"masks\", \"*.png\")\n",
671
+ "masks_paths = glob.glob(masks_path)\n",
672
+ "minimum=1000\n",
673
+ "min_patch=0\n",
674
+ "rust_count=0\n",
675
+ "non_rust_count=0\n",
676
+ "\n",
677
+ "for mask_path in masks_paths:\n",
678
+ " patch_name = mask_path.split(\"/\")[-1].split(\".\")[0]\n",
679
+ " \n",
680
+ " patch_mask = cv2.imread(mask_path, 0)\n",
681
+ " patch_img = cv2.imread(os.path.join(root, \"images\",patch_name+\".png\"))\n",
682
+ "\n",
683
+ " condition = (patch_mask > 150)\n",
684
+ " count = np.sum(condition)\n",
685
+ " \n",
686
+ " if count<=rust_threshold:\n",
687
+ " cv2.imwrite(os.path.join(destination,\"non_rust\",\"images\",f\"{patch_name}.png\"), patch_img)\n",
688
+ " cv2.imwrite(os.path.join(destination,\"non_rust\",\"masks\",f\"{patch_name}.png\"), patch_mask)\n",
689
+ " non_rust_count+=1\n",
690
+ " else:\n",
691
+ " if (count<=minimum):\n",
692
+ " minimum=count\n",
693
+ " min_patch = patch_name\n",
694
+ " cv2.imwrite(os.path.join(destination,\"rust\",\"images\",f\"{patch_name}.png\"), patch_img)\n",
695
+ " cv2.imwrite(os.path.join(destination,\"rust\",\"masks\",f\"{patch_name}.png\"), patch_mask)\n",
696
+ " rust_count+=1\n",
697
+ "\n",
698
+ "print(\"minimum rust patch:\",min_patch)\n",
699
+ "print(\"minimum rust patch white pixels:\",minimum)\n",
700
+ "print(\"rust count=\", rust_count)\n",
701
+ "print(\"non rust count=\", non_rust_count)"
702
+ ]
703
+ },
704
+ {
705
+ "cell_type": "markdown",
706
+ "metadata": {},
707
+ "source": [
708
+ "Run the next two code snippets for training only. The following code will augment the images in the destination/RustNonRustSplit/images and destination/RustNonRustSplit/masks folder. "
709
+ ]
710
+ },
711
+ {
712
+ "cell_type": "code",
713
+ "execution_count": null,
714
+ "metadata": {},
715
+ "outputs": [],
716
+ "source": [
717
+ "#flip images horizontally\n",
718
+ "def flip_images_hor(input_image):\n",
719
+ " # Iterate over the images in the input directory\n",
720
+ " transform_hflip = T.RandomHorizontalFlip(p=1.0) # Set probability to 1.0 to always flip\n",
721
+ " return transform_hflip(input_image)\n",
722
+ "\n",
723
+ "#flip images vertically\n",
724
+ "def flip_images_ver(input_image):\n",
725
+ " # Iterate over the images in the input directory\n",
726
+ " transform_vflip = T.RandomVerticalFlip(p=1.0) # Set probability to 1.0 to always flip\n",
727
+ " return transform_vflip(input_image) \n",
728
+ " \n",
729
+ "def shear_vertical(input_image, shear_factor=45):\n",
730
+ " # Apply vertical shear\n",
731
+ " sheared_image = TF.affine(input_image, angle=0, translate=(0, 0), scale=1, shear=(0, shear_factor))\n",
732
+ " return sheared_image\n",
733
+ "\n",
734
+ "def shear_horizontal(input_image, shear_factor=45): # Increased shear for testing\n",
735
+ " sheared_image = TF.affine(input_image, angle=0, translate=(0, 0), scale=1, shear=(shear_factor, 0))\n",
736
+ " return sheared_image\n",
737
+ "\n",
738
+ "def rotate_images(input_image, angle=45):\n",
739
+ " # Convert PIL Image to NumPy array\n",
740
+ " input_array = np.array(input_image)\n",
741
+ " # Rotate the image\n",
742
+ " height, width = input_array.shape[:2]\n",
743
+ " rotation_matrix = cv2.getRotationMatrix2D((width / 2, height / 2), angle, 1)\n",
744
+ " rotated_array = cv2.warpAffine(input_array, rotation_matrix, (width, height))\n",
745
+ " # Convert NumPy array back to PIL Image\n",
746
+ " rotated_image = Image.fromarray(rotated_array)\n",
747
+ " return rotated_image\n",
748
+ "\n",
749
+ "def dark(input_image,gamma):\n",
750
+ " dark_image= TF.adjust_gamma(input_image, gamma)\n",
751
+ " return dark_image\n",
752
+ "\n",
753
+ "def augment_image(img_path):\n",
754
+ "\n",
755
+ " # Apply the transformations\n",
756
+ " \n",
757
+ " #orig_image\n",
758
+ " orig_img = Image.open(Path(img_path))\n",
759
+ " \n",
760
+ " #flip images\n",
761
+ " img_hflipped = flip_images_hor(orig_img)\n",
762
+ " img_vflipped = flip_images_ver(orig_img)\n",
763
+ " \n",
764
+ " \n",
765
+ " #shear images\n",
766
+ " hor_shear = shear_horizontal(orig_img)\n",
767
+ " ver_shear = shear_vertical(orig_img)\n",
768
+ " \n",
769
+ " #dark\n",
770
+ " img_dark = dark(img_hflipped, 2)\n",
771
+ " img_rot = rotate_images(orig_img, angle=45)\n",
772
+ " \n",
773
+ " return [img_dark,img_hflipped,img_vflipped,hor_shear,ver_shear, img_rot]\n",
774
+ "\n",
775
+ "def creating_file_with_augmented_images(file_path_master_dataset, file_path_augmented_images):\n",
776
+ " master_dataset_folder = file_path_master_dataset\n",
777
+ " files_in_master_dataset = os.listdir(file_path_master_dataset)\n",
778
+ " augmented_images_folder = file_path_augmented_images\n",
779
+ " \n",
780
+ " for image_name in files_in_master_dataset:\n",
781
+ " image_path = os.path.join(master_dataset_folder, image_name)\n",
782
+ " required_images = augment_image(image_path) # Assuming augment_image is defined elsewhere\n",
783
+ " i = 0\n",
784
+ " for augmented_image in required_images:\n",
785
+ " # Convert RGBA to RGB if necessary\n",
786
+ " if augmented_image.mode == 'RGBA':\n",
787
+ " augmented_image = augmented_image.convert('RGB')\n",
788
+ " \n",
789
+ " # Save as png\n",
790
+ " augmented_image_path = os.path.join(augmented_images_folder, f\"aug{i}_{image_name}\")\n",
791
+ " augmented_image.save(augmented_image_path, format='png')\n",
792
+ " i += 1\n",
793
+ "\n",
794
+ "master_dataset = os.path.join(destination,\"rust\",\"images\")\n",
795
+ "augmented_dataset = os.path.join(destination,\"rust\",\"images\")\n",
796
+ "creating_file_with_augmented_images(master_dataset,augmented_dataset)\n",
797
+ "\n",
798
+ "master_dataset = os.path.join(destination,\"rust\",\"masks\")\n",
799
+ "augmented_dataset = os.path.join(destination,\"rust\",\"masks\")\n",
800
+ "creating_file_with_augmented_images(master_dataset,augmented_dataset)"
801
+ ]
802
+ },
803
+ {
804
+ "cell_type": "markdown",
805
+ "metadata": {},
806
+ "source": [
807
+ "Run next snippet only for training dataset. To remove patches that have their rust removed becuase of their augmentations"
808
+ ]
809
+ },
810
+ {
811
+ "cell_type": "code",
812
+ "execution_count": null,
813
+ "metadata": {},
814
+ "outputs": [
815
+ {
816
+ "name": "stdout",
817
+ "output_type": "stream",
818
+ "text": [
819
+ "minimum rust patch: 0\n",
820
+ "minimum rust patch white pixels: 1000\n",
821
+ "rust count= 0\n",
822
+ "non rust count= 0\n"
823
+ ]
824
+ }
825
+ ],
826
+ "source": [
827
+ "root = os.path.join(destination,\"rust\")\n",
828
+ "\n",
829
+ "non_rust_images_dir = os.path.join(destination,\"non_rust\",\"images\")\n",
830
+ "non_rust_masks_dir = os.path.join(destination,\"non_rust\",\"masks\")\n",
831
+ "\n",
832
+ "masks_path = os.path.join(root, \"masks\", \"*.png\")\n",
833
+ "masks_paths = glob.glob(masks_path)\n",
834
+ "minimum=1000\n",
835
+ "min_patch=0\n",
836
+ "rust_count=0\n",
837
+ "non_rust_count=0\n",
838
+ "\n",
839
+ "for mask_path in masks_paths:\n",
840
+ " patch_name = mask_path.split(\"/\")[-1].split(\".\")[0]\n",
841
+ " \n",
842
+ " patch_mask = cv2.imread(mask_path, 0)\n",
843
+ " patch_img = cv2.imread(os.path.join(root, \"images\",patch_name+\".png\"))\n",
844
+ "\n",
845
+ " condition = (patch_mask > 150)\n",
846
+ " count = np.sum(condition)\n",
847
+ " \n",
848
+ " if count<=rust_threshold:\n",
849
+ " os.remove(mask_path)\n",
850
+ " os.remove(os.path.join(root, \"images\",patch_name+\".png\"))\n",
851
+ "\n",
852
+ " cv2.imwrite(os.path.join(non_rust_images_dir,f\"{patch_name}.png\"), patch_img)\n",
853
+ " cv2.imwrite(os.path.join(non_rust_masks_dir,f\"{patch_name}.png\"), patch_mask)\n",
854
+ " non_rust_count+=1\n",
855
+ " else:\n",
856
+ " if (count<=minimum):\n",
857
+ " minimum=count\n",
858
+ " min_patch = patch_name\n",
859
+ " # cv2.imwrite(os.path.join(destination,\"rust\",\"images\",f\"{patch_name}.png\"), patch_img)\n",
860
+ " # cv2.imwrite(os.path.join(destination,\"rust\",\"masks\",f\"{patch_name}.png\"), patch_mask)\n",
861
+ " rust_count+=1\n",
862
+ "\n",
863
+ "print(\"minimum rust patch:\",min_patch)\n",
864
+ "print(\"minimum rust patch white pixels:\",minimum)\n",
865
+ "print(\"rust count=\", rust_count)\n",
866
+ "print(\"non rust count=\", non_rust_count)"
867
+ ]
868
+ },
869
+ {
870
+ "cell_type": "markdown",
871
+ "metadata": {},
872
+ "source": [
873
+ "Create a dataset for classification model"
874
+ ]
875
+ },
876
+ {
877
+ "cell_type": "code",
878
+ "execution_count": null,
879
+ "metadata": {},
880
+ "outputs": [
881
+ {
882
+ "data": {
883
+ "text/plain": [
884
+ "'C:\\\\Users\\\\hasee\\\\Desktop\\\\Germany_2024\\\\Dataset\\\\NWRDFprocessed\\\\train\\\\calssification\\\\non_rust'"
885
+ ]
886
+ },
887
+ "execution_count": 7,
888
+ "metadata": {},
889
+ "output_type": "execute_result"
890
+ }
891
+ ],
892
+ "source": [
893
+ "rust_images_dir = os.path.join(destination,\"rust\",\"images\")\n",
894
+ "non_rust_images_dir = os.path.join(destination,\"non_rust\",\"images\")\n",
895
+ "\n",
896
+ "rustClassificationDir = os.path.join(dest, \"calssification\", \"rust\")\n",
897
+ "nonRustClassificationDir = os.path.join(dest, \"calssification\", \"non_rust\")\n",
898
+ "os.makedirs(rustClassificationDir, exist_ok=True)\n",
899
+ "os.makedirs(nonRustClassificationDir, exist_ok=True)\n",
900
+ "\n",
901
+ "shutil.copytree(rust_images_dir,rustClassificationDir, dirs_exist_ok=True)\n",
902
+ "shutil.copytree(non_rust_images_dir,nonRustClassificationDir, dirs_exist_ok=True)\n"
903
+ ]
904
+ },
905
+ {
906
+ "cell_type": "markdown",
907
+ "metadata": {},
908
+ "source": [
909
+ "Run the next code snippet for training dataset only. It deletes non-rust patches to match rust patches in the classification folder only."
910
+ ]
911
+ },
912
+ {
913
+ "cell_type": "code",
914
+ "execution_count": null,
915
+ "metadata": {},
916
+ "outputs": [],
917
+ "source": [
918
+ "import os\n",
919
+ "import glob\n",
920
+ "\n",
921
+ "def delete_extra_images(directory, target_count):\n",
922
+ " # Get a list of all image files in the directory\n",
923
+ " image_files = glob.glob(os.path.join(directory, '*.JPG')) + glob.glob(os.path.join(directory, '*.jpeg')) + glob.glob(os.path.join(directory, '*.png'))\n",
924
+ " \n",
925
+ " # Check if the number of images exceeds the target count\n",
926
+ " if len(image_files) > target_count:\n",
927
+ " # Calculate the number of images to delete\n",
928
+ " num_to_delete = len(image_files) - target_count\n",
929
+ " # Sort the images by modification time (oldest first)\n",
930
+ " image_files.sort(key=os.path.getmtime)\n",
931
+ " # Delete the extra images\n",
932
+ " for i in range(num_to_delete):\n",
933
+ " os.remove(image_files[i])\n",
934
+ " print(f\"{num_to_delete} images deleted.\")\n",
935
+ " elif len(image_files) < target_count:\n",
936
+ " print(\"Warning: Number of images in directory is less than the target count.\")\n",
937
+ "\n",
938
+ "if len(os.listdir(rustClassificationDir))< len(os.listdir(nonRustClassificationDir)):\n",
939
+ " delete_extra_images(nonRustClassificationDir, len(os.listdir(rustClassificationDir)))\n",
940
+ "else:\n",
941
+ " delete_extra_images(rustClassificationDir, len(os.listdir(nonRustClassificationDir)))\n"
942
+ ]
943
+ },
944
+ {
945
+ "cell_type": "markdown",
946
+ "metadata": {},
947
+ "source": [
948
+ "The following code creates a coslaiency style structure for co-saliency models training"
949
+ ]
950
+ },
951
+ {
952
+ "cell_type": "code",
953
+ "execution_count": null,
954
+ "metadata": {},
955
+ "outputs": [],
956
+ "source": [
957
+ "rust_dir = os.path.join(destination,\"rust\")\n",
958
+ "rustCosaliencynDir = os.path.join(dest, \"cosaliency\")\n",
959
+ "shutil.copytree(rust_dir,rustCosaliencynDir, dirs_exist_ok=True)\n",
960
+ "\n",
961
+ "# Function to split images into folders based on image number\n",
962
+ "def split_images_into_folders(source_dir, destination_dir):\n",
963
+ " # Create destination directory if it doesn't exist\n",
964
+ " if not os.path.exists(destination_dir):\n",
965
+ " os.makedirs(destination_dir)\n",
966
+ " # Iterate through files in the source directory\n",
967
+ " for filename in os.listdir(source_dir):\n",
968
+ " if filename.endswith('.png'):\n",
969
+ " image_no = filename.split('_')[0] # Extract image number from filename\n",
970
+ " if not image_no.isdigit():\n",
971
+ " image_no = filename.split('_')[1]\n",
972
+ " destination_subdir = os.path.join(destination_dir, image_no)\n",
973
+ " # Create subdirectory if it doesn't exist\n",
974
+ " if not os.path.exists(destination_subdir):\n",
975
+ " os.makedirs(destination_subdir)\n",
976
+ " # Move the image file to the respective subdirectory\n",
977
+ " shutil.move(os.path.join(source_dir, filename), destination_subdir)\n",
978
+ "\n",
979
+ "def organize_images(main_directory):\n",
980
+ " # Ensure the main directory exists\n",
981
+ " if not os.path.exists(main_directory):\n",
982
+ " print(f\"The specified main directory '{main_directory}' does not exist.\")\n",
983
+ " return\n",
984
+ "\n",
985
+ " # Get a list of subdirectories in the main directory\n",
986
+ " subdirectories = [d for d in os.listdir(main_directory) if os.path.isdir(os.path.join(main_directory, d))]\n",
987
+ "\n",
988
+ " # Process each subdirectory\n",
989
+ " for subdir in subdirectories:\n",
990
+ " subdir_path = os.path.join(main_directory, subdir)\n",
991
+ "\n",
992
+ " # Get a list of images in the subdirectory\n",
993
+ " images = [f for f in os.listdir(subdir_path) if f.lower().endswith(('.jpg', '.jpeg', '.png', '.gif'))]\n",
994
+ " # Determine the number of images per subdirectory\n",
995
+ " images_per_subdir = 12\n",
996
+ " num_subdirectories = len(images) // images_per_subdir\n",
997
+ " n=0\n",
998
+ " # Create additional subdirectories if needed\n",
999
+ " for i in range(num_subdirectories - 1):\n",
1000
+ " new_subdir_name = f\"{subdir}_part{i + 1}\"\n",
1001
+ " new_subdir_path = os.path.join(main_directory, new_subdir_name)\n",
1002
+ "\n",
1003
+ " # Create the new subdirectory\n",
1004
+ " os.makedirs(new_subdir_path)\n",
1005
+ "\n",
1006
+ " # Move images to the new subdirectory\n",
1007
+ " for j in range(images_per_subdir):\n",
1008
+ " old_image_path = os.path.join(subdir_path, images[n])\n",
1009
+ " new_image_path = os.path.join(new_subdir_path, images[n])\n",
1010
+ " shutil.move(old_image_path, new_image_path)\n",
1011
+ " n+=1\n",
1012
+ "\n",
1013
+ "source_directory = os.path.join(dest, \"cosaliency\", \"images\")\n",
1014
+ "destination_directory = os.path.join(dest, \"cosaliency\", \"images\")\n",
1015
+ "split_images_into_folders(source_directory, destination_directory)\n",
1016
+ "organize_images(destination_directory)\n",
1017
+ "\n",
1018
+ "source_directory = os.path.join(dest, \"cosaliency\", \"masks\")\n",
1019
+ "destination_directory = os.path.join(dest, \"cosaliency\", \"masks\")\n",
1020
+ "split_images_into_folders(source_directory, destination_directory)\n",
1021
+ "organize_images(destination_directory)"
1022
+ ]
1023
+ }
1024
+ ],
1025
+ "metadata": {
1026
+ "kernelspec": {
1027
+ "display_name": "segformer",
1028
+ "language": "python",
1029
+ "name": "python3"
1030
+ },
1031
+ "language_info": {
1032
+ "codemirror_mode": {
1033
+ "name": "ipython",
1034
+ "version": 3
1035
+ },
1036
+ "file_extension": ".py",
1037
+ "mimetype": "text/x-python",
1038
+ "name": "python",
1039
+ "nbconvert_exporter": "python",
1040
+ "pygments_lexer": "ipython3",
1041
+ "version": "3.8.0"
1042
+ }
1043
+ },
1044
+ "nbformat": 4,
1045
+ "nbformat_minor": 2
1046
+ }
evaluator.cpython-37.pyc ADDED
Binary file (12.3 kB). View file
 
evaluator.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import json
4
+
5
+ import numpy as np
6
+ from scipy.io import savemat
7
+ import torch
8
+ from torchvision import transforms
9
+
10
+ from PIL import ImageFile
11
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
12
+
13
+
14
+ class Eval_thread():
15
+ def __init__(self, loader, method='', dataset='', output_dir='', epoch='', cuda=True):
16
+ self.loader = loader
17
+ self.method = method
18
+ self.dataset = dataset
19
+ self.cuda = cuda
20
+ self.output_dir = output_dir
21
+ self.epoch = epoch.split('ep')[-1]
22
+ self.logfile = os.path.join(output_dir, 'result.txt')
23
+ self.dataset2smeasure_bottom_bound = {'CoCA': 0.673, 'CoSOD3k': 0.802, 'CoSal2015': 0.845} # S_measures of GCoNet
24
+
25
+ def run(self, AP=False, AUC=False, save_metrics=False, continue_eval=True):
26
+ Res = {}
27
+ start_time = time.time()
28
+
29
+ if continue_eval:
30
+ s = self.Eval_Smeasure()
31
+ if s > self.dataset2smeasure_bottom_bound[self.dataset]:
32
+ mae = self.Eval_mae()
33
+ Em = self.Eval_Emeasure()
34
+ max_e = Em.max().item()
35
+ mean_e = Em.mean().item()
36
+ Em = Em.cpu().numpy()
37
+ Fm, prec, recall = self.Eval_fmeasure()
38
+ max_f = Fm.max().item()
39
+ mean_f = Fm.mean().item()
40
+ Fm = Fm.cpu().numpy()
41
+ else:
42
+ mae = 1
43
+ Em = torch.zeros(255).cpu().numpy()
44
+ max_e = 0
45
+ mean_e = 0
46
+ Fm, prec, recall = 0, 0, 0
47
+ max_f = 0
48
+ mean_f = 0
49
+ continue_eval = False
50
+ else:
51
+ s = 0
52
+ mae = 1
53
+ Em = torch.zeros(255).cpu().numpy()
54
+ max_e = 0
55
+ mean_e = 0
56
+ Fm, prec, recall = 0, 0, 0
57
+ max_f = 0
58
+ mean_f = 0
59
+ continue_eval = False
60
+
61
+
62
+ if AP:
63
+ prec = prec.cpu().numpy()
64
+ recall = recall.cpu().numpy()
65
+ avg_p = self.Eval_AP(prec, recall)
66
+
67
+ if AUC:
68
+ auc, TPR, FPR = self.Eval_auc()
69
+ TPR = TPR.cpu().numpy()
70
+ FPR = FPR.cpu().numpy()
71
+
72
+ if save_metrics:
73
+ os.makedirs(os.path.join(self.output_dir, self.method, self.epoch), exist_ok=True)
74
+ Res['Sm'] = s
75
+ if s > self.dataset2smeasure_bottom_bound[self.dataset]:
76
+ Res['MAE'] = mae
77
+ Res['MaxEm'] = max_e
78
+ Res['MeanEm'] = mean_e
79
+ Res['Em'] = Em
80
+ Res['Fm'] = Fm
81
+ else:
82
+ Res['MAE'] = 1
83
+ Res['MaxEm'] = 0
84
+ Res['MeanEm'] = 0
85
+ Res['Em'] = torch.zeros(255).cpu().numpy()
86
+ Res['Fm'] = 0
87
+
88
+ if AP:
89
+ Res['MaxFm'] = max_f
90
+ Res['MeanFm'] = mean_f
91
+ Res['AP'] = avg_p
92
+ Res['Prec'] = prec
93
+ Res['Recall'] = recall
94
+
95
+ if AUC:
96
+ Res['AUC'] = auc
97
+ Res['TPR'] = TPR
98
+ Res['FPR'] = FPR
99
+
100
+ os.makedirs(os.path.join(self.output_dir, self.method, self.epoch), exist_ok=True)
101
+ savemat(os.path.join(self.output_dir, self.method, self.epoch, self.dataset + '.mat'), Res)
102
+
103
+ info = '{} ({}): {:.4f} max-Emeasure || {:.4f} S-measure || {:.4f} max-fm || {:.4f} mae || {:.4f} mean-Emeasure || {:.4f} mean-fm'.format(
104
+ self.dataset, self.method+'-ep{}'.format(self.epoch), max_e, s, max_f, mae, mean_e, mean_f
105
+ )
106
+ if AP:
107
+ info += ' || {:.4f} AP'.format(avg_p)
108
+ if AUC:
109
+ info += ' || {:.4f} AUC'.format(auc)
110
+ info += '.'
111
+ self.LOG(info + '\n')
112
+
113
+ return '[cost:{:.4f}s] '.format(time.time() - start_time) + info, continue_eval
114
+
115
+ def Eval_mae(self):
116
+ if self.epoch:
117
+ print('Evaluating MAE...')
118
+ avg_mae, img_num = 0.0, 0.0
119
+ with torch.no_grad():
120
+ trans = transforms.Compose([transforms.ToTensor()])
121
+ for pred, gt in self.loader:
122
+ if self.cuda:
123
+ pred = trans(pred).cuda()
124
+ gt = trans(gt).cuda()
125
+ else:
126
+ pred = trans(pred)
127
+ gt = trans(gt)
128
+ mea = torch.abs(pred - gt).mean()
129
+ if mea == mea: # for Nan
130
+ avg_mae += mea
131
+ img_num += 1.0
132
+ avg_mae /= img_num
133
+ return avg_mae.item()
134
+
135
+ def Eval_fmeasure(self):
136
+ print('Evaluating FMeasure...')
137
+ beta2 = 0.3
138
+ avg_f, avg_p, avg_r, img_num = 0.0, 0.0, 0.0, 0.0
139
+
140
+ with torch.no_grad():
141
+ trans = transforms.Compose([transforms.ToTensor()])
142
+ for pred, gt in self.loader:
143
+ if self.cuda:
144
+ pred = trans(pred).cuda()
145
+ gt = trans(gt).cuda()
146
+ pred = (pred - torch.min(pred)) / (torch.max(pred) -
147
+ torch.min(pred) + 1e-20)
148
+ else:
149
+ pred = trans(pred)
150
+ pred = (pred - torch.min(pred)) / (torch.max(pred) -
151
+ torch.min(pred) + 1e-20)
152
+ gt = trans(gt)
153
+ prec, recall = self._eval_pr(pred, gt, 255)
154
+ f_score = (1 + beta2) * prec * recall / (beta2 * prec + recall)
155
+ f_score[f_score != f_score] = 0 # for Nan
156
+ avg_f += f_score
157
+ avg_p += prec
158
+ avg_r += recall
159
+ img_num += 1.0
160
+ Fm = avg_f / img_num
161
+ avg_p = avg_p / img_num
162
+ avg_r = avg_r / img_num
163
+ return Fm, avg_p, avg_r
164
+
165
+ def Eval_auc(self):
166
+ print('Evaluating AUC...')
167
+
168
+ avg_tpr, avg_fpr, avg_auc, img_num = 0.0, 0.0, 0.0, 0.0
169
+
170
+ with torch.no_grad():
171
+ trans = transforms.Compose([transforms.ToTensor()])
172
+ for pred, gt in self.loader:
173
+ if self.cuda:
174
+ pred = trans(pred).cuda()
175
+ pred = (pred - torch.min(pred)) / (torch.max(pred) -
176
+ torch.min(pred) + 1e-20)
177
+ gt = trans(gt).cuda()
178
+ else:
179
+ pred = trans(pred)
180
+ pred = (pred - torch.min(pred)) / (torch.max(pred) -
181
+ torch.min(pred) + 1e-20)
182
+ gt = trans(gt)
183
+ TPR, FPR = self._eval_roc(pred, gt, 255)
184
+ avg_tpr += TPR
185
+ avg_fpr += FPR
186
+ img_num += 1.0
187
+ avg_tpr = avg_tpr / img_num
188
+ avg_fpr = avg_fpr / img_num
189
+
190
+ sorted_idxes = torch.argsort(avg_fpr)
191
+ avg_tpr = avg_tpr[sorted_idxes]
192
+ avg_fpr = avg_fpr[sorted_idxes]
193
+ avg_auc = torch.trapz(avg_tpr, avg_fpr)
194
+
195
+ return avg_auc.item(), avg_tpr, avg_fpr
196
+
197
+ def Eval_Emeasure(self):
198
+ print('Evaluating EMeasure...')
199
+ avg_e, img_num = 0.0, 0.0
200
+ with torch.no_grad():
201
+ trans = transforms.Compose([transforms.ToTensor()])
202
+ Em = torch.zeros(255)
203
+ if self.cuda:
204
+ Em = Em.cuda()
205
+ for pred, gt in self.loader:
206
+ if self.cuda:
207
+ pred = trans(pred).cuda()
208
+ pred = (pred - torch.min(pred)) / (torch.max(pred) -
209
+ torch.min(pred) + 1e-20)
210
+ gt = trans(gt).cuda()
211
+ else:
212
+ pred = trans(pred)
213
+ pred = (pred - torch.min(pred)) / (torch.max(pred) -
214
+ torch.min(pred) + 1e-20)
215
+ gt = trans(gt)
216
+ Em += self._eval_e(pred, gt, 255)
217
+ img_num += 1.0
218
+
219
+ Em /= img_num
220
+ return Em
221
+
222
+ def select_by_Smeasure(self, bar=0.9, loader_comp=None, bar_comp=0.1):
223
+ print('Evaluating SMeasure...')
224
+ good_ones = []
225
+ good_ones_comp = []
226
+ good_ones_gt = []
227
+ alpha, avg_q, img_num = 0.5, 0.0, 0.0
228
+ with torch.no_grad():
229
+ trans = transforms.Compose([transforms.ToTensor()])
230
+ for (pred, gt, predpath, gtpath), (pred_comp, gt_comp, predpath_comp) in zip(self.loader, loader_comp):
231
+ # pred X gt
232
+ if self.cuda:
233
+ pred = trans(pred).cuda()
234
+ pred = (pred - torch.min(pred)) / (torch.max(pred) -
235
+ torch.min(pred) + 1e-20)
236
+ gt = trans(gt).cuda()
237
+ else:
238
+ pred = trans(pred)
239
+ pred = (pred - torch.min(pred)) / (torch.max(pred) -
240
+ torch.min(pred) + 1e-20)
241
+ gt = trans(gt)
242
+ y = gt.mean()
243
+ if y == 0:
244
+ x = pred.mean()
245
+ Q = 1.0 - x
246
+ elif y == 1:
247
+ x = pred.mean()
248
+ Q = x
249
+ else:
250
+ gt[gt >= 0.5] = 1
251
+ gt[gt < 0.5] = 0
252
+ Q = alpha * self._S_object(
253
+ pred, gt) + (1 - alpha) * self._S_region(pred, gt)
254
+ if Q.item() < 0:
255
+ Q = torch.FloatTensor([0.0])
256
+ img_num += 1.0
257
+ avg_q += Q.item()
258
+ # pred_comp X gt
259
+ if self.cuda:
260
+ pred_comp = trans(pred_comp).cuda()
261
+ pred_comp = (pred_comp - torch.min(pred_comp)) / (torch.max(pred_comp) -
262
+ torch.min(pred_comp) + 1e-20)
263
+ gt_comp = trans(gt_comp).cuda()
264
+ else:
265
+ pred_comp = trans(pred_comp)
266
+ pred_comp = (pred_comp - torch.min(pred_comp)) / (torch.max(pred_comp) -
267
+ torch.min(pred_comp) + 1e-20)
268
+ gt_comp = trans(gt_comp)
269
+ y = gt_comp.mean()
270
+ if y == 0:
271
+ x = pred_comp.mean()
272
+ Q_comp = 1.0 - x
273
+ elif y == 1:
274
+ x = pred_comp.mean()
275
+ Q_comp = x
276
+ else:
277
+ gt_comp[gt_comp >= 0.5] = 1
278
+ gt_comp[gt_comp < 0.5] = 0
279
+ Q_comp = alpha * self._S_object(
280
+ pred_comp, gt_comp) + (1 - alpha) * self._S_region(pred_comp, gt_comp)
281
+ if Q_comp.item() < 0:
282
+ Q_comp = torch.FloatTensor([0.0])
283
+ if Q.item() > bar and (Q.item() - Q_comp.item()) > bar_comp:
284
+ good_ones.append(predpath)
285
+ good_ones_comp.append(predpath_comp)
286
+ good_ones_gt.append(gtpath)
287
+ avg_q /= img_num
288
+ return avg_q, good_ones, good_ones_comp, good_ones_gt
289
+
290
+ def Eval_Smeasure(self):
291
+ print('Evaluating SMeasure...')
292
+ alpha, avg_q, img_num = 0.5, 0.0, 0.0
293
+ with torch.no_grad():
294
+ trans = transforms.Compose([transforms.ToTensor()])
295
+ for pred, gt in self.loader:
296
+ if self.cuda:
297
+ pred = trans(pred).cuda()
298
+ pred = (pred - torch.min(pred)) / (torch.max(pred) -
299
+ torch.min(pred) + 1e-20)
300
+ gt = trans(gt).cuda()
301
+ else:
302
+ pred = trans(pred)
303
+ pred = (pred - torch.min(pred)) / (torch.max(pred) -
304
+ torch.min(pred) + 1e-20)
305
+ gt = trans(gt)
306
+ y = gt.mean()
307
+ if y == 0:
308
+ x = pred.mean()
309
+ Q = 1.0 - x
310
+ elif y == 1:
311
+ x = pred.mean()
312
+ Q = x
313
+ else:
314
+ gt[gt >= 0.5] = 1
315
+ gt[gt < 0.5] = 0
316
+ Q = alpha * self._S_object(
317
+ pred, gt) + (1 - alpha) * self._S_region(pred, gt)
318
+ if Q.item() < 0:
319
+ Q = torch.FloatTensor([0.0])
320
+ img_num += 1.0
321
+ avg_q += Q.item()
322
+ avg_q /= img_num
323
+ return avg_q
324
+
325
+ def LOG(self, output):
326
+ os.makedirs(self.output_dir, exist_ok=True)
327
+ with open(self.logfile, 'a') as f:
328
+ f.write(output)
329
+
330
+ def _eval_e(self, y_pred, y, num):
331
+ if self.cuda:
332
+ score = torch.zeros(num).cuda()
333
+ thlist = torch.linspace(0, 1 - 1e-10, num).cuda()
334
+ else:
335
+ score = torch.zeros(num)
336
+ thlist = torch.linspace(0, 1 - 1e-10, num)
337
+ for i in range(num):
338
+ y_pred_th = (y_pred >= thlist[i]).float()
339
+ fm = y_pred_th - y_pred_th.mean()
340
+ gt = y - y.mean()
341
+ align_matrix = 2 * gt * fm / (gt * gt + fm * fm + 1e-20)
342
+ enhanced = ((align_matrix + 1) * (align_matrix + 1)) / 4
343
+ score[i] = torch.sum(enhanced) / (y.numel() - 1 + 1e-20)
344
+ return score
345
+
346
+ def _eval_pr(self, y_pred, y, num):
347
+ if self.cuda:
348
+ prec, recall = torch.zeros(num).cuda(), torch.zeros(num).cuda()
349
+ thlist = torch.linspace(0, 1 - 1e-10, num).cuda()
350
+ else:
351
+ prec, recall = torch.zeros(num), torch.zeros(num)
352
+ thlist = torch.linspace(0, 1 - 1e-10, num)
353
+ for i in range(num):
354
+ y_temp = (y_pred >= thlist[i]).float()
355
+ tp = (y_temp * y).sum()
356
+ prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / (y.sum() + 1e-20)
357
+ return prec, recall
358
+
359
+ def _eval_roc(self, y_pred, y, num):
360
+ if self.cuda:
361
+ TPR, FPR = torch.zeros(num).cuda(), torch.zeros(num).cuda()
362
+ thlist = torch.linspace(0, 1 - 1e-10, num).cuda()
363
+ else:
364
+ TPR, FPR = torch.zeros(num), torch.zeros(num)
365
+ thlist = torch.linspace(0, 1 - 1e-10, num)
366
+ for i in range(num):
367
+ y_temp = (y_pred >= thlist[i]).float()
368
+ tp = (y_temp * y).sum()
369
+ fp = (y_temp * (1 - y)).sum()
370
+ tn = ((1 - y_temp) * (1 - y)).sum()
371
+ fn = ((1 - y_temp) * y).sum()
372
+
373
+ TPR[i] = tp / (tp + fn + 1e-20)
374
+ FPR[i] = fp / (fp + tn + 1e-20)
375
+
376
+ return TPR, FPR
377
+
378
+ def _S_object(self, pred, gt):
379
+ fg = torch.where(gt == 0, torch.zeros_like(pred), pred)
380
+ bg = torch.where(gt == 1, torch.zeros_like(pred), 1 - pred)
381
+ o_fg = self._object(fg, gt)
382
+ o_bg = self._object(bg, 1 - gt)
383
+ u = gt.mean()
384
+ Q = u * o_fg + (1 - u) * o_bg
385
+ return Q
386
+
387
+ def _object(self, pred, gt):
388
+ temp = pred[gt == 1]
389
+ x = temp.mean()
390
+ sigma_x = temp.std()
391
+ score = 2.0 * x / (x * x + 1.0 + sigma_x + 1e-20)
392
+
393
+ return score
394
+
395
+ def _S_region(self, pred, gt):
396
+ X, Y = self._centroid(gt)
397
+ gt1, gt2, gt3, gt4, w1, w2, w3, w4 = self._divideGT(gt, X, Y)
398
+ p1, p2, p3, p4 = self._dividePrediction(pred, X, Y)
399
+ Q1 = self._ssim(p1, gt1)
400
+ Q2 = self._ssim(p2, gt2)
401
+ Q3 = self._ssim(p3, gt3)
402
+ Q4 = self._ssim(p4, gt4)
403
+ Q = w1 * Q1 + w2 * Q2 + w3 * Q3 + w4 * Q4
404
+ return Q
405
+
406
+ def _centroid(self, gt):
407
+ rows, cols = gt.size()[-2:]
408
+ gt = gt.view(rows, cols)
409
+ if gt.sum() == 0:
410
+ if self.cuda:
411
+ X = torch.eye(1).cuda() * round(cols / 2)
412
+ Y = torch.eye(1).cuda() * round(rows / 2)
413
+ else:
414
+ X = torch.eye(1) * round(cols / 2)
415
+ Y = torch.eye(1) * round(rows / 2)
416
+ else:
417
+ total = gt.sum()
418
+ if self.cuda:
419
+ i = torch.from_numpy(np.arange(0, cols)).cuda().float()
420
+ j = torch.from_numpy(np.arange(0, rows)).cuda().float()
421
+ else:
422
+ i = torch.from_numpy(np.arange(0, cols)).float()
423
+ j = torch.from_numpy(np.arange(0, rows)).float()
424
+ X = torch.round((gt.sum(dim=0) * i).sum() / total + 1e-20)
425
+ Y = torch.round((gt.sum(dim=1) * j).sum() / total + 1e-20)
426
+ return X.long(), Y.long()
427
+
428
+ def _divideGT(self, gt, X, Y):
429
+ h, w = gt.size()[-2:]
430
+ area = h * w
431
+ gt = gt.view(h, w)
432
+ LT = gt[:Y, :X]
433
+ RT = gt[:Y, X:w]
434
+ LB = gt[Y:h, :X]
435
+ RB = gt[Y:h, X:w]
436
+ X = X.float()
437
+ Y = Y.float()
438
+ w1 = X * Y / area
439
+ w2 = (w - X) * Y / area
440
+ w3 = X * (h - Y) / area
441
+ w4 = 1 - w1 - w2 - w3
442
+ return LT, RT, LB, RB, w1, w2, w3, w4
443
+
444
+ def _dividePrediction(self, pred, X, Y):
445
+ h, w = pred.size()[-2:]
446
+ pred = pred.view(h, w)
447
+ LT = pred[:Y, :X]
448
+ RT = pred[:Y, X:w]
449
+ LB = pred[Y:h, :X]
450
+ RB = pred[Y:h, X:w]
451
+ return LT, RT, LB, RB
452
+
453
+ def _ssim(self, pred, gt):
454
+ gt = gt.float()
455
+ h, w = pred.size()[-2:]
456
+ N = h * w
457
+ x = pred.mean()
458
+ y = gt.mean()
459
+ sigma_x2 = ((pred - x) * (pred - x)).sum() / (N - 1 + 1e-20)
460
+ sigma_y2 = ((gt - y) * (gt - y)).sum() / (N - 1 + 1e-20)
461
+ sigma_xy = ((pred - x) * (gt - y)).sum() / (N - 1 + 1e-20)
462
+
463
+ aplha = 4 * x * y * sigma_xy
464
+ beta = (x * x + y * y) * (sigma_x2 + sigma_y2)
465
+
466
+ if aplha != 0:
467
+ Q = aplha / (beta + 1e-20)
468
+ elif aplha == 0 and beta == 0:
469
+ Q = 1.0
470
+ else:
471
+ Q = 0
472
+ return Q
473
+
474
+ def Eval_AP(self, prec, recall):
475
+ # Ref:
476
+ # https://github.com/facebookresearch/Detectron/blob/05d04d3a024f0991339de45872d02f2f50669b3d/lib/datasets/voc_eval.py#L54
477
+ print('Evaluating AP...')
478
+ ap_r = np.concatenate(([0.], recall, [1.]))
479
+ ap_p = np.concatenate(([0.], prec, [0.]))
480
+ sorted_idxes = np.argsort(ap_r)
481
+ ap_r = ap_r[sorted_idxes]
482
+ ap_p = ap_p[sorted_idxes]
483
+ count = ap_r.shape[0]
484
+
485
+ for i in range(count - 1, 0, -1):
486
+ ap_p[i - 1] = max(ap_p[i], ap_p[i - 1])
487
+
488
+ i = np.where(ap_r[1:] != ap_r[:-1])[0]
489
+ ap = np.sum((ap_r[i + 1] - ap_r[i]) * ap_p[i + 1])
490
+ return ap
hist_of_pixel_values.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import cv2
4
+ import matplotlib.pyplot as plt
5
+
6
+
7
+ root_dir = os.path.join([rd for rd in os.listdir('.') if 'gconet_' in rd][0], 'CoCA/Accordion')
8
+ image_paths = [os.path.join(root_dir, p) for p in os.listdir(root_dir)]
9
+ pixel_values = []
10
+ for image_path in image_paths:
11
+ image = cv2.imread(image_path)
12
+ pixel_value = image.flatten().squeeze().tolist()
13
+ pixel_values += pixel_value
14
+
15
+ pixel_values = np.array(pixel_values)
16
+
17
+ non_zero_values = pixel_values[pixel_values >= 0]
18
+ margin_values_percent = (np.sum(non_zero_values > 230) + np.sum(non_zero_values <= 0)) / non_zero_values.shape[0] * 100
19
+ print('histing...')
20
+ plt.hist(x=non_zero_values)
21
+ plt.title('(0+>230)/all, {:.1f} % are margin values'.format(margin_values_percent))
22
+ plt.savefig('hist_(0+>230)|all.png')
23
+ plt.show()
24
+
25
+ non_zero_values = pixel_values[pixel_values >= 0]
26
+ margin_values_percent = (np.sum(non_zero_values > 230) + np.sum(non_zero_values < 0)) / non_zero_values.shape[0] * 100
27
+ print('histing...')
28
+ plt.figure()
29
+ plt.hist(x=non_zero_values)
30
+ plt.title('(230)/all, {:.1f} % are margin values'.format(margin_values_percent))
31
+ plt.savefig('hist_(230)|all.png')
32
+ plt.show()
33
+
34
+ non_zero_values = pixel_values[pixel_values > 0]
35
+ margin_values_percent = (np.sum(non_zero_values > 230) + np.sum(non_zero_values <= 0)) / non_zero_values.shape[0] * 100
36
+ print('histing...')
37
+ plt.figure()
38
+ plt.hist(x=non_zero_values)
39
+ plt.title('(0+>230)/(all-0), {:.1f} % are margin values'.format(margin_values_percent))
40
+ plt.savefig('hist_(0+>230)|(all-0).png')
41
+ plt.show()
loss.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import math
5
+ import numpy as np
6
+ from torch.autograd import Variable
7
+
8
+
9
+ class IoU_loss(torch.nn.Module):
10
+ def __init__(self):
11
+ super(IoU_loss, self).__init__()
12
+
13
+ def forward(self, pred, target):
14
+ b = pred.shape[0]
15
+ IoU = 0.0
16
+ for i in range(0, b):
17
+ #compute the IoU of the foreground
18
+ Iand1 = torch.sum(target[i, :, :, :]*pred[i, :, :, :])
19
+ Ior1 = torch.sum(target[i, :, :, :]) + torch.sum(pred[i, :, :, :])-Iand1
20
+ IoU1 = Iand1/(Ior1 + 1e-5)
21
+ #IoU loss is (1-IoU1)
22
+ IoU = IoU + (1-IoU1)
23
+
24
+ return IoU/b
25
+ #return IoU
26
+
27
+
28
+ class Scale_IoU(nn.Module):
29
+ def __init__(self):
30
+ super(Scale_IoU, self).__init__()
31
+ self.iou = IoU_loss()
32
+
33
+ def forward(self, scaled_preds, gt):
34
+ loss = 0
35
+ for pred_lvl in scaled_preds[0:]:
36
+ loss += self.iou(torch.sigmoid(pred_lvl), gt) + self.iou(1-torch.sigmoid(pred_lvl), 1-gt)
37
+ return loss
38
+
39
+
40
+ def compute_cos_dis(x_sup, x_que):
41
+ x_sup = x_sup.view(x_sup.size()[0], x_sup.size()[1], -1)
42
+ x_que = x_que.view(x_que.size()[0], x_que.size()[1], -1)
43
+
44
+ x_que_norm = torch.norm(x_que, p=2, dim=1, keepdim=True)
45
+ x_sup_norm = torch.norm(x_sup, p=2, dim=1, keepdim=True)
46
+
47
+ x_que_norm = x_que_norm.permute(0, 2, 1)
48
+ x_qs_norm = torch.matmul(x_que_norm, x_sup_norm)
49
+
50
+ x_que = x_que.permute(0, 2, 1)
51
+
52
+ x_qs = torch.matmul(x_que, x_sup)
53
+ x_qs = x_qs / (x_qs_norm + 1e-5)
54
+ return x_qs
55
+
56
+
main.cpython-37.pyc ADDED
Binary file (10.5 kB). View file
 
main.cpython-38.pyc ADDED
Binary file (10.3 kB). View file
 
main.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from models.vgg import VGG_Backbone
5
+ from util import *
6
+
7
+
8
+ def weights_init(module):
9
+ if isinstance(module, nn.Conv2d):
10
+ nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
11
+ if module.bias is not None:
12
+ nn.init.zeros_(module.bias)
13
+ elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
14
+ nn.init.ones_(module.weight)
15
+ if module.bias is not None:
16
+ nn.init.zeros_(module.bias)
17
+ elif isinstance(module, nn.Linear):
18
+ nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
19
+ if module.bias is not None:
20
+ nn.init.zeros_(module.bias)
21
+
22
+
23
+ class EnLayer(nn.Module):
24
+ def __init__(self, in_channel=64):
25
+ super(EnLayer, self).__init__()
26
+ self.enlayer = nn.Sequential(
27
+ nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1),
28
+ nn.ReLU(inplace=True),
29
+ nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
30
+ )
31
+
32
+ def forward(self, x):
33
+ x = self.enlayer(x)
34
+ return x
35
+
36
+
37
+ class LatLayer(nn.Module):
38
+ def __init__(self, in_channel):
39
+ super(LatLayer, self).__init__()
40
+ self.convlayer = nn.Sequential(
41
+ nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1),
42
+ nn.ReLU(inplace=True),
43
+ nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
44
+ )
45
+
46
+ def forward(self, x):
47
+ x = self.convlayer(x)
48
+ return x
49
+
50
+
51
+ class DSLayer(nn.Module):
52
+ def __init__(self, in_channel=64):
53
+ super(DSLayer, self).__init__()
54
+ self.enlayer = nn.Sequential(
55
+ nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1),
56
+ nn.ReLU(inplace=True),
57
+ nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
58
+ nn.ReLU(inplace=True),
59
+ )
60
+ self.predlayer = nn.Sequential(
61
+ nn.Conv2d(64, 1, kernel_size=1, stride=1, padding=0))#, nn.Sigmoid())
62
+
63
+ def forward(self, x):
64
+ x = self.enlayer(x)
65
+ x = self.predlayer(x)
66
+ return x
67
+
68
+
69
+ class half_DSLayer(nn.Module):
70
+ def __init__(self, in_channel=512):
71
+ super(half_DSLayer, self).__init__()
72
+ self.enlayer = nn.Sequential(
73
+ nn.Conv2d(in_channel, int(in_channel/4), kernel_size=3, stride=1, padding=1),
74
+ nn.ReLU(inplace=True),
75
+ )
76
+ self.predlayer = nn.Sequential(
77
+ nn.Conv2d(int(in_channel/4), 1, kernel_size=1, stride=1, padding=0)) #, nn.Sigmoid())
78
+
79
+ def forward(self, x):
80
+ x = self.enlayer(x)
81
+ x = self.predlayer(x)
82
+ return x
83
+
84
+
85
+ class AugAttentionModule(nn.Module):
86
+ def __init__(self, input_channels=512):
87
+ super(AugAttentionModule, self).__init__()
88
+ self.query_transform = nn.Sequential(
89
+ nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0),
90
+ nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0),
91
+ )
92
+ self.key_transform = nn.Sequential(
93
+ nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0),
94
+ nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0),
95
+ )
96
+ self.value_transform = nn.Sequential(
97
+ nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0),
98
+ nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0),
99
+ )
100
+ self.scale = 1.0 / (input_channels ** 0.5)
101
+ self.conv = nn.Sequential(
102
+ nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0),
103
+ nn.ReLU(inplace=True),
104
+ )
105
+
106
+ def forward(self, x):
107
+ B, C, H, W = x.size()
108
+ x = self.conv(x)
109
+ x_query = self.query_transform(x).view(B, C, -1).permute(0, 2, 1) # B,HW,C
110
+ # x_key: C,BHW
111
+ x_key = self.key_transform(x).view(B, C, -1) # B, C,HW
112
+ # x_value: BHW, C
113
+ x_value = self.value_transform(x).view(B, C, -1).permute(0, 2, 1) # B,HW,C
114
+ attention_bmm = torch.bmm(x_query, x_key)*self.scale # B, HW, HW
115
+ attention = F.softmax(attention_bmm, dim=-1)
116
+ attention_sort = torch.sort(attention_bmm, dim=-1, descending=True)[1]
117
+ attention_sort = torch.sort(attention_sort, dim=-1)[1]
118
+ #####
119
+ attention_positive_num = torch.ones_like(attention).cuda()
120
+ attention_positive_num[attention_bmm < 0] = 0
121
+ att_pos_mask = attention_positive_num.clone()
122
+ attention_positive_num = torch.sum(attention_positive_num, dim=-1, keepdim=True).expand_as(attention_sort)
123
+ attention_sort_pos = attention_sort.float().clone()
124
+ apn = attention_positive_num-1
125
+ attention_sort_pos[attention_sort > apn] = 0
126
+ attention_mask = ((attention_sort_pos+1)**3)*att_pos_mask + (1-att_pos_mask)
127
+ out = torch.bmm(attention*attention_mask, x_value)
128
+ out = out.view(B, H, W, C).permute(0, 3, 1, 2)
129
+ return out+x
130
+
131
+
132
+ class AttLayer(nn.Module):
133
+ def __init__(self, input_channels=512):
134
+ super(AttLayer, self).__init__()
135
+ self.query_transform = nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0)
136
+ self.key_transform = nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0)
137
+ self.scale = 1.0 / (input_channels ** 0.5)
138
+ self.conv = nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0)
139
+
140
+ def correlation(self, x5, seeds):
141
+ B, C, H5, W5 = x5.size()
142
+ if self.training:
143
+ correlation_maps = F.conv2d(x5, weight=seeds) # B,B,H,W
144
+ else:
145
+ correlation_maps = torch.relu(F.conv2d(x5, weight=seeds)) # B,B,H,W
146
+ correlation_maps = correlation_maps.mean(1).view(B, -1)
147
+ min_value = torch.min(correlation_maps, dim=1, keepdim=True)[0]
148
+ max_value = torch.max(correlation_maps, dim=1, keepdim=True)[0]
149
+ correlation_maps = (correlation_maps - min_value) / (max_value - min_value + 1e-12) # shape=[B, HW]
150
+ correlation_maps = correlation_maps.view(B, 1, H5, W5) # shape=[B, 1, H, W]
151
+ return correlation_maps
152
+
153
+ def forward(self, x5):
154
+ # x: B,C,H,W
155
+ x5 = self.conv(x5)+x5
156
+ B, C, H5, W5 = x5.size()
157
+ x_query = self.query_transform(x5).view(B, C, -1)
158
+ # x_query: B,HW,C
159
+ x_query = torch.transpose(x_query, 1, 2).contiguous().view(-1, C) # BHW, C
160
+ # x_key: B,C,HW
161
+ x_key = self.key_transform(x5).view(B, C, -1)
162
+ x_key = torch.transpose(x_key, 0, 1).contiguous().view(C, -1) # C, BHW
163
+ # W = Q^T K: B,HW,HW
164
+ x_w1 = torch.matmul(x_query, x_key) * self.scale # BHW, BHW
165
+ x_w = x_w1.view(B * H5 * W5, B, H5 * W5)
166
+ x_w = torch.max(x_w, -1).values # BHW, B
167
+ x_w = x_w.mean(-1)
168
+ x_w = x_w.view(B, -1) # B, HW
169
+ x_w = F.softmax(x_w, dim=-1) # B, HW
170
+ ##### mine ######
171
+ # x_w_max = torch.max(x_w, -1)
172
+ # max_indices0 = x_w_max.indices.unsqueeze(-1).unsqueeze(-1)
173
+ norm0 = F.normalize(x5, dim=1)
174
+ # norm = norm0.view(B, C, -1)
175
+ # max_indices = max_indices0.expand(B, C, -1)
176
+ # seeds = torch.gather(norm, 2, max_indices).unsqueeze(-1)
177
+ x_w = x_w.unsqueeze(1)
178
+ x_w_max = torch.max(x_w, -1).values.unsqueeze(2).expand_as(x_w)
179
+ mask = torch.zeros_like(x_w).cuda()
180
+ mask[x_w == x_w_max] = 1
181
+ mask = mask.view(B, 1, H5, W5)
182
+ seeds = norm0 * mask
183
+ seeds = seeds.sum(3).sum(2).unsqueeze(2).unsqueeze(3)
184
+ cormap = self.correlation(norm0, seeds)
185
+ x51 = x5 * cormap
186
+ proto1 = torch.mean(x51, (0, 2, 3), True)
187
+ return x5, proto1, x5*proto1+x51, mask
188
+
189
+
190
+ class Decoder(nn.Module):
191
+ def __init__(self):
192
+ super(Decoder, self).__init__()
193
+ self.toplayer = nn.Sequential(
194
+ nn.Conv2d(512, 64, kernel_size=1, stride=1, padding=0),
195
+ nn.ReLU(inplace=True),
196
+ nn.Conv2d(64, 64, kernel_size=1, stride=1, padding=0))
197
+ self.latlayer4 = LatLayer(in_channel=512)
198
+ self.latlayer3 = LatLayer(in_channel=256)
199
+ self.latlayer2 = LatLayer(in_channel=128)
200
+ self.latlayer1 = LatLayer(in_channel=64)
201
+
202
+ self.enlayer4 = EnLayer()
203
+ self.enlayer3 = EnLayer()
204
+ self.enlayer2 = EnLayer()
205
+ self.enlayer1 = EnLayer()
206
+
207
+ self.dslayer4 = DSLayer()
208
+ self.dslayer3 = DSLayer()
209
+ self.dslayer2 = DSLayer()
210
+ self.dslayer1 = DSLayer()
211
+
212
+ def _upsample_add(self, x, y):
213
+ [_, _, H, W] = y.size()
214
+ x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=False)
215
+ return x + y
216
+
217
+ def forward(self, weighted_x5, x4, x3, x2, x1, H, W):
218
+ preds = []
219
+ p5 = self.toplayer(weighted_x5)
220
+ p4 = self._upsample_add(p5, self.latlayer4(x4))
221
+ p4 = self.enlayer4(p4)
222
+ _pred = self.dslayer4(p4)
223
+ preds.append(
224
+ F.interpolate(_pred,
225
+ size=(H, W),
226
+ mode='bilinear', align_corners=False))
227
+
228
+ p3 = self._upsample_add(p4, self.latlayer3(x3))
229
+ p3 = self.enlayer3(p3)
230
+ _pred = self.dslayer3(p3)
231
+ preds.append(
232
+ F.interpolate(_pred,
233
+ size=(H, W),
234
+ mode='bilinear', align_corners=False))
235
+
236
+ p2 = self._upsample_add(p3, self.latlayer2(x2))
237
+ p2 = self.enlayer2(p2)
238
+ _pred = self.dslayer2(p2)
239
+ preds.append(
240
+ F.interpolate(_pred,
241
+ size=(H, W),
242
+ mode='bilinear', align_corners=False))
243
+
244
+ p1 = self._upsample_add(p2, self.latlayer1(x1))
245
+ p1 = self.enlayer1(p1)
246
+ _pred = self.dslayer1(p1)
247
+ preds.append(
248
+ F.interpolate(_pred,
249
+ size=(H, W),
250
+ mode='bilinear', align_corners=False))
251
+ return preds
252
+
253
+
254
+ class DCFMNet(nn.Module):
255
+ """ Class for extracting activations and
256
+ registering gradients from targetted intermediate layers """
257
+ def __init__(self, mode='train'):
258
+ super(DCFMNet, self).__init__()
259
+ self.gradients = None
260
+ self.backbone = VGG_Backbone()
261
+ self.mode = mode
262
+ self.aug = AugAttentionModule()
263
+ self.fusion = AttLayer(512)
264
+ self.decoder = Decoder()
265
+
266
+ def set_mode(self, mode):
267
+ self.mode = mode
268
+
269
+ def forward(self, x, gt):
270
+ if self.mode == 'train':
271
+ preds = self._forward(x, gt)
272
+ else:
273
+ with torch.no_grad():
274
+ preds = self._forward(x, gt)
275
+
276
+ return preds
277
+
278
+ def featextract(self, x):
279
+ x1 = self.backbone.conv1(x)
280
+ x2 = self.backbone.conv2(x1)
281
+ x3 = self.backbone.conv3(x2)
282
+ x4 = self.backbone.conv4(x3)
283
+ x5 = self.backbone.conv5(x4)
284
+ return x5, x4, x3, x2, x1
285
+
286
+ def _forward(self, x, gt):
287
+ [B, _, H, W] = x.size()
288
+ x5, x4, x3, x2, x1 = self.featextract(x)
289
+ feat, proto, weighted_x5, cormap = self.fusion(x5)
290
+ feataug = self.aug(weighted_x5)
291
+ preds = self.decoder(feataug, x4, x3, x2, x1, H, W)
292
+ if self.training:
293
+ gt = F.interpolate(gt, size=weighted_x5.size()[2:], mode='bilinear', align_corners=False)
294
+ feat_pos, proto_pos, weighted_x5_pos, cormap_pos = self.fusion(x5 * gt)
295
+ feat_neg, proto_neg, weighted_x5_neg, cormap_neg = self.fusion(x5*(1-gt))
296
+ return preds, proto, proto_pos, proto_neg
297
+ return preds
298
+
299
+
300
+ class DCFM(nn.Module):
301
+ def __init__(self, mode='train'):
302
+ super(DCFM, self).__init__()
303
+ set_seed(123)
304
+ self.dcfmnet = DCFMNet()
305
+ self.mode = mode
306
+
307
+ def set_mode(self, mode):
308
+ self.mode = mode
309
+ self.dcfmnet.set_mode(self.mode)
310
+
311
+ def forward(self, x, gt):
312
+ ########## Co-SOD ############
313
+ preds = self.dcfmnet(x, gt)
314
+ return preds
315
+
preprocessing.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import cv2
4
+ import numpy as np
5
+ from PIL import Image
6
+ import torch
7
+ from PIL import Image
8
+ from pathlib import Path
9
+ import torchvision.transforms as T
10
+ import torchvision.transforms.functional as TF
11
+ import numpy as np
12
+ from torchvision import transforms
13
+ import os
14
+ import cv2
15
+ import matplotlib.pyplot as plt
16
+ import shutil
17
+
18
+ source = "/scratch/wej36how/Datasets/NWRD/train"
19
+ dest = "/scratch/wej36how/Datasets/NWRDProcessed/train"
20
+ patch_size = 224
21
+ rust_threshold = 150
22
+ max_number_of_images_per_group = 12
23
+
24
+ patches_path = os.path.join(dest, "patches")
25
+ images_dir = os.path.join(patches_path, "images")
26
+ masks_dir = os.path.join(patches_path, "masks")
27
+
28
+ destination = os.path.join(dest, "RustNonRustSplit")
29
+ root = patches_path
30
+
31
+ rust_images_dir = os.path.join(destination,"rust","images")
32
+ non_rust_images_dir = os.path.join(destination,"non_rust","images")
33
+
34
+ rustClassificationDir = os.path.join(dest, "calssification", "rust")
35
+ nonRustClassificationDir = os.path.join(dest, "calssification", "non_rust")
36
+ os.makedirs(rustClassificationDir, exist_ok=True)
37
+ os.makedirs(nonRustClassificationDir, exist_ok=True)
38
+
39
+ shutil.copytree(rust_images_dir,rustClassificationDir, dirs_exist_ok=True)
40
+ shutil.copytree(non_rust_images_dir,nonRustClassificationDir, dirs_exist_ok=True)
41
+
42
+ import os
43
+ import glob
44
+
45
+ def delete_extra_images(directory, target_count):
46
+ # Get a list of all image files in the directory
47
+ image_files = glob.glob(os.path.join(directory, '*.JPG')) + glob.glob(os.path.join(directory, '*.jpeg')) + glob.glob(os.path.join(directory, '*.png'))
48
+
49
+ # Check if the number of images exceeds the target count
50
+ if len(image_files) > target_count:
51
+ # Calculate the number of images to delete
52
+ num_to_delete = len(image_files) - target_count
53
+ # Sort the images by modification time (oldest first)
54
+ image_files.sort(key=os.path.getmtime)
55
+ # Delete the extra images
56
+ for i in range(num_to_delete):
57
+ os.remove(image_files[i])
58
+ print(f"{num_to_delete} images deleted.")
59
+ elif len(image_files) < target_count:
60
+ print("Warning: Number of images in directory is less than the target count.")
61
+
62
+ if len(os.listdir(rustClassificationDir))< len(os.listdir(nonRustClassificationDir)):
63
+ delete_extra_images(nonRustClassificationDir, len(os.listdir(rustClassificationDir)))
64
+ else:
65
+ delete_extra_images(rustClassificationDir, len(os.listdir(nonRustClassificationDir)))
66
+
67
+ rust_dir = os.path.join(destination,"rust")
68
+ rustCosaliencynDir = os.path.join(dest, "cosaliency")
69
+ shutil.copytree(rust_dir,rustCosaliencynDir, dirs_exist_ok=True)
70
+
71
+ # Function to split images into folders based on image number
72
+ def split_images_into_folders(source_dir, destination_dir):
73
+ # Create destination directory if it doesn't exist
74
+ if not os.path.exists(destination_dir):
75
+ os.makedirs(destination_dir)
76
+ # Iterate through files in the source directory
77
+ for filename in os.listdir(source_dir):
78
+ if filename.endswith('.png'):
79
+ image_no = filename.split('_')[0] # Extract image number from filename
80
+ if not image_no.isdigit():
81
+ image_no = filename.split('_')[1]
82
+ destination_subdir = os.path.join(destination_dir, image_no)
83
+ # Create subdirectory if it doesn't exist
84
+ if not os.path.exists(destination_subdir):
85
+ os.makedirs(destination_subdir)
86
+ # Move the image file to the respective subdirectory
87
+ shutil.move(os.path.join(source_dir, filename), destination_subdir)
88
+
89
+ def organize_images(main_directory):
90
+ # Ensure the main directory exists
91
+ if not os.path.exists(main_directory):
92
+ print(f"The specified main directory '{main_directory}' does not exist.")
93
+ return
94
+
95
+ # Get a list of subdirectories in the main directory
96
+ subdirectories = [d for d in os.listdir(main_directory) if os.path.isdir(os.path.join(main_directory, d))]
97
+
98
+ # Process each subdirectory
99
+ for subdir in subdirectories:
100
+ subdir_path = os.path.join(main_directory, subdir)
101
+
102
+ # Get a list of images in the subdirectory
103
+ images = [f for f in os.listdir(subdir_path) if f.lower().endswith(('.jpg', '.jpeg', '.png', '.gif'))]
104
+ # Determine the number of images per subdirectory
105
+ images_per_subdir = 12
106
+ num_subdirectories = len(images) // images_per_subdir
107
+ n=0
108
+ # Create additional subdirectories if needed
109
+ for i in range(num_subdirectories - 1):
110
+ new_subdir_name = f"{subdir}_part{i + 1}"
111
+ new_subdir_path = os.path.join(main_directory, new_subdir_name)
112
+
113
+ # Create the new subdirectory
114
+ os.makedirs(new_subdir_path)
115
+
116
+ # Move images to the new subdirectory
117
+ for j in range(images_per_subdir):
118
+ old_image_path = os.path.join(subdir_path, images[n])
119
+ new_image_path = os.path.join(new_subdir_path, images[n])
120
+ shutil.move(old_image_path, new_image_path)
121
+ n+=1
122
+
123
+ source_directory = os.path.join(dest, "cosaliency", "images")
124
+ destination_directory = os.path.join(dest, "cosaliency", "images")
125
+ split_images_into_folders(source_directory, destination_directory)
126
+ organize_images(destination_directory)
127
+
128
+ source_directory = os.path.join(dest, "cosaliency", "masks")
129
+ destination_directory = os.path.join(dest, "cosaliency", "masks")
130
+ split_images_into_folders(source_directory, destination_directory)
131
+ organize_images(destination_directory)
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ matplotlib==3.4.1
2
+ numpy==1.19.2
3
+ opencv_python==4.5.1.48
4
+ pandas==1.2.4
5
+ Pillow==9.1.0
6
+ pytorch_toolbelt==0.4.3
7
+ scikit_image==0.18.1
8
+ skimage==0.0
9
+ torch==1.7.1
10
+ torchvision==0.2.2
11
+ tqdm==4.60.0
12
+ transformers
segmentation_metrics.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ from sklearn.metrics import precision_score, recall_score, f1_score
5
+ from sklearn.metrics import jaccard_score
6
+
7
+ # Directories containing the prediction maps and ground truth masks
8
+ dir1 = '/home/wej36how/codes/CoSOD-main/result/Predictions/NWRDFRust_concatenated'
9
+ dir2 = '/home/wej36how/datasets/NWRDF/test/masks'
10
+
11
+ # Initialize lists to store scores
12
+ precisions = []
13
+ recalls = []
14
+ f1_scores = []
15
+ iou_scores = []
16
+
17
+ # Loop through all files in the prediction directory
18
+ for filename in os.listdir(dir1):
19
+ pred_path = os.path.join(dir1, filename)
20
+ gt_path = os.path.join(dir2, filename)
21
+ print(pred_path)
22
+ # Ensure that the file exists in both directories
23
+ if os.path.exists(pred_path) and os.path.exists(gt_path):
24
+ # Load the prediction and ground truth images
25
+ pred_img = cv2.imread(pred_path, cv2.IMREAD_GRAYSCALE)
26
+ gt_img = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE)
27
+
28
+ # Flatten the images to 1D arrays
29
+ pred_flat = pred_img.flatten()
30
+ gt_flat = gt_img.flatten()
31
+
32
+ # Binarize the images (assuming binary segmentation masks)
33
+ pred_flat = (pred_flat > 127).astype(np.uint8)
34
+ gt_flat = (gt_flat > 127).astype(np.uint8)
35
+
36
+ # Calculate precision, recall, and F1 score
37
+ precision = precision_score(gt_flat, pred_flat)
38
+ recall = recall_score(gt_flat, pred_flat)
39
+ f1 = f1_score(gt_flat, pred_flat)
40
+ iou = jaccard_score(gt_flat, pred_flat)
41
+
42
+ # Append the scores to the lists
43
+ precisions.append(precision)
44
+ recalls.append(recall)
45
+ f1_scores.append(f1)
46
+ iou_scores.append(iou)
47
+
48
+
49
+
50
+ # Calculate average scores
51
+ avg_precision = np.mean(precisions)
52
+ avg_recall = np.mean(recalls)
53
+ avg_f1_score = np.mean(f1_scores)
54
+ avg_iou = np.mean(iou_scores)
55
+
56
+ # Print the results
57
+ print(f'Average Precision: {avg_precision:.4f}')
58
+ print(f'Average Recall: {avg_recall:.4f}')
59
+ print(f'Average F1 Score: {avg_f1_score:.4f}')
60
+ print(f'Average iou Score: {avg_iou:.4f}')
select_results.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import numpy as np
4
+ import cv2
5
+
6
+ from evaluator import Eval_thread
7
+ from dataloader import EvalDataset
8
+
9
+ import sys
10
+ sys.path.append('..')
11
+
12
+
13
+ def main(cfg):
14
+ dataset_names = cfg.datasets.split('+')
15
+ root_dir_predictions = [dr for dr in os.listdir('.') if 'gconet_' in dr]
16
+ root_dir_prediction_comp = cfg.gt_dir.replace('/gts', '/gconet')
17
+ print('root_dir_predictions:', root_dir_predictions)
18
+ root_dir_prediction = root_dir_predictions[0]
19
+ root_dir_good_ones = 'good_ones'
20
+ for dataset in dataset_names:
21
+ dir_prediction = os.path.join(root_dir_prediction, dataset)
22
+ dir_prediction_comp = os.path.join(root_dir_prediction_comp, dataset)
23
+ dir_gt = os.path.join(cfg.gt_dir, dataset)
24
+ loader = EvalDataset(
25
+ dir_prediction, # preds
26
+ dir_gt, # GT
27
+ return_predpath=True,
28
+ return_gtpath=True
29
+ )
30
+ loader_comp = EvalDataset(
31
+ dir_prediction_comp, # preds
32
+ dir_gt, # GT
33
+ return_predpath=True
34
+ )
35
+ print('Selecting predictions from {}'.format(dir_prediction))
36
+ thread = Eval_thread(loader, cuda=cfg.cuda)
37
+ s_measure, good_ones, good_ones_comp, good_ones_gt = thread.select_by_Smeasure(bar=0.95, loader_comp=loader_comp, bar_comp=0.2)
38
+ dir_good_ones = os.path.join(root_dir_good_ones, dataset)
39
+ os.makedirs(dir_good_ones, exist_ok=True)
40
+ print('have good_ones {}'.format(len(good_ones)))
41
+ for good_one, good_one_comp, good_one_gt in zip(good_ones, good_ones_comp, good_ones_gt):
42
+ dir_category = os.path.join(dir_good_ones, good_one.split('/')[-2])
43
+ os.makedirs(dir_category, exist_ok=True)
44
+ save_path = os.path.join(dir_category, good_one.split('/')[-1])
45
+ sal_map = cv2.imread(good_one)
46
+ sal_map_gt = cv2.imread(good_one_gt)
47
+ sal_map_comp = cv2.imread(good_one_comp)
48
+ image_path = good_one_gt.replace('/gts', '/images').replace('.png', '.jpg')
49
+ image = cv2.imread(image_path)
50
+ cv2.imwrite(save_path, sal_map)
51
+ split_line = np.zeros((sal_map.shape[0], 10, 3)).astype(sal_map.dtype) + 127
52
+ comp = cv2.hconcat([image, split_line, sal_map_gt, split_line, sal_map, split_line, sal_map_comp])
53
+ save_path_comp = ''.join((save_path[:-4], '_comp', save_path[-4:]))
54
+ cv2.imwrite(save_path_comp, comp)
55
+
56
+
57
+ if __name__ == "__main__":
58
+ parser = argparse.ArgumentParser()
59
+ parser.add_argument('--datasets', type=str, default='CoCA+CoSOD3k+CoSal2015')
60
+ parser.add_argument('--gt_dir', type=str, default='/root/datasets/sod/gts', help='GT')
61
+ parser.add_argument('--cuda', type=bool, default=True)
62
+ config = parser.parse_args()
63
+ main(config)
sort_results.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+
6
+
7
+ move_best_results_here = False
8
+
9
+ record = ['dataset', 'ckpt', 'Emax', 'Smeasure', 'Fmax', 'MAE', 'Emean', 'Fmean']
10
+ measurement = 'Emax'
11
+ score_idx = record.index(measurement)
12
+
13
+ with open('output/details/result.txt', 'r') as f:
14
+ res = f.read()
15
+
16
+ res = res.replace('||', '').replace('(', '').replace(')', '')
17
+
18
+ score = []
19
+ for r in res.splitlines():
20
+ ds = r.split()
21
+ s = ds[:2]
22
+ for idx_d, d in enumerate(ds[2:]):
23
+ if idx_d % 2 == 0:
24
+ s.append(float(d))
25
+ score.append(s)
26
+
27
+ ss = sorted(score, key=lambda x: (x[record.index('dataset')], x[record.index('Emax')], x[record.index('Smeasure')], x[record.index('Fmax')], x[record.index('ckpt')]), reverse=True)
28
+ ss_ar = np.array(ss)
29
+ np.savetxt('score_sorted.txt', ss_ar, fmt='%s')
30
+ ckpt_coca = ss_ar[ss_ar[:, 0] == 'CoCA'][0][1]
31
+ ckpt_cosod = ss_ar[ss_ar[:, 0] == 'CoSOD3k'][0][1]
32
+ ckpt_cosal = ss_ar[ss_ar[:, 0] == 'CoSal2015'][0][1]
33
+
34
+ best_coca_scores = ss_ar[ss_ar[:, 1] == ckpt_coca]
35
+ best_cosod_scores = ss_ar[ss_ar[:, 1] == ckpt_cosod]
36
+ best_cosal_scores = ss_ar[ss_ar[:, 1] == ckpt_cosal]
37
+ print('Best (models may be different):')
38
+ print('CoCA:\n', best_coca_scores)
39
+ print('CoSOD3k:\n', best_cosod_scores)
40
+ print('CoSal2015:\n', best_cosal_scores)
41
+
42
+ # Overal relative Emax improvement on three datasets
43
+ if measurement == 'Emax':
44
+ gco_scores = {'CoCA': 0.760, 'CoSOD3k': 0.860, 'CoSal2015': 0.887}
45
+ gco_scores_Smeasure = {'CoCA': 0.673, 'CoSOD3k': 0.802, 'CoSal2015': 0.845}
46
+ elif measurement == 'Smeasure':
47
+ gco_scores = {'CoCA': 0.673, 'CoSOD3k': 0.802, 'CoSal2015': 0.845}
48
+ elif measurement == 'Fmax':
49
+ gco_scores = {'CoCA': 0.544, 'CoSOD3k': 0.777, 'CoSal2015': 0.847}
50
+ elif measurement == 'Emean':
51
+ gco_scores = {'CoCA': 0.1, 'CoSOD3k': 0.1, 'CoSal2015': 0.1}
52
+ elif measurement == 'Fmean':
53
+ gco_scores = {'CoCA': 0.1, 'CoSOD3k': 0.1, 'CoSal2015': 0.1}
54
+ ckpts = list(set(ss_ar[:, 1].squeeze().tolist()))
55
+ improvements_mean = []
56
+ improvements_lst = []
57
+ improvements_mean_Smeasure = []
58
+ improvements_lst_Smeasure = []
59
+ for ckpt in ckpts:
60
+ scores = ss_ar[ss_ar[:, 1] == ckpt]
61
+ if scores.shape[0] != len(gco_scores):
62
+ improvements_mean.append(-1)
63
+ improvements_lst.append([-1, -1, 1])
64
+ improvements_mean_Smeasure.append(-1)
65
+ improvements_lst_Smeasure.append([-1, -1, 1])
66
+ continue
67
+ score_coca = float(scores[scores[:, 0] == 'CoCA'][0][score_idx])
68
+ score_cosod = float(scores[scores[:, 0] == 'CoSOD3k'][0][score_idx])
69
+ score_cosal = float(scores[scores[:, 0] == 'CoSal2015'][0][score_idx])
70
+ improvements = [
71
+ (score_coca - gco_scores['CoCA']) / gco_scores['CoCA'],
72
+ (score_cosod - gco_scores['CoSOD3k']) / gco_scores['CoSOD3k'],
73
+ (score_cosal - gco_scores['CoSal2015']) / gco_scores['CoSal2015']
74
+ ]
75
+ improvement_mean = np.mean(improvements)
76
+ improvements_mean.append(improvement_mean)
77
+ improvements_lst.append(improvements)
78
+
79
+ # Smeasure
80
+ score_coca = float(scores[scores[:, 0] == 'CoCA'][0][record.index('Smeasure')])
81
+ score_cosod = float(scores[scores[:, 0] == 'CoSOD3k'][0][record.index('Smeasure')])
82
+ score_cosal = float(scores[scores[:, 0] == 'CoSal2015'][0][record.index('Smeasure')])
83
+ improvements_Smeasure = [
84
+ (score_coca - gco_scores_Smeasure['CoCA']) / gco_scores_Smeasure['CoCA'],
85
+ (score_cosod - gco_scores_Smeasure['CoSOD3k']) / gco_scores_Smeasure['CoSOD3k'],
86
+ (score_cosal - gco_scores_Smeasure['CoSal2015']) / gco_scores_Smeasure['CoSal2015']
87
+ ]
88
+ improvement_mean_Smeasure = np.mean(improvements_Smeasure)
89
+ improvements_mean_Smeasure.append(improvement_mean_Smeasure)
90
+ improvements_lst_Smeasure.append(improvements_Smeasure)
91
+ best_measurement = 'Emax'
92
+ if best_measurement == 'Emax':
93
+ best_improvement_index = np.argsort(improvements_mean).tolist()[-1]
94
+ best_ckpt = ckpts[best_improvement_index]
95
+ best_improvement_mean = improvements_mean[best_improvement_index]
96
+ best_improvements = improvements_lst[best_improvement_index]
97
+
98
+ best_improvement_mean_Smeasure = improvements_mean_Smeasure[best_improvement_index]
99
+ best_improvements_Smeasure = improvements_lst_Smeasure[best_improvement_index]
100
+ elif best_measurement == 'Smeasure':
101
+ best_improvement_index = np.argsort(improvements_mean_Smeasure).tolist()[-1]
102
+ best_ckpt = ckpts[best_improvement_index]
103
+ best_improvement_mean_Smeasure = improvements_mean_Smeasure[best_improvement_index]
104
+ best_improvements_Smeasure = improvements_lst_Smeasure[best_improvement_index]
105
+
106
+ best_improvement_mean = improvements_mean[best_improvement_index]
107
+ best_improvements = improvements_lst[best_improvement_index]
108
+
109
+ print('The overall best one:')
110
+ print(ss_ar[ss_ar[:, 1] == best_ckpt])
111
+ print('Got Emax improvements on CoCA-{:.3f}%, CoSOD3k-{:.3f}%, CoSal2015-{:.3f}%, mean_improvement: {:.3f}%.'.format(
112
+ best_improvements[0]*100, best_improvements[1]*100, best_improvements[2]*100, best_improvement_mean*100
113
+ ))
114
+ print('Got Smes improvements on CoCA-{:.3f}%, CoSOD3k-{:.3f}%, CoSal2015-{:.3f}%, mean_improvement: {:.3f}%.'.format(
115
+ best_improvements_Smeasure[0]*100, best_improvements_Smeasure[1]*100, best_improvements_Smeasure[2]*100, best_improvement_mean_Smeasure*100
116
+ ))
117
+ trial = int(best_ckpt.split('_')[-1].split('-')[0])
118
+ ep = int(best_ckpt.split('ep')[-1].split(':')[0])
119
+ if move_best_results_here:
120
+ trial, ep = 'gconet_{}'.format(trial), 'ep{}'.format(ep)
121
+ dr = os.path.join(trial, ep)
122
+ dst = '-'.join((trial, ep))
123
+ shutil.move(os.path.join('/root/datasets/sod/preds', dr), dst)
124
+
125
+
126
+ # model_indices = sorted([fname.split('_')[-1] for fname in os.listdir('output/details') if 'gconet_' in fname])
127
+ # emax = {}
128
+ # for model_idx in model_indices:
129
+ # m = 'gconet_{}-'.format(model_idx)
130
+ # if m not in list(emax.keys()):
131
+ # emax[m] = []
132
+ # for s in score:
133
+ # if m in s[1]:
134
+ # ep = int(s[1].split('ep')[-1].rstrip('):'))
135
+ # emax[m].append([ep, s[2], s[0]])
136
+
137
+ # for m, e in emax.items():
138
+ # plot_name = m[:-1]
139
+ # print('Saving {} ...'.format(plot_name))
140
+ # e = np.array(e)
141
+ # e_coca = e[e[:, -1] == 'CoCA']
142
+ # e_cosod = e[e[:, -1] == 'CoSOD3k']
143
+ # e_cosal = e[e[:, -1] == 'CoSal2015']
144
+ # eps = sorted(list(set(e_coca[:, 0].astype(float))))
145
+
146
+ # e_coca = np.array(sorted(e_coca, key=lambda x: int(x[0])))[:, 1].astype(float)
147
+ # e_cosod = np.array(sorted(e_cosod, key=lambda x: int(x[0])))[:, 1].astype(float)
148
+ # e_cosal = np.array(sorted(e_cosal, key=lambda x: int(x[0])))[:, 1].astype(float)
149
+
150
+ # plt.figure()
151
+ # plt.plot(eps, e_coca)
152
+ # plt.plot(eps, e_cosod)
153
+ # plt.plot(eps, e_cosal)
154
+ # plt.legend(['CoCA', 'CoSOD3k', 'CoSal2015'])
155
+ # plt.title(m)
156
+ # plt.savefig('{}.png'.format(plot_name))
test.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from PIL import Image
2
+ from dataset import get_loader
3
+ import torch
4
+ from torchvision import transforms
5
+ # from util import save_tensor_img, Logger
6
+ from tqdm import tqdm
7
+ from torch import nn
8
+ import os
9
+ from models.main import *
10
+ import argparse
11
+ # import numpy as np
12
+ # import cv2
13
+ # from skimage import img_as_ubyte
14
+
15
+
16
+ def main(args):
17
+ # Init model
18
+
19
+ device = torch.device("cuda")
20
+ model = DCFM()
21
+ model = model.to(device)
22
+ try:
23
+ # modelname = os.path.join(args.param_root, 'best_ep198_Smeasure0.7019.pth')
24
+ modelname = "/scratch/wej36how/codes/DCFM-master/best_ep12_Smeasure0.7256.pth"
25
+ dcfmnet_dict = torch.load(modelname)
26
+ print('loaded', modelname)
27
+ except:
28
+ dcfmnet_dict = torch.load(os.path.join(args.param_root, 'dcfm.pth'))
29
+
30
+ model.to(device)
31
+ model.dcfmnet.load_state_dict(dcfmnet_dict)
32
+ model.eval()
33
+ model.set_mode('test')
34
+
35
+ tensor2pil = transforms.ToPILImage()
36
+ for testset in ['NWRD']:
37
+ if testset == 'CoCA':
38
+ test_img_path = './data/images/CoCA/'
39
+ test_gt_path = './data/gts/CoCA/'
40
+ saved_root = os.path.join(args.save_root, 'CoCA')
41
+ elif testset == 'CoSOD3k':
42
+ test_img_path = './data/images/CoSOD3k/'
43
+ test_gt_path = './data/gts/CoSOD3k/'
44
+ saved_root = os.path.join(args.save_root, 'CoSOD3k')
45
+ elif testset == 'CoSal2015':
46
+ test_img_path = './data/images/CoSal2015/'
47
+ test_gt_path = './data/gts/CoSal2015/'
48
+ saved_root = os.path.join(args.save_root, 'CoSal2015')
49
+ elif testset == 'NWRD':
50
+ test_img_path = '/home/wej36how/codes/crossvit/results/nwrd22/images/'
51
+ test_gt_path = '/home/wej36how/codes/crossvit/results/nwrd22/images/'
52
+ saved_root = os.path.join(args.save_root, 'NWRD')
53
+ else:
54
+ print('Unkonwn test dataset')
55
+ print(args.dataset)
56
+
57
+ test_loader = get_loader(
58
+ test_img_path, test_gt_path, args.size, 1, istrain=False, shuffle=False, num_workers=8, pin=True)
59
+
60
+ for batch in tqdm(test_loader):
61
+ inputs = batch[0].to(device).squeeze(0)
62
+ gts = batch[1].to(device).squeeze(0)
63
+ subpaths = batch[2]
64
+ ori_sizes = batch[3]
65
+ scaled_preds= model(inputs, gts)
66
+ scaled_preds = torch.sigmoid(scaled_preds[-1])
67
+ os.makedirs(os.path.join(saved_root, subpaths[0][0].split('/')[0]), exist_ok=True)
68
+ num = gts.shape[0]
69
+ for inum in range(num):
70
+ subpath = subpaths[inum][0]
71
+ ori_size = (ori_sizes[inum][0].item(), ori_sizes[inum][1].item())
72
+ res = nn.functional.interpolate(scaled_preds[inum].unsqueeze(0), size=ori_size, mode='bilinear', align_corners=True)
73
+ save_tensor_img(res, os.path.join(saved_root, subpath))
74
+
75
+
76
+ if __name__ == '__main__':
77
+ # Parameter from command line
78
+ parser = argparse.ArgumentParser(description='')
79
+ parser.add_argument('--size',
80
+ default=224,
81
+ type=int,
82
+ help='input size')
83
+ parser.add_argument('--param_root', default='/data1/dcfm/temp', type=str, help='model folder')
84
+ parser.add_argument('--save_root', default='./CoSODmaps/pred', type=str, help='Output folder')
85
+
86
+ args = parser.parse_args()
87
+
88
+ main(args)
89
+
90
+
91
+
train.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from util import Logger, AverageMeter, save_checkpoint, save_tensor_img, set_seed
5
+ import os
6
+ import numpy as np
7
+ from matplotlib import pyplot as plt
8
+ import time
9
+ import argparse
10
+ from tqdm import tqdm
11
+ from dataset import get_loader
12
+ from loss import *
13
+ from config import Config
14
+ from evaluation.dataloader import EvalDataset
15
+ from evaluation.evaluator import Eval_thread
16
+
17
+
18
+ from models.main import *
19
+
20
+ import torch.nn.functional as F
21
+ import pytorch_toolbelt.losses as PTL
22
+
23
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
24
+ # Parameter from command line
25
+ parser = argparse.ArgumentParser(description='')
26
+
27
+ parser.add_argument('--loss',
28
+ default='Scale_IoU',
29
+ type=str,
30
+ help="Options: '', ''")
31
+ parser.add_argument('--bs', '--batch_size', default=1, type=int)
32
+ parser.add_argument('--lr',
33
+ '--learning_rate',
34
+ default=1e-4,
35
+ type=float,
36
+ help='Initial learning rate')
37
+ parser.add_argument('--resume',
38
+ default=None,
39
+ type=str,
40
+ help='path to latest checkpoint')
41
+ parser.add_argument('--epochs', default=200, type=int)
42
+ parser.add_argument('--start_epoch',
43
+ default=0,
44
+ type=int,
45
+ help='manual epoch number (useful on restarts)')
46
+ parser.add_argument('--trainset',
47
+ default='CoCo',
48
+ type=str,
49
+ help="Options: 'CoCo'")
50
+ parser.add_argument('--testsets',
51
+ default='CoCA',
52
+ type=str,
53
+ help="Options: 'CoCA','CoSal2015','CoSOD3k','iCoseg','MSRC'")
54
+ parser.add_argument('--size',
55
+ default=224,
56
+ type=int,
57
+ help='input size')
58
+ parser.add_argument('--tmp', default='/data1/dcfm/temp', help='Temporary folder')
59
+ parser.add_argument('--save_root', default='./CoSODmaps/pred', type=str, help='Output folder')
60
+
61
+ args = parser.parse_args()
62
+ config = Config()
63
+
64
+ # Prepare dataset
65
+ if args.trainset == 'CoCo':
66
+ train_img_path = './data/CoCo/img/'
67
+ train_gt_path = './data/CoCo/gt/'
68
+ train_loader = get_loader(train_img_path,
69
+ train_gt_path,
70
+ args.size,
71
+ args.bs,
72
+ max_num=16, #20,
73
+ istrain=True,
74
+ shuffle=False,
75
+ num_workers=8, #4,
76
+ pin=True)
77
+
78
+ else:
79
+ print('Unkonwn train dataset')
80
+ print(args.dataset)
81
+
82
+ for testset in ['CoCA']:
83
+ if testset == 'CoCA':
84
+ test_img_path = './data/images/CoCA/'
85
+ test_gt_path = './data/gts/CoCA/'
86
+
87
+ saved_root = os.path.join(args.save_root, 'CoCA')
88
+ elif testset == 'CoSOD3k':
89
+ test_img_path = './data/images/CoSOD3k/'
90
+ test_gt_path = './data/gts/CoSOD3k/'
91
+ saved_root = os.path.join(args.save_root, 'CoSOD3k')
92
+ elif testset == 'CoSal2015':
93
+ test_img_path = './data/images/CoSal2015/'
94
+ test_gt_path = './data/gts/CoSal2015/'
95
+ saved_root = os.path.join(args.save_root, 'CoSal2015')
96
+ elif testset == 'CoCo':
97
+ test_img_path = './data/images/CoCo/'
98
+ test_gt_path = './data/gts/CoCo/'
99
+ saved_root = os.path.join(args.save_root, 'CoCo')
100
+ else:
101
+ print('Unkonwn test dataset')
102
+ print(args.dataset)
103
+
104
+ test_loader = get_loader(
105
+ test_img_path, test_gt_path, args.size, 1, istrain=False, shuffle=False, num_workers=8, pin=True)
106
+
107
+ # make dir for tmp
108
+ os.makedirs(args.tmp, exist_ok=True)
109
+
110
+ # Init log file
111
+ logger = Logger(os.path.join(args.tmp, "log.txt"))
112
+ set_seed(123)
113
+
114
+ # Init model
115
+ device = torch.device("cuda")
116
+
117
+ model = DCFM()
118
+ model = model.to(device)
119
+ model.apply(weights_init)
120
+
121
+ model.dcfmnet.backbone._initialize_weights(torch.load('./models/vgg16-397923af.pth'))
122
+
123
+ backbone_params = list(map(id, model.dcfmnet.backbone.parameters()))
124
+ base_params = filter(lambda p: id(p) not in backbone_params,
125
+ model.dcfmnet.parameters())
126
+
127
+ all_params = [{'params': base_params}, {'params': model.dcfmnet.backbone.parameters(), 'lr': args.lr*0.1}]
128
+
129
+ # Setting optimizer
130
+ optimizer = optim.Adam(params=all_params,lr=args.lr, weight_decay=1e-4, betas=[0.9, 0.99])
131
+
132
+ for key, value in model.named_parameters():
133
+ if 'dcfmnet.backbone' in key and 'dcfmnet.backbone.conv5.conv5_3' not in key:
134
+ value.requires_grad = False
135
+
136
+ for key, value in model.named_parameters():
137
+ print(key, value.requires_grad)
138
+
139
+ # log model and optimizer pars
140
+ logger.info("Model details:")
141
+ logger.info(model)
142
+ logger.info("Optimizer details:")
143
+ logger.info(optimizer)
144
+ logger.info("Scheduler details:")
145
+ # logger.info(scheduler)
146
+ logger.info("Other hyperparameters:")
147
+ logger.info(args)
148
+
149
+ # Setting Loss
150
+ exec('from loss import ' + args.loss)
151
+ IOUloss = eval(args.loss+'()')
152
+
153
+
154
+ def main():
155
+ val_measures = []
156
+ # Optionally resume from a checkpoint
157
+ if args.resume:
158
+ if os.path.isfile(args.resume):
159
+ logger.info("=> loading checkpoint '{}'".format(args.resume))
160
+ checkpoint = torch.load(args.resume)
161
+ args.start_epoch = checkpoint['epoch']
162
+ model.dcfmnet.load_state_dict(checkpoint['state_dict'])
163
+ optimizer.load_state_dict(checkpoint['optimizer'])
164
+ logger.info("=> loaded checkpoint '{}' (epoch {})".format(
165
+ args.resume, checkpoint['epoch']))
166
+ else:
167
+ logger.info("=> no checkpoint found at '{}'".format(args.resume))
168
+
169
+ print(args.epochs)
170
+ for epoch in range(args.start_epoch, args.epochs):
171
+ train_loss = train(epoch)
172
+ if config.validation:
173
+ measures = validate(model, test_loader, args.testsets)
174
+ val_measures.append(measures)
175
+ print(
176
+ 'Validation: S_measure on CoCA for epoch-{} is {:.4f}. Best epoch is epoch-{} with S_measure {:.4f}'.format(
177
+ epoch, measures[0], np.argmax(np.array(val_measures)[:, 0].squeeze()),
178
+ np.max(np.array(val_measures)[:, 0]))
179
+ )
180
+ # Save checkpoint
181
+ save_checkpoint(
182
+ {
183
+ 'epoch': epoch + 1,
184
+ 'state_dict': model.dcfmnet.state_dict(),
185
+ #'scheduler': scheduler.state_dict(),
186
+ },
187
+ path=args.tmp)
188
+ if config.validation:
189
+ if np.max(np.array(val_measures)[:, 0].squeeze()) == measures[0]:
190
+ best_weights_before = [os.path.join(args.tmp, weight_file) for weight_file in
191
+ os.listdir(args.tmp) if 'best_' in weight_file]
192
+ for best_weight_before in best_weights_before:
193
+ os.remove(best_weight_before)
194
+ torch.save(model.dcfmnet.state_dict(),
195
+ os.path.join(args.tmp, 'best_ep{}_Smeasure{:.4f}.pth'.format(epoch, measures[0])))
196
+ if (epoch + 1) % 10 == 0 or epoch == 0:
197
+ torch.save(model.dcfmnet.state_dict(), args.tmp + '/model-' + str(epoch + 1) + '.pt')
198
+
199
+ if epoch > 188:
200
+ torch.save(model.dcfmnet.state_dict(), args.tmp+'/model-' + str(epoch + 1) + '.pt')
201
+ #dcfmnet_dict = model.dcfmnet.state_dict()
202
+ #torch.save(dcfmnet_dict, os.path.join(args.tmp, 'final.pth'))
203
+
204
+ def sclloss(x, xt, xb):
205
+ cosc = (1+compute_cos_dis(x, xt))*0.5
206
+ cosb = (1+compute_cos_dis(x, xb))*0.5
207
+ loss = -torch.log(cosc+1e-5)-torch.log(1-cosb+1e-5)
208
+ return loss.sum()
209
+
210
+ def train(epoch):
211
+ # Switch to train mode
212
+ model.train()
213
+ model.set_mode('train')
214
+ loss_sum = 0.0
215
+ loss_sumkl = 0.0
216
+ for batch_idx, batch in enumerate(train_loader):
217
+ inputs = batch[0].to(device).squeeze(0)
218
+ gts = batch[1].to(device).squeeze(0)
219
+ pred, proto, protogt, protobg = model(inputs, gts)
220
+ loss_iou = IOUloss(pred, gts)
221
+ loss_scl = sclloss(proto, protogt, protobg)
222
+ loss = loss_iou+0.1*loss_scl
223
+ optimizer.zero_grad()
224
+ loss.backward()
225
+ optimizer.step()
226
+ loss_sum = loss_sum + loss_iou.detach().item()
227
+
228
+ if batch_idx % 20 == 0:
229
+ logger.info('Epoch[{0}/{1}] Iter[{2}/{3}] '
230
+ 'Train Loss: loss_iou: {4:.3f}, loss_scl: {5:.3f} '.format(
231
+ epoch,
232
+ args.epochs,
233
+ batch_idx,
234
+ len(train_loader),
235
+ loss_iou,
236
+ loss_scl,
237
+ ))
238
+ loss_mean = loss_sum / len(train_loader)
239
+ return loss_sum
240
+
241
+
242
+ def validate(model, test_loaders, testsets):
243
+ model.eval()
244
+
245
+ testsets = testsets.split('+')
246
+ measures = []
247
+ for testset in testsets[:1]:
248
+ print('Validating {}...'.format(testset))
249
+ #test_loader = test_loaders[testset]
250
+
251
+ saved_root = os.path.join(args.save_root, testset)
252
+
253
+ for batch in test_loader:
254
+ inputs = batch[0].to(device).squeeze(0)
255
+ gts = batch[1].to(device).squeeze(0)
256
+ subpaths = batch[2]
257
+ ori_sizes = batch[3]
258
+ with torch.no_grad():
259
+ scaled_preds = model(inputs, gts)[-1].sigmoid()
260
+
261
+ os.makedirs(os.path.join(saved_root, subpaths[0][0].split('/')[0]), exist_ok=True)
262
+
263
+ num = len(scaled_preds)
264
+ for inum in range(num):
265
+ subpath = subpaths[inum][0]
266
+ ori_size = (ori_sizes[inum][0].item(), ori_sizes[inum][1].item())
267
+ res = nn.functional.interpolate(scaled_preds[inum].unsqueeze(0), size=ori_size, mode='bilinear',
268
+ align_corners=True)
269
+ save_tensor_img(res, os.path.join(saved_root, subpath))
270
+
271
+ eval_loader = EvalDataset(
272
+ saved_root, # preds
273
+ os.path.join('./data/gts', testset) # GT
274
+ )
275
+ evaler = Eval_thread(eval_loader, cuda=True)
276
+ # Use S_measure for validation
277
+ s_measure = evaler.Eval_Smeasure()
278
+ if s_measure > config.val_measures['Smeasure']['CoCA'] and 0:
279
+ # TODO: evluate others measures if s_measure is very high.
280
+ e_max = evaler.Eval_Emeasure().max().item()
281
+ f_max = evaler.Eval_fmeasure().max().item()
282
+ print('Emax: {:4.f}, Fmax: {:4.f}'.format(e_max, f_max))
283
+ measures.append(s_measure)
284
+
285
+ model.train()
286
+ return measures
287
+
288
+ if __name__ == '__main__':
289
+ main()
train_wandb.ipynb ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "/home/wej36how/.conda/envs/vit/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13
+ " from .autonotebook import tqdm as notebook_tqdm\n"
14
+ ]
15
+ }
16
+ ],
17
+ "source": [
18
+ "import torch\n",
19
+ "from torch.utils.data import DataLoader\n",
20
+ "from transformers import AdamW, ViTImageProcessor, ViTForImageClassification\n",
21
+ "from NWRD_dataset import NWRD\n",
22
+ "from tqdm import tqdm\n",
23
+ "import numpy as np\n",
24
+ "import torch.nn.functional as F\n",
25
+ "import os\n",
26
+ "import torch.optim as optim\n",
27
+ "from torchvision import transforms\n"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": 2,
33
+ "metadata": {},
34
+ "outputs": [],
35
+ "source": [
36
+ "seed = 42\n",
37
+ "torch.manual_seed(seed)\n",
38
+ "np.random.seed(seed)\n",
39
+ "# If you are using CUDA, set this for further deterministic behavior\n",
40
+ "if torch.cuda.is_available():\n",
41
+ " torch.cuda.manual_seed(seed)\n",
42
+ " torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.\n",
43
+ " # Below settings are recommended for deterministic behavior when using specific convolution operations,\n",
44
+ " # but may reduce performance\n",
45
+ " torch.backends.cudnn.deterministic = True\n",
46
+ " torch.backends.cudnn.benchmark = False"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": 3,
52
+ "metadata": {},
53
+ "outputs": [
54
+ {
55
+ "name": "stdout",
56
+ "output_type": "stream",
57
+ "text": [
58
+ "cpu\n"
59
+ ]
60
+ }
61
+ ],
62
+ "source": [
63
+ "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
64
+ "CUDA_LAUNCH_BLOCKING=1\n",
65
+ "TORCH_USE_CUDA_DSA=1\n",
66
+ "print(device)"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": 4,
72
+ "metadata": {},
73
+ "outputs": [],
74
+ "source": [
75
+ "transformations = transforms.Compose([\n",
76
+ " transforms.Resize((224, 224)), # Resize the image to 224x224\n",
77
+ " transforms.ToTensor() # Convert the image to a PyTorch tensor\n",
78
+ "])"
79
+ ]
80
+ },
81
+ {
82
+ "cell_type": "code",
83
+ "execution_count": 5,
84
+ "metadata": {},
85
+ "outputs": [
86
+ {
87
+ "ename": "FileNotFoundError",
88
+ "evalue": "[Errno 2] No such file or directory: 'C:\\\\Users\\\\hasee\\\\Desktop\\\\Germany_2024\\\\Dataset\\\\NWRDprocessed\\\\train\\\\calssification/rust'",
89
+ "output_type": "error",
90
+ "traceback": [
91
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
92
+ "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
93
+ "Cell \u001b[0;32mIn[5], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m train_ds \u001b[38;5;241m=\u001b[39m \u001b[43mNWRD\u001b[49m\u001b[43m(\u001b[49m\u001b[43mroot_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mC:\u001b[39;49m\u001b[38;5;130;43;01m\\\\\u001b[39;49;00m\u001b[38;5;124;43mUsers\u001b[39;49m\u001b[38;5;130;43;01m\\\\\u001b[39;49;00m\u001b[38;5;124;43mhasee\u001b[39;49m\u001b[38;5;130;43;01m\\\\\u001b[39;49;00m\u001b[38;5;124;43mDesktop\u001b[39;49m\u001b[38;5;130;43;01m\\\\\u001b[39;49;00m\u001b[38;5;124;43mGermany_2024\u001b[39;49m\u001b[38;5;130;43;01m\\\\\u001b[39;49;00m\u001b[38;5;124;43mDataset\u001b[39;49m\u001b[38;5;130;43;01m\\\\\u001b[39;49;00m\u001b[38;5;124;43mNWRDprocessed\u001b[39;49m\u001b[38;5;130;43;01m\\\\\u001b[39;49;00m\u001b[38;5;124;43mtrain\u001b[39;49m\u001b[38;5;130;43;01m\\\\\u001b[39;49;00m\u001b[38;5;124;43mcalssification\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtransform\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtransformations\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2\u001b[0m val_ds \u001b[38;5;241m=\u001b[39m NWRD(root_dir\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mC:\u001b[39m\u001b[38;5;130;01m\\\\\u001b[39;00m\u001b[38;5;124mUsers\u001b[39m\u001b[38;5;130;01m\\\\\u001b[39;00m\u001b[38;5;124mhasee\u001b[39m\u001b[38;5;130;01m\\\\\u001b[39;00m\u001b[38;5;124mDesktop\u001b[39m\u001b[38;5;130;01m\\\\\u001b[39;00m\u001b[38;5;124mGermany_2024\u001b[39m\u001b[38;5;130;01m\\\\\u001b[39;00m\u001b[38;5;124mDataset\u001b[39m\u001b[38;5;130;01m\\\\\u001b[39;00m\u001b[38;5;124mNWRDprocessed\u001b[39m\u001b[38;5;130;01m\\\\\u001b[39;00m\u001b[38;5;124mval\u001b[39m\u001b[38;5;130;01m\\\\\u001b[39;00m\u001b[38;5;124mcalssification\u001b[39m\u001b[38;5;124m\"\u001b[39m, train\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, transform\u001b[38;5;241m=\u001b[39mtransformations)\n\u001b[1;32m 4\u001b[0m train_loader \u001b[38;5;241m=\u001b[39m DataLoader(train_ds, batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m8\u001b[39m, shuffle\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n",
94
+ "File \u001b[0;32m~/codes/crossvit/NWRD_dataset.py:12\u001b[0m, in \u001b[0;36mNWRD.__init__\u001b[0;34m(self, root_dir, transform, train)\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mimages \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 11\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlabels \u001b[38;5;241m=\u001b[39m []\n\u001b[0;32m---> 12\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
95
+ "File \u001b[0;32m~/codes/crossvit/NWRD_dataset.py:19\u001b[0m, in \u001b[0;36mNWRD.load_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 16\u001b[0m non_rust_dir \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mroot_dir, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnon_rust\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 18\u001b[0m \u001b[38;5;66;03m# Load rust images\u001b[39;00m\n\u001b[0;32m---> 19\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m filename \u001b[38;5;129;01min\u001b[39;00m \u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlistdir\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrust_dir\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[1;32m 20\u001b[0m filepath \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(rust_dir, filename)\n\u001b[1;32m 21\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mimages\u001b[38;5;241m.\u001b[39mappend(filepath)\n",
96
+ "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'C:\\\\Users\\\\hasee\\\\Desktop\\\\Germany_2024\\\\Dataset\\\\NWRDprocessed\\\\train\\\\calssification/rust'"
97
+ ]
98
+ }
99
+ ],
100
+ "source": [
101
+ "train_ds = NWRD(root_dir=\"C:\\\\Users\\\\hasee\\\\Desktop\\\\Germany_2024\\\\Dataset\\\\NWRDprocessed\\\\train\\\\calssification\", train=True, transform=transformations)\n",
102
+ "val_ds = NWRD(root_dir=\"C:\\\\Users\\\\hasee\\\\Desktop\\\\Germany_2024\\\\Dataset\\\\NWRDprocessed\\\\val\\\\calssification\", train=False, transform=transformations)\n",
103
+ " \n",
104
+ "train_loader = DataLoader(train_ds, batch_size=8, shuffle=True)\n",
105
+ "val_loader = DataLoader(val_ds, batch_size=8, shuffle=True)"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "code",
110
+ "execution_count": 6,
111
+ "metadata": {},
112
+ "outputs": [],
113
+ "source": [
114
+ "mean = [0.485, 0.456, 0.406] # Mean values for RGB channels\n",
115
+ "std = [0.229, 0.224, 0.225] # Standard deviation values for RGB channels\n",
116
+ "#processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224',transform={'mean': mean, 'std': std})\n",
117
+ "processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')\n",
118
+ "model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')\n",
119
+ "# processor.image_mean=mean\n",
120
+ "# processor.image_std=std\n",
121
+ "#print(processor)"
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "code",
126
+ "execution_count": 8,
127
+ "metadata": {},
128
+ "outputs": [
129
+ {
130
+ "data": {
131
+ "text/plain": [
132
+ "ViTForImageClassification(\n",
133
+ " (vit): ViTModel(\n",
134
+ " (embeddings): ViTEmbeddings(\n",
135
+ " (patch_embeddings): ViTPatchEmbeddings(\n",
136
+ " (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))\n",
137
+ " )\n",
138
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
139
+ " )\n",
140
+ " (encoder): ViTEncoder(\n",
141
+ " (layer): ModuleList(\n",
142
+ " (0-11): 12 x ViTLayer(\n",
143
+ " (attention): ViTSdpaAttention(\n",
144
+ " (attention): ViTSdpaSelfAttention(\n",
145
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
146
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
147
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
148
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
149
+ " )\n",
150
+ " (output): ViTSelfOutput(\n",
151
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
152
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
153
+ " )\n",
154
+ " )\n",
155
+ " (intermediate): ViTIntermediate(\n",
156
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
157
+ " (intermediate_act_fn): GELUActivation()\n",
158
+ " )\n",
159
+ " (output): ViTOutput(\n",
160
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
161
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
162
+ " )\n",
163
+ " (layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
164
+ " (layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
165
+ " )\n",
166
+ " )\n",
167
+ " )\n",
168
+ " (layernorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
169
+ " )\n",
170
+ " (classifier): Linear(in_features=768, out_features=2, bias=True)\n",
171
+ ")"
172
+ ]
173
+ },
174
+ "execution_count": 8,
175
+ "metadata": {},
176
+ "output_type": "execute_result"
177
+ }
178
+ ],
179
+ "source": [
180
+ "model.classifier = torch.nn.Linear(model.config.hidden_size, 2)\n",
181
+ "model.to(device)"
182
+ ]
183
+ },
184
+ {
185
+ "cell_type": "markdown",
186
+ "metadata": {},
187
+ "source": [
188
+ "Finetuning of the model based on pretraining weights."
189
+ ]
190
+ },
191
+ {
192
+ "cell_type": "code",
193
+ "execution_count": 8,
194
+ "metadata": {},
195
+ "outputs": [],
196
+ "source": [
197
+ "# model_weights = torch.load('/home/Hirra/coding_files/crossvit/weights/wandb_vit_base_final_med_val_NWRD_epoch_50_lr_0.000000001_wd_0.001_batch_size_8_unaugmented_unequlaized/49.pth')\n",
198
+ "# model.load_state_dict(model_weights.state_dict())"
199
+ ]
200
+ },
201
+ {
202
+ "cell_type": "code",
203
+ "execution_count": 9,
204
+ "metadata": {},
205
+ "outputs": [],
206
+ "source": [
207
+ "optimizer = optim.SGD(model.parameters(), lr=0.00000003, weight_decay=0.001)\n",
208
+ "criterion = torch.nn.CrossEntropyLoss()\n",
209
+ "weights_directory = 'wandb_vit_base_final_for_time_NWRD_epoch_50_lr_0.000000003_wd_0.001_batch_size_8_unaugmented_training'\n",
210
+ "weight_loc = f\"weights/{weights_directory}\"\n",
211
+ "os.makedirs(weight_loc, exist_ok=True)"
212
+ ]
213
+ },
214
+ {
215
+ "cell_type": "code",
216
+ "execution_count": 10,
217
+ "metadata": {},
218
+ "outputs": [
219
+ {
220
+ "name": "stderr",
221
+ "output_type": "stream",
222
+ "text": [
223
+ "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n",
224
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mgptautomated\u001b[0m (\u001b[33mtukl_labwork\u001b[0m). Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n",
225
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m If you're specifying your api key in code, ensure this code is not shared publicly.\n",
226
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Consider setting the WANDB_API_KEY environment variable, or running `wandb login` from the command line.\n",
227
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: C:\\Users\\hasee\\.netrc\n"
228
+ ]
229
+ },
230
+ {
231
+ "data": {
232
+ "text/plain": [
233
+ "True"
234
+ ]
235
+ },
236
+ "execution_count": 10,
237
+ "metadata": {},
238
+ "output_type": "execute_result"
239
+ }
240
+ ],
241
+ "source": [
242
+ "import wandb, os\n",
243
+ "#wandb.login()\n",
244
+ "wandb.login(key=\"4e8a21c26ae61cced8d70053c80bbe1b112fec12\")\n",
245
+ "#4e8a21c26ae61cced8d70053c80bbe1b112fec12"
246
+ ]
247
+ },
248
+ {
249
+ "cell_type": "code",
250
+ "execution_count": 11,
251
+ "metadata": {},
252
+ "outputs": [
253
+ {
254
+ "name": "stdout",
255
+ "output_type": "stream",
256
+ "text": [
257
+ "env: WANDB_PROJECT=crossvit_rust_classifier_new\n"
258
+ ]
259
+ }
260
+ ],
261
+ "source": [
262
+ "%env WANDB_PROJECT=crossvit_rust_classifier_new\n",
263
+ "os.environ[\"WANDB_PROJECT\"] = \"<crossvit>\"\n",
264
+ "os.environ[\"WANDB_REPORT_TO\"] = \"wandb\""
265
+ ]
266
+ },
267
+ {
268
+ "cell_type": "code",
269
+ "execution_count": 12,
270
+ "metadata": {},
271
+ "outputs": [
272
+ {
273
+ "data": {
274
+ "text/html": [
275
+ "Changes to your `wandb` environment variables will be ignored because your `wandb` session has already started. For more information on how to modify your settings with `wandb.init()` arguments, please refer to <a href='https://wandb.me/wandb-init' target=\"_blank\">the W&B docs</a>."
276
+ ],
277
+ "text/plain": [
278
+ "<IPython.core.display.HTML object>"
279
+ ]
280
+ },
281
+ "metadata": {},
282
+ "output_type": "display_data"
283
+ },
284
+ {
285
+ "data": {
286
+ "text/html": [
287
+ "wandb version 0.17.3 is available! To upgrade, please run:\n",
288
+ " $ pip install wandb --upgrade"
289
+ ],
290
+ "text/plain": [
291
+ "<IPython.core.display.HTML object>"
292
+ ]
293
+ },
294
+ "metadata": {},
295
+ "output_type": "display_data"
296
+ },
297
+ {
298
+ "data": {
299
+ "text/html": [
300
+ "Tracking run with wandb version 0.17.2"
301
+ ],
302
+ "text/plain": [
303
+ "<IPython.core.display.HTML object>"
304
+ ]
305
+ },
306
+ "metadata": {},
307
+ "output_type": "display_data"
308
+ },
309
+ {
310
+ "data": {
311
+ "text/html": [
312
+ "Run data is saved locally in <code>c:\\Users\\hasee\\Desktop\\Germany_2024\\codes\\crossvit\\wandb\\run-20240626_161631-bgtm3oyt</code>"
313
+ ],
314
+ "text/plain": [
315
+ "<IPython.core.display.HTML object>"
316
+ ]
317
+ },
318
+ "metadata": {},
319
+ "output_type": "display_data"
320
+ },
321
+ {
322
+ "data": {
323
+ "text/html": [
324
+ "Syncing run <strong><a href='https://wandb.ai/tukl_labwork/uncategorized/runs/bgtm3oyt' target=\"_blank\">glamorous-wood-74</a></strong> to <a href='https://wandb.ai/tukl_labwork/uncategorized' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
325
+ ],
326
+ "text/plain": [
327
+ "<IPython.core.display.HTML object>"
328
+ ]
329
+ },
330
+ "metadata": {},
331
+ "output_type": "display_data"
332
+ },
333
+ {
334
+ "data": {
335
+ "text/html": [
336
+ " View project at <a href='https://wandb.ai/tukl_labwork/uncategorized' target=\"_blank\">https://wandb.ai/tukl_labwork/uncategorized</a>"
337
+ ],
338
+ "text/plain": [
339
+ "<IPython.core.display.HTML object>"
340
+ ]
341
+ },
342
+ "metadata": {},
343
+ "output_type": "display_data"
344
+ },
345
+ {
346
+ "data": {
347
+ "text/html": [
348
+ " View run at <a href='https://wandb.ai/tukl_labwork/uncategorized/runs/bgtm3oyt' target=\"_blank\">https://wandb.ai/tukl_labwork/uncategorized/runs/bgtm3oyt</a>"
349
+ ],
350
+ "text/plain": [
351
+ "<IPython.core.display.HTML object>"
352
+ ]
353
+ },
354
+ "metadata": {},
355
+ "output_type": "display_data"
356
+ },
357
+ {
358
+ "name": "stderr",
359
+ "output_type": "stream",
360
+ "text": [
361
+ " 0%| | 0/241 [00:00<?, ?it/s]c:\\Users\\hasee\\miniconda3\\envs\\segformer\\Lib\\site-packages\\transformers\\models\\vit\\modeling_vit.py:253: UserWarning: 1Torch was not compiled with flash attention. (Triggered internally at C:\\cb\\pytorch_1000000000000\\work\\aten\\src\\ATen\\native\\transformers\\cuda\\sdp_utils.cpp:455.)\n",
362
+ " context_layer = torch.nn.functional.scaled_dot_product_attention(\n",
363
+ "Epoch 0 train Loss 0.6551: 21%|██ | 51/241 [00:27<01:42, 1.85it/s]\n"
364
+ ]
365
+ },
366
+ {
367
+ "ename": "KeyboardInterrupt",
368
+ "evalue": "",
369
+ "output_type": "error",
370
+ "traceback": [
371
+ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
372
+ "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
373
+ "Cell \u001b[1;32mIn[12], line 22\u001b[0m\n\u001b[0;32m 17\u001b[0m \u001b[38;5;66;03m# print(\"logits\", logits)\u001b[39;00m\n\u001b[0;32m 18\u001b[0m \u001b[38;5;66;03m# print(\"prediction\", predication)\u001b[39;00m\n\u001b[0;32m 19\u001b[0m \u001b[38;5;66;03m# print(\"labels\", labels)\u001b[39;00m\n\u001b[0;32m 21\u001b[0m loss \u001b[38;5;241m=\u001b[39m criterion(logits, labels)\n\u001b[1;32m---> 22\u001b[0m train_losses\u001b[38;5;241m.\u001b[39mappend(\u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mitem\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[0;32m 23\u001b[0m loss\u001b[38;5;241m.\u001b[39mbackward()\n\u001b[0;32m 24\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mstep()\n",
374
+ "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
375
+ ]
376
+ }
377
+ ],
378
+ "source": [
379
+ "wandb.init()\n",
380
+ "\n",
381
+ "best_epoch = {}\n",
382
+ "train_losses = []\n",
383
+ "for epoch in range(50):\n",
384
+ " model.train\n",
385
+ " train_losses=[]\n",
386
+ " loop = tqdm(enumerate(train_loader), total=len(train_loader))\n",
387
+ " for batch_idx, (images, labels) in loop:\n",
388
+ " inputs = processor(images=images, return_tensors=\"pt\", do_rescale=False).to(device)\n",
389
+ " labels = labels.to(device)\n",
390
+ "\n",
391
+ " outputs = model(**inputs)\n",
392
+ " logits = outputs.logits\n",
393
+ " predication = logits.argmax(axis=1)\n",
394
+ " \n",
395
+ " # print(\"logits\", logits)\n",
396
+ " # print(\"prediction\", predication)\n",
397
+ " # print(\"labels\", labels)\n",
398
+ " \n",
399
+ " loss = criterion(logits, labels)\n",
400
+ " train_losses.append(loss.item())\n",
401
+ " loss.backward()\n",
402
+ " optimizer.step()\n",
403
+ " loop.set_description(f\"Epoch {epoch} train Loss {np.mean(train_losses):.4f}\")\n",
404
+ "\n",
405
+ "\n",
406
+ " print(\"Epoch \"+str(epoch)+\" Train Loss \"+str(np.mean(train_losses)))\n",
407
+ " torch.save(model, weight_loc+'/{}.pth'.format(epoch))\n",
408
+ " wandb.log({\"train_loss\": np.mean(train_losses), \"epoch\": epoch})\n",
409
+ "\n",
410
+ " #validation\n",
411
+ " optimizer.zero_grad()\n",
412
+ " model.eval\n",
413
+ " val_losses=[]\n",
414
+ "\n",
415
+ " loop = tqdm(enumerate(val_loader), total=len(val_loader))\n",
416
+ " with torch.no_grad():\n",
417
+ " for batch_idx, (images, labels) in loop:\n",
418
+ " inputs = processor(images=images, return_tensors=\"pt\", do_rescale=False).to(device)\n",
419
+ " labels = labels.to(device)\n",
420
+ "\n",
421
+ " outputs = model(**inputs)\n",
422
+ " logits = outputs.logits\n",
423
+ " \n",
424
+ " loss = criterion(logits, labels)\n",
425
+ " val_losses.append(loss.item())\n",
426
+ "\n",
427
+ " predication = logits.argmax(axis=1)\n",
428
+ "\n",
429
+ " loss = criterion(logits, labels)\n",
430
+ " val_losses.append(loss.item())\n",
431
+ " \n",
432
+ " loop.set_description(f\"Epoch {epoch} Val Loss {np.mean(val_losses):.4f}\")\n",
433
+ " wandb.log({\"val_loss\": np.mean(val_losses), \"epoch\": epoch})\n",
434
+ "torch.cuda.empty_cache()\n"
435
+ ]
436
+ }
437
+ ],
438
+ "metadata": {
439
+ "kernelspec": {
440
+ "display_name": "crossvit",
441
+ "language": "python",
442
+ "name": "python3"
443
+ },
444
+ "language_info": {
445
+ "codemirror_mode": {
446
+ "name": "ipython",
447
+ "version": 3
448
+ },
449
+ "file_extension": ".py",
450
+ "mimetype": "text/x-python",
451
+ "name": "python",
452
+ "nbconvert_exporter": "python",
453
+ "pygments_lexer": "ipython3",
454
+ "version": "3.8.19"
455
+ }
456
+ },
457
+ "nbformat": 4,
458
+ "nbformat_minor": 2
459
+ }
util.cpython-38.pyc ADDED
Binary file (3.5 kB). View file
 
util.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import torch
4
+ import shutil
5
+ from torchvision import transforms
6
+ import numpy as np
7
+ import random
8
+ import cv2
9
+
10
+
11
+ class Logger():
12
+ def __init__(self, path="log.txt"):
13
+ self.logger = logging.getLogger('DCFM')
14
+ self.file_handler = logging.FileHandler(path, "w")
15
+ self.stdout_handler = logging.StreamHandler()
16
+ self.stdout_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s'))
17
+ self.file_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s'))
18
+ self.logger.addHandler(self.file_handler)
19
+ self.logger.addHandler(self.stdout_handler)
20
+ self.logger.setLevel(logging.INFO)
21
+ self.logger.propagate = False
22
+
23
+ def info(self, txt):
24
+ self.logger.info(txt)
25
+
26
+ def close(self):
27
+ self.file_handler.close()
28
+ self.stdout_handler.close()
29
+
30
+ class AverageMeter(object):
31
+ """Computes and stores the average and current value"""
32
+ def __init__(self):
33
+ self.reset()
34
+
35
+ def reset(self):
36
+ self.val = 0.0
37
+ self.avg = 0.0
38
+ self.sum = 0.0
39
+ self.count = 0.0
40
+
41
+ def update(self, val, n=1):
42
+ self.val = val
43
+ self.sum += val * n
44
+ self.count += n
45
+ self.avg = self.sum / self.count
46
+
47
+
48
+ def save_checkpoint(state, path, filename="checkpoint.pth"):
49
+ torch.save(state, os.path.join(path, filename))
50
+
51
+
52
+ def save_tensor_img(tenor_im, path):
53
+ im = tenor_im.cpu().clone()
54
+ im = im.squeeze(0)
55
+ tensor2pil = transforms.ToPILImage()
56
+ im = tensor2pil(im)
57
+ im.save(path)
58
+
59
+
60
+ def save_tensor_merge(tenor_im, tensor_mask, path, colormap='HOT'):
61
+ im = tenor_im.cpu().detach().clone()
62
+ im = im.squeeze(0).numpy()
63
+ im = ((im - np.min(im)) / (np.max(im) - np.min(im) + 1e-20)) * 255
64
+ im = np.array(im,np.uint8)
65
+ mask = tensor_mask.cpu().detach().clone()
66
+ mask = mask.squeeze(0).numpy()
67
+ mask = ((mask - np.min(mask)) / (np.max(mask) - np.min(mask) + 1e-20)) * 255
68
+ mask = np.clip(mask, 0, 255)
69
+ mask = np.array(mask, np.uint8)
70
+ if colormap == 'HOT':
71
+ mask = cv2.applyColorMap(mask[0,:,:], cv2.COLORMAP_HOT)
72
+ elif colormap == 'PINK':
73
+ mask = cv2.applyColorMap(mask[0,:,:], cv2.COLORMAP_PINK)
74
+ elif colormap == 'BONE':
75
+ mask = cv2.applyColorMap(mask[0,:,:], cv2.COLORMAP_BONE)
76
+ # exec('cv2.applyColorMap(mask[0,:,:], cv2.COLORMAP_' + colormap+')')
77
+ im = im.transpose((1, 2, 0))
78
+ im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
79
+ mix = cv2.addWeighted(im, 0.3, mask, 0.7, 0)
80
+ cv2.imwrite(path, mix)
81
+
82
+
83
+ def set_seed(seed):
84
+ torch.manual_seed(seed)
85
+ torch.cuda.manual_seed(seed)
86
+ torch.cuda.manual_seed_all(seed)
87
+ np.random.seed(seed)
88
+ random.seed(seed)
89
+ torch.backends.cudnn.deterministic = True
90
+ torch.backends.cudnn.benchmark = False
91
+
92
+
utils.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
utils.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ import utils.utils as gen_utils
5
+ import numpy as np
6
+
7
+ def adjust_rate_poly(cur_iter, max_iter, power=0.9):
8
+ return (1.0 - 1.0 * cur_iter / max_iter) ** power
9
+
10
+ def adjust_learning_rate_exp(lr, optimizer, iters, decay_rate=0.1, decay_step=25):
11
+ lr = lr * (decay_rate ** (iters // decay_step))
12
+ for param_group in optimizer.param_groups:
13
+ param_group['lr'] = lr * param_group['lr_mult']
14
+
15
+ def adjust_learning_rate_RevGrad(lr, optimizer, max_iter, cur_iter,
16
+ alpha=10, beta=0.75):
17
+ p = 1.0 * cur_iter / (max_iter - 1)
18
+ lr = lr / pow(1.0 + alpha * p, beta)
19
+ for param_group in optimizer.param_groups:
20
+ param_group['lr'] = lr * param_group['lr_mult']
21
+
22
+ def adjust_learning_rate_inv(lr, optimizer, iters, alpha=0.001, beta=0.75):
23
+ lr = lr / pow(1.0 + alpha * iters, beta)
24
+ for param_group in optimizer.param_groups:
25
+ param_group['lr'] = lr * param_group['lr_mult']
26
+
27
+ def adjust_learning_rate_step(lr, optimizer, iters, steps, beta=0.1):
28
+ n = 0
29
+ for step in steps:
30
+ if iters < step:
31
+ break
32
+ n += 1
33
+
34
+ lr = lr * (beta ** n)
35
+ for param_group in optimizer.param_groups:
36
+ param_group['lr'] = lr * param_group['lr_mult']
37
+
38
+ def adjust_learning_rate_poly(lr, optimizer, iters, max_iter, power=0.9):
39
+ lr = lr * (1.0 - 1.0 * iters / max_iter) ** power
40
+ for param_group in optimizer.param_groups:
41
+ param_group['lr'] = lr * param_group['lr_mult']
42
+
43
+ def set_param_groups(net, lr_mult_dict={}):
44
+ params = []
45
+ if hasattr(net, "module"):
46
+ net = net.module
47
+
48
+ modules = net._modules
49
+ for name in modules:
50
+ module = modules[name]
51
+ if name in lr_mult_dict:
52
+ params += [{'params': module.parameters(), \
53
+ 'lr_mult': lr_mult_dict[name]}]
54
+ else:
55
+ params += [{'params': module.parameters(), 'lr_mult': 1.0}]
56
+
57
+ return params
58
+
59
+ def LSR(x, dim=1, thres=10.0):
60
+ lsr = -1.0 * torch.mean(x, dim=dim)
61
+ if thres > 0.0:
62
+ return torch.mean((lsr/thres-1.0).detach() * lsr)
63
+ else:
64
+ return torch.mean(lsr)
65
+
66
+ def crop(feats, preds, gt, h, w):
67
+ H, W = feats.shape[-2:]
68
+ tmp_feats = []
69
+ tmp_preds = []
70
+ tmp_gt = []
71
+ N = feats.size(0)
72
+ for i in range(N):
73
+ inds_H = torch.randperm(H)[0:h]
74
+ inds_W = torch.randperm(W)[0:w]
75
+ tmp_feats += [feats[i, :, inds_H[:, None], inds_W]]
76
+ tmp_preds += [preds[i, :, inds_H[:, None], inds_W]]
77
+ tmp_gt += [gt[i, inds_H[:, None], inds_W]]
78
+
79
+ new_feats = torch.stack(tmp_feats, dim=0)
80
+ new_gt = torch.stack(tmp_gt, dim=0)
81
+ new_preds = torch.stack(tmp_preds, dim=0)
82
+ probs = F.softmax(new_preds, dim=1)
83
+ return new_feats, probs, new_gt
vgg.cpython-37.pyc ADDED
Binary file (3.75 kB). View file
 
vgg.cpython-38.pyc ADDED
Binary file (3.78 kB). View file
 
vgg.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import os
4
+
5
+
6
+ class VGG_Backbone(nn.Module):
7
+ # VGG16 with two branches
8
+ # pooling layer at the front of block
9
+ def __init__(self):
10
+ super(VGG_Backbone, self).__init__()
11
+ conv1 = nn.Sequential()
12
+ conv1.add_module('conv1_1', nn.Conv2d(3, 64, 3, 1, 1))
13
+ conv1.add_module('relu1_1', nn.ReLU(inplace=True))
14
+ conv1.add_module('conv1_2', nn.Conv2d(64, 64, 3, 1, 1))
15
+ conv1.add_module('relu1_2', nn.ReLU(inplace=True))
16
+ self.conv1 = conv1
17
+
18
+ conv2 = nn.Sequential()
19
+ conv2.add_module('pool1', nn.MaxPool2d(2, stride=2))
20
+ conv2.add_module('conv2_1', nn.Conv2d(64, 128, 3, 1, 1))
21
+ conv2.add_module('relu2_1', nn.ReLU())
22
+ conv2.add_module('conv2_2', nn.Conv2d(128, 128, 3, 1, 1))
23
+ conv2.add_module('relu2_2', nn.ReLU())
24
+ self.conv2 = conv2
25
+
26
+ conv3 = nn.Sequential()
27
+ conv3.add_module('pool2', nn.MaxPool2d(2, stride=2))
28
+ conv3.add_module('conv3_1', nn.Conv2d(128, 256, 3, 1, 1))
29
+ conv3.add_module('relu3_1', nn.ReLU())
30
+ conv3.add_module('conv3_2', nn.Conv2d(256, 256, 3, 1, 1))
31
+ conv3.add_module('relu3_2', nn.ReLU())
32
+ conv3.add_module('conv3_3', nn.Conv2d(256, 256, 3, 1, 1))
33
+ conv3.add_module('relu3_3', nn.ReLU())
34
+ self.conv3 = conv3
35
+
36
+ conv4 = nn.Sequential()
37
+ conv4.add_module('pool3', nn.MaxPool2d(2, stride=2))
38
+ conv4.add_module('conv4_1', nn.Conv2d(256, 512, 3, 1, 1))
39
+ conv4.add_module('relu4_1', nn.ReLU())
40
+ conv4.add_module('conv4_2', nn.Conv2d(512, 512, 3, 1, 1))
41
+ conv4.add_module('relu4_2', nn.ReLU())
42
+ conv4.add_module('conv4_3', nn.Conv2d(512, 512, 3, 1, 1))
43
+ conv4.add_module('relu4_3', nn.ReLU())
44
+ self.conv4 = conv4
45
+
46
+ conv5 = nn.Sequential()
47
+ conv5.add_module('pool4', nn.MaxPool2d(2, stride=2))
48
+ conv5.add_module('conv5_1', nn.Conv2d(512, 512, 3, 1, 1))
49
+ conv5.add_module('relu5_1', nn.ReLU())
50
+ conv5.add_module('conv5_2', nn.Conv2d(512, 512, 3, 1, 1))
51
+ conv5.add_module('relu5_2', nn.ReLU())
52
+ conv5.add_module('conv5_3', nn.Conv2d(512, 512, 3, 1, 1))
53
+ conv5.add_module('relu5_3', nn.ReLU())
54
+ self.conv5 = conv5
55
+
56
+ self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
57
+ self.classifier = nn.Sequential(
58
+ nn.Linear(512 * 7 * 7, 4096),
59
+ nn.ReLU(True),
60
+ nn.Dropout(),
61
+ nn.Linear(4096, 4096),
62
+ nn.ReLU(True),
63
+ nn.Dropout(),
64
+ nn.Linear(4096, 1000),
65
+ )
66
+
67
+ # pre_train = torch.load(os.path.dirname(__file__) + '/vgg16-397923af.pth')
68
+ pre_train = torch.load("/scratch/wej36how/codes/DCFM-master/vgg16-397923af.pth")
69
+ self._initialize_weights(pre_train)
70
+
71
+ def forward(self, x):
72
+ x = self.conv1(x)
73
+ x = self.conv2(x)
74
+ x = self.conv3(x)
75
+ x1 = self.conv4_1(x)
76
+ x1 = self.conv5_1(x1)
77
+ x1 = self.avgpool(x1)
78
+ _x1 = x1.view(x1.size(0), -1)
79
+ pred_vector = self.classifier(_x1)
80
+
81
+ x2 = self.conv4_2(x)
82
+ x2 = self.conv5_2(x2)
83
+ return x1, pred_vector, x2
84
+
85
+ def _initialize_weights(self, pre_train):
86
+ keys = list(pre_train.keys())
87
+ self.conv1.conv1_1.weight.data.copy_(pre_train[keys[0]])
88
+ self.conv1.conv1_2.weight.data.copy_(pre_train[keys[2]])
89
+ self.conv2.conv2_1.weight.data.copy_(pre_train[keys[4]])
90
+ self.conv2.conv2_2.weight.data.copy_(pre_train[keys[6]])
91
+ self.conv3.conv3_1.weight.data.copy_(pre_train[keys[8]])
92
+ self.conv3.conv3_2.weight.data.copy_(pre_train[keys[10]])
93
+ self.conv3.conv3_3.weight.data.copy_(pre_train[keys[12]])
94
+ self.conv4.conv4_1.weight.data.copy_(pre_train[keys[14]])
95
+ self.conv4.conv4_2.weight.data.copy_(pre_train[keys[16]])
96
+ self.conv4.conv4_3.weight.data.copy_(pre_train[keys[18]])
97
+ self.conv5.conv5_1.weight.data.copy_(pre_train[keys[20]])
98
+ self.conv5.conv5_2.weight.data.copy_(pre_train[keys[22]])
99
+ self.conv5.conv5_3.weight.data.copy_(pre_train[keys[24]])
100
+
101
+ self.conv1.conv1_1.bias.data.copy_(pre_train[keys[1]])
102
+ self.conv1.conv1_2.bias.data.copy_(pre_train[keys[3]])
103
+ self.conv2.conv2_1.bias.data.copy_(pre_train[keys[5]])
104
+ self.conv2.conv2_2.bias.data.copy_(pre_train[keys[7]])
105
+ self.conv3.conv3_1.bias.data.copy_(pre_train[keys[9]])
106
+ self.conv3.conv3_2.bias.data.copy_(pre_train[keys[11]])
107
+ self.conv3.conv3_3.bias.data.copy_(pre_train[keys[13]])
108
+ self.conv4.conv4_1.bias.data.copy_(pre_train[keys[15]])
109
+ self.conv4.conv4_2.bias.data.copy_(pre_train[keys[17]])
110
+ self.conv4.conv4_3.bias.data.copy_(pre_train[keys[19]])
111
+ self.conv5.conv5_1.bias.data.copy_(pre_train[keys[21]])
112
+ self.conv5.conv5_2.bias.data.copy_(pre_train[keys[23]])
113
+ self.conv5.conv5_3.bias.data.copy_(pre_train[keys[25]])
114
+
115
+ self.classifier[0].weight.data.copy_(pre_train[keys[26]])
116
+ self.classifier[0].bias.data.copy_(pre_train[keys[27]])
117
+ self.classifier[3].weight.data.copy_(pre_train[keys[28]])
118
+ self.classifier[3].bias.data.copy_(pre_train[keys[29]])
119
+ self.classifier[6].weight.data.copy_(pre_train[keys[30]])
120
+ self.classifier[6].bias.data.copy_(pre_train[keys[31]])