Upload 30 files
Browse files- NWRD_dataset.py +40 -0
- README.md +52 -3
- config.py +17 -0
- dataloader.cpython-37.pyc +0 -0
- dataloader.py +50 -0
- dataset.cpython-38.pyc +0 -0
- dataset.py +259 -0
- dataset_preprocessing.ipynb +1046 -0
- evaluator.cpython-37.pyc +0 -0
- evaluator.py +490 -0
- hist_of_pixel_values.py +41 -0
- loss.py +56 -0
- main.cpython-37.pyc +0 -0
- main.cpython-38.pyc +0 -0
- main.py +315 -0
- preprocessing.py +131 -0
- requirements.txt +12 -0
- segmentation_metrics.py +60 -0
- select_results.py +63 -0
- sort_results.py +156 -0
- test.py +91 -0
- train.py +289 -0
- train_wandb.ipynb +459 -0
- util.cpython-38.pyc +0 -0
- util.py +92 -0
- utils.ipynb +0 -0
- utils.py +83 -0
- vgg.cpython-37.pyc +0 -0
- vgg.cpython-38.pyc +0 -0
- vgg.py +120 -0
NWRD_dataset.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data import Dataset
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import os
|
| 4 |
+
from torchvision import transforms
|
| 5 |
+
|
| 6 |
+
class NWRD(Dataset):
|
| 7 |
+
def __init__(self, root_dir, transform=None, train=True):
|
| 8 |
+
self.root_dir = root_dir
|
| 9 |
+
self.transform = transform
|
| 10 |
+
self.images = []
|
| 11 |
+
self.labels = []
|
| 12 |
+
self.load_data()
|
| 13 |
+
|
| 14 |
+
def load_data(self):
|
| 15 |
+
rust_dir = os.path.join(self.root_dir, "rust")
|
| 16 |
+
non_rust_dir = os.path.join(self.root_dir, "non_rust")
|
| 17 |
+
|
| 18 |
+
# Load rust images
|
| 19 |
+
for filename in os.listdir(rust_dir):
|
| 20 |
+
filepath = os.path.join(rust_dir, filename)
|
| 21 |
+
self.images.append(filepath)
|
| 22 |
+
self.labels.append(1)
|
| 23 |
+
|
| 24 |
+
# Load non-rust images
|
| 25 |
+
for filename in os.listdir(non_rust_dir):
|
| 26 |
+
filepath = os.path.join(non_rust_dir, filename)
|
| 27 |
+
self.images.append(filepath)
|
| 28 |
+
self.labels.append(0)
|
| 29 |
+
|
| 30 |
+
def __len__(self):
|
| 31 |
+
return len(self.images)
|
| 32 |
+
|
| 33 |
+
def __getitem__(self, idx):
|
| 34 |
+
image_path = self.images[idx]
|
| 35 |
+
image = Image.open(image_path).convert('RGB')
|
| 36 |
+
|
| 37 |
+
label = int(self.labels[idx])
|
| 38 |
+
if self.transform:
|
| 39 |
+
image = self.transform(image)
|
| 40 |
+
return image, label
|
README.md
CHANGED
|
@@ -1,3 +1,52 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DCFM
|
| 2 |
+
The official repo of the paper `Democracy Does Matter: Comprehensive Feature Mining for Co-Salient Object Detection`.
|
| 3 |
+
|
| 4 |
+
## Environment Requirement
|
| 5 |
+
create enviroment and intall as following:
|
| 6 |
+
`pip install -r requirements.txt`
|
| 7 |
+
|
| 8 |
+
## Data Format
|
| 9 |
+
trainset: CoCo-SEG
|
| 10 |
+
|
| 11 |
+
testset: CoCA, CoSOD3k, Cosal2015
|
| 12 |
+
|
| 13 |
+
Put the [CoCo-SEG](https://drive.google.com/file/d/1GbA_WKvJm04Z1tR8pTSzBdYVQ75avg4f/view), [CoCA](http://zhaozhang.net/coca.html), [CoSOD3k](http://dpfan.net/CoSOD3K/) and [Cosal2015](https://drive.google.com/u/0/uc?id=1mmYpGx17t8WocdPcw2WKeuFpz6VHoZ6K&export=download) datasets to `DCFM/data` as the following structure:
|
| 14 |
+
```
|
| 15 |
+
DCFM
|
| 16 |
+
├── other codes
|
| 17 |
+
├── ...
|
| 18 |
+
│
|
| 19 |
+
└── data
|
| 20 |
+
|
| 21 |
+
├── CoCo-SEG (CoCo-SEG's image files)
|
| 22 |
+
├── CoCA (CoCA's image files)
|
| 23 |
+
├── CoSOD3k (CoSOD3k's image files)
|
| 24 |
+
└── Cosal2015 (Cosal2015's image files)
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
## Trained model
|
| 28 |
+
|
| 29 |
+
trained model can be downloaded from [papermodel](https://drive.google.com/file/d/1cfuq4eJoCwvFR9W1XOJX7Y0ttd8TGjlp/view?usp=sharing).
|
| 30 |
+
|
| 31 |
+
Run `test.py` for inference.
|
| 32 |
+
|
| 33 |
+
The evaluation tool please follow: https://github.com/zzhanghub/eval-co-sod
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
<!-- USAGE EXAMPLES -->
|
| 37 |
+
## Usage
|
| 38 |
+
Download pretrainde backbone model [VGG](https://drive.google.com/file/d/1Z1aAYXMyJ6txQ1Z9N7gtxLOIai4dxrXd/view?usp=sharing).
|
| 39 |
+
|
| 40 |
+
Run `train.py` for training.
|
| 41 |
+
|
| 42 |
+
## Prediction results
|
| 43 |
+
The co-saliency maps of DCFM can be found at [preds](https://drive.google.com/file/d/1wGeNHXFWVSyqvmL4NIUmEFdlHDovEtQR/view?usp=sharing).
|
| 44 |
+
|
| 45 |
+
## Reproduction
|
| 46 |
+
reproductions by myself on 2080Ti can be found at [reproduction1](https://drive.google.com/file/d/1vovii0RtYR_EC0Y2zxjY_cTWKWM3WaxP/view?usp=sharing) and [reproduction2](https://drive.google.com/file/d/1YPOKZ5kBtmZrCDhHpP3-w1GMVR5BfDoU/view?usp=sharing).
|
| 47 |
+
|
| 48 |
+
reprodution by myself on TITAN X can be found at [reproduction3](https://drive.google.com/file/d/1bnGFtRTYkVXqI2dcjeWFRDXnqqbUUBJr/view?usp=sharing).
|
| 49 |
+
|
| 50 |
+
## Others
|
| 51 |
+
The code is based on [GCoNet](https://github.com/fanq15/GCoNet).
|
| 52 |
+
I've added a validation part to help select the model for closer results. This validation part is based on [GCoNet_plus](https://github.com/ZhengPeng7/GCoNet_plus). You can try different evaluation metrics to select the model.
|
config.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class Config():
|
| 5 |
+
def __init__(self) -> None:
|
| 6 |
+
|
| 7 |
+
# Performance of GCoNet
|
| 8 |
+
self.val_measures = {
|
| 9 |
+
'Emax': {'CoCA': 0.783, 'CoSOD3k': 0.874, 'CoSal2015': 0.892},
|
| 10 |
+
'Smeasure': {'CoCA': 0.710, 'CoSOD3k': 0.810, 'CoSal2015': 0.838},
|
| 11 |
+
'Fmax': {'CoCA': 0.598, 'CoSOD3k': 0.805, 'CoSal2015': 0.856},
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
# others
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
self.validation = True
|
dataloader.cpython-37.pyc
ADDED
|
Binary file (1.82 kB). View file
|
|
|
dataloader.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils import data
|
| 2 |
+
import os
|
| 3 |
+
from PIL import Image, ImageFile
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class EvalDataset(data.Dataset):
|
| 10 |
+
def __init__(self, pred_root, label_root, return_predpath=False, return_gtpath=False):
|
| 11 |
+
self.return_predpath = return_predpath
|
| 12 |
+
self.return_gtpath = return_gtpath
|
| 13 |
+
pred_dirs = os.listdir(pred_root)
|
| 14 |
+
label_dirs = os.listdir(label_root)
|
| 15 |
+
|
| 16 |
+
dir_name_list = []
|
| 17 |
+
for idir in pred_dirs:
|
| 18 |
+
if idir in label_dirs:
|
| 19 |
+
pred_names = os.listdir(os.path.join(pred_root, idir))
|
| 20 |
+
label_names = os.listdir(os.path.join(label_root, idir))
|
| 21 |
+
for iname in pred_names:
|
| 22 |
+
if iname in label_names:
|
| 23 |
+
dir_name_list.append(os.path.join(idir, iname))
|
| 24 |
+
|
| 25 |
+
self.image_path = list(
|
| 26 |
+
map(lambda x: os.path.join(pred_root, x), dir_name_list))
|
| 27 |
+
self.label_path = list(
|
| 28 |
+
map(lambda x: os.path.join(label_root, x), dir_name_list))
|
| 29 |
+
|
| 30 |
+
self.labels = []
|
| 31 |
+
for p in self.label_path:
|
| 32 |
+
self.labels.append(Image.open(p).convert('L'))
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def __getitem__(self, item):
|
| 36 |
+
predpath = self.image_path[item]
|
| 37 |
+
gtpath = self.label_path[item]
|
| 38 |
+
pred = Image.open(predpath).convert('L')
|
| 39 |
+
gt = self.labels[item]
|
| 40 |
+
if pred.size != gt.size:
|
| 41 |
+
pred = pred.resize(gt.size, Image.BILINEAR)
|
| 42 |
+
returns = [pred, gt]
|
| 43 |
+
if self.return_predpath:
|
| 44 |
+
returns.append(predpath)
|
| 45 |
+
if self.return_gtpath:
|
| 46 |
+
returns.append(gtpath)
|
| 47 |
+
return returns
|
| 48 |
+
|
| 49 |
+
def __len__(self):
|
| 50 |
+
return len(self.image_path)
|
dataset.cpython-38.pyc
ADDED
|
Binary file (8.83 kB). View file
|
|
|
dataset.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from PIL import Image, ImageOps, ImageFilter#, PILLOW_VERSION
|
| 3 |
+
import torch
|
| 4 |
+
import random
|
| 5 |
+
import numpy as np
|
| 6 |
+
from torch.utils import data
|
| 7 |
+
from torchvision import transforms
|
| 8 |
+
from torchvision.transforms import functional as F
|
| 9 |
+
import numbers
|
| 10 |
+
import random
|
| 11 |
+
import pandas as pd
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class CoData(data.Dataset):
|
| 15 |
+
def __init__(self, img_root, gt_root, img_size, transform, max_num, is_train):
|
| 16 |
+
|
| 17 |
+
class_list = os.listdir(img_root)
|
| 18 |
+
self.size = [img_size, img_size]
|
| 19 |
+
self.img_dirs = list(
|
| 20 |
+
map(lambda x: os.path.join(img_root, x), class_list))
|
| 21 |
+
self.gt_dirs = list(
|
| 22 |
+
map(lambda x: os.path.join(gt_root, x), class_list))
|
| 23 |
+
self.transform = transform
|
| 24 |
+
self.max_num = max_num
|
| 25 |
+
self.is_train = is_train
|
| 26 |
+
|
| 27 |
+
def __getitem__(self, item):
|
| 28 |
+
names = os.listdir(self.img_dirs[item])
|
| 29 |
+
num = len(names)
|
| 30 |
+
img_paths = list(
|
| 31 |
+
map(lambda x: os.path.join(self.img_dirs[item], x), names))
|
| 32 |
+
gt_paths = list(
|
| 33 |
+
map(lambda x: os.path.join(self.gt_dirs[item], x[:-4]+'.png'), names))
|
| 34 |
+
|
| 35 |
+
if self.is_train:
|
| 36 |
+
final_num = min(num, self.max_num)
|
| 37 |
+
|
| 38 |
+
sampled_list = random.sample(range(num), final_num)
|
| 39 |
+
# print(sampled_list)
|
| 40 |
+
new_img_paths = [img_paths[i] for i in sampled_list]
|
| 41 |
+
img_paths = new_img_paths
|
| 42 |
+
new_gt_paths = [gt_paths[i] for i in sampled_list]
|
| 43 |
+
gt_paths = new_gt_paths
|
| 44 |
+
|
| 45 |
+
final_num = final_num
|
| 46 |
+
else:
|
| 47 |
+
final_num = num
|
| 48 |
+
|
| 49 |
+
imgs = torch.Tensor(final_num, 3, self.size[0], self.size[1])
|
| 50 |
+
gts = torch.Tensor(final_num, 1, self.size[0], self.size[1])
|
| 51 |
+
|
| 52 |
+
subpaths = []
|
| 53 |
+
ori_sizes = []
|
| 54 |
+
for idx in range(final_num):
|
| 55 |
+
# print(idx)
|
| 56 |
+
img = Image.open(img_paths[idx]).convert('RGB')
|
| 57 |
+
gt = Image.open(gt_paths[idx]).convert('L')
|
| 58 |
+
|
| 59 |
+
subpaths.append(os.path.join(img_paths[idx].split('/')[-2], img_paths[idx].split('/')[-1][:-4]+'.png'))
|
| 60 |
+
ori_sizes.append((img.size[1], img.size[0]))
|
| 61 |
+
# ori_sizes += ((img.size[1], img.size[0]))
|
| 62 |
+
|
| 63 |
+
[img, gt] = self.transform(img, gt)
|
| 64 |
+
|
| 65 |
+
imgs[idx] = img
|
| 66 |
+
gts[idx] = gt
|
| 67 |
+
if self.is_train:
|
| 68 |
+
cls_ls = [item] * int(final_num)
|
| 69 |
+
return imgs, gts, subpaths, ori_sizes, cls_ls
|
| 70 |
+
else:
|
| 71 |
+
return imgs, gts, subpaths, ori_sizes
|
| 72 |
+
|
| 73 |
+
def __len__(self):
|
| 74 |
+
return len(self.img_dirs)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class FixedResize(object):
|
| 78 |
+
def __init__(self, size):
|
| 79 |
+
self.size = (size, size) # size: (h, w)
|
| 80 |
+
|
| 81 |
+
def __call__(self, img, gt):
|
| 82 |
+
# assert img.size == gt.size
|
| 83 |
+
|
| 84 |
+
img = img.resize(self.size, Image.BILINEAR)
|
| 85 |
+
gt = gt.resize(self.size, Image.NEAREST)
|
| 86 |
+
# gt = gt.resize(self.size, Image.BILINEAR)
|
| 87 |
+
|
| 88 |
+
return img, gt
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class ToTensor(object):
|
| 92 |
+
def __call__(self, img, gt):
|
| 93 |
+
|
| 94 |
+
return F.to_tensor(img), F.to_tensor(gt)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class Normalize(object):
|
| 98 |
+
"""Normalize a tensor image with mean and standard deviation.
|
| 99 |
+
Args:
|
| 100 |
+
mean (tuple): means for each channel.
|
| 101 |
+
std (tuple): standard deviations for each channel.
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)):
|
| 105 |
+
self.mean = mean
|
| 106 |
+
self.std = std
|
| 107 |
+
|
| 108 |
+
def __call__(self, img, gt):
|
| 109 |
+
img = F.normalize(img, self.mean, self.std)
|
| 110 |
+
|
| 111 |
+
return img, gt
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class RandomHorizontalFlip(object):
|
| 115 |
+
def __init__(self, p=0.5):
|
| 116 |
+
self.p = p
|
| 117 |
+
|
| 118 |
+
def __call__(self, img, gt):
|
| 119 |
+
if random.random() < self.p:
|
| 120 |
+
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
| 121 |
+
gt = gt.transpose(Image.FLIP_LEFT_RIGHT)
|
| 122 |
+
|
| 123 |
+
return img, gt
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class RandomScaleCrop(object):
|
| 127 |
+
def __init__(self, base_size, crop_size, fill=0):
|
| 128 |
+
self.base_size = base_size
|
| 129 |
+
self.crop_size = crop_size
|
| 130 |
+
self.fill = fill
|
| 131 |
+
|
| 132 |
+
def __call__(self, img, mask):
|
| 133 |
+
# random scale (short edge)
|
| 134 |
+
# img = img.numpy()
|
| 135 |
+
# mask = mask.numpy()
|
| 136 |
+
short_size = random.randint(int(self.base_size * 0.8), int(self.base_size * 1.2))
|
| 137 |
+
w, h = img.size
|
| 138 |
+
if h > w:
|
| 139 |
+
ow = short_size
|
| 140 |
+
oh = int(1.0 * h * ow / w)
|
| 141 |
+
else:
|
| 142 |
+
oh = short_size
|
| 143 |
+
ow = int(1.0 * w * oh / h)
|
| 144 |
+
img = img.resize((ow, oh), Image.BILINEAR)
|
| 145 |
+
mask = mask.resize((ow, oh), Image.NEAREST)
|
| 146 |
+
# pad crop
|
| 147 |
+
if short_size < self.crop_size:
|
| 148 |
+
padh = self.crop_size - oh if oh < self.crop_size else 0
|
| 149 |
+
padw = self.crop_size - ow if ow < self.crop_size else 0
|
| 150 |
+
img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
|
| 151 |
+
mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill)
|
| 152 |
+
# random crop crop_size
|
| 153 |
+
w, h = img.size
|
| 154 |
+
x1 = random.randint(0, w - self.crop_size)
|
| 155 |
+
y1 = random.randint(0, h - self.crop_size)
|
| 156 |
+
img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
|
| 157 |
+
mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
|
| 158 |
+
|
| 159 |
+
return img, mask
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class RandomRotation(object):
|
| 163 |
+
def __init__(self, degrees, resample=False, expand=False, center=None):
|
| 164 |
+
if isinstance(degrees, numbers.Number):
|
| 165 |
+
if degrees < 0:
|
| 166 |
+
raise ValueError("If degrees is a single number, it must be positive.")
|
| 167 |
+
self.degrees = (-degrees, degrees)
|
| 168 |
+
else:
|
| 169 |
+
if len(degrees) != 2:
|
| 170 |
+
raise ValueError("If degrees is a sequence, it must be of len 2.")
|
| 171 |
+
self.degrees = degrees
|
| 172 |
+
|
| 173 |
+
self.resample = resample
|
| 174 |
+
self.expand = expand
|
| 175 |
+
self.center = center
|
| 176 |
+
|
| 177 |
+
@staticmethod
|
| 178 |
+
def get_params(degrees):
|
| 179 |
+
angle = random.uniform(degrees[0], degrees[1])
|
| 180 |
+
|
| 181 |
+
return angle
|
| 182 |
+
|
| 183 |
+
def __call__(self, img, gt):
|
| 184 |
+
"""
|
| 185 |
+
img (PIL Image): Image to be rotated.
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
PIL Image: Rotated image.
|
| 189 |
+
"""
|
| 190 |
+
|
| 191 |
+
angle = self.get_params(self.degrees)
|
| 192 |
+
|
| 193 |
+
return F.rotate(img, angle, Image.BILINEAR, self.expand, self.center), F.rotate(gt, angle, Image.NEAREST, self.expand, self.center)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class Compose(object):
|
| 198 |
+
def __init__(self, transforms):
|
| 199 |
+
self.transforms = transforms
|
| 200 |
+
|
| 201 |
+
def __call__(self, img, gt):
|
| 202 |
+
for t in self.transforms:
|
| 203 |
+
img, gt = t(img, gt)
|
| 204 |
+
return img, gt
|
| 205 |
+
|
| 206 |
+
def __repr__(self):
|
| 207 |
+
format_string = self.__class__.__name__ + '('
|
| 208 |
+
for t in self.transforms:
|
| 209 |
+
format_string += '\n'
|
| 210 |
+
format_string += ' {0}'.format(t)
|
| 211 |
+
format_string += '\n)'
|
| 212 |
+
return format_string
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
# get the dataloader (Note: without data augmentation)
|
| 216 |
+
def get_loader(img_root, gt_root, img_size, batch_size, max_num = float('inf'), istrain=True, shuffle=False, num_workers=0, pin=False):
|
| 217 |
+
if istrain:
|
| 218 |
+
transform = Compose([
|
| 219 |
+
RandomScaleCrop(img_size*2, img_size*2),
|
| 220 |
+
FixedResize(img_size),
|
| 221 |
+
RandomHorizontalFlip(),
|
| 222 |
+
|
| 223 |
+
RandomRotation((-90, 90)),
|
| 224 |
+
ToTensor(),
|
| 225 |
+
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 226 |
+
])
|
| 227 |
+
else:
|
| 228 |
+
transform = Compose([
|
| 229 |
+
FixedResize(img_size),
|
| 230 |
+
# RandomHorizontalFlip(),
|
| 231 |
+
ToTensor(),
|
| 232 |
+
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 233 |
+
])
|
| 234 |
+
|
| 235 |
+
dataset = CoData(img_root, gt_root, img_size, transform, max_num, is_train=istrain)
|
| 236 |
+
data_loader = data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers,
|
| 237 |
+
pin_memory=pin)
|
| 238 |
+
return data_loader
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
if __name__ == '__main__':
|
| 242 |
+
import matplotlib.pyplot as plt
|
| 243 |
+
|
| 244 |
+
mean = [0.485, 0.456, 0.406]
|
| 245 |
+
std = [0.229, 0.224, 0.225]
|
| 246 |
+
img_root = './data/testtrain/img/'
|
| 247 |
+
gt_root = './data/testtrain/gt/'
|
| 248 |
+
loader = get_loader(img_root, gt_root, 20, 1, 16, istrain=False)
|
| 249 |
+
for batch in loader:
|
| 250 |
+
b, c, h, w = batch[0][0].shape
|
| 251 |
+
for i in range(b):
|
| 252 |
+
img = batch[0].squeeze(0)[i].permute(1, 2, 0).cpu().numpy() * std + mean
|
| 253 |
+
image = img * 255
|
| 254 |
+
mask = batch[1].squeeze(0)[i].squeeze().cpu().numpy()
|
| 255 |
+
plt.subplot(121)
|
| 256 |
+
plt.imshow(np.uint8(image))
|
| 257 |
+
plt.subplot(122)
|
| 258 |
+
plt.imshow(mask)
|
| 259 |
+
plt.show(block=True)
|
dataset_preprocessing.ipynb
ADDED
|
@@ -0,0 +1,1046 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [
|
| 8 |
+
{
|
| 9 |
+
"name": "stderr",
|
| 10 |
+
"output_type": "stream",
|
| 11 |
+
"text": [
|
| 12 |
+
"/home/wej36how/.conda/envs/dmt/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
| 13 |
+
" from .autonotebook import tqdm as notebook_tqdm\n"
|
| 14 |
+
]
|
| 15 |
+
}
|
| 16 |
+
],
|
| 17 |
+
"source": [
|
| 18 |
+
"import glob\n",
|
| 19 |
+
"import os\n",
|
| 20 |
+
"import cv2\n",
|
| 21 |
+
"import numpy as np\n",
|
| 22 |
+
"from PIL import Image\n",
|
| 23 |
+
"import torch \n",
|
| 24 |
+
"from PIL import Image\n",
|
| 25 |
+
"from pathlib import Path\n",
|
| 26 |
+
"import torchvision.transforms as T\n",
|
| 27 |
+
"import torchvision.transforms.functional as TF\n",
|
| 28 |
+
"import numpy as np\n",
|
| 29 |
+
"from torchvision import transforms\n",
|
| 30 |
+
"import os\n",
|
| 31 |
+
"import cv2\n",
|
| 32 |
+
"import matplotlib.pyplot as plt\n",
|
| 33 |
+
"import shutil"
|
| 34 |
+
]
|
| 35 |
+
},
|
| 36 |
+
{
|
| 37 |
+
"cell_type": "code",
|
| 38 |
+
"execution_count": 2,
|
| 39 |
+
"metadata": {},
|
| 40 |
+
"outputs": [],
|
| 41 |
+
"source": [
|
| 42 |
+
"source = \"/scratch/wej36how/Datasets/NWRD/val\"\n",
|
| 43 |
+
"dest = \"/scratch/wej36how/Datasets/NWRDProcessed/val\"\n",
|
| 44 |
+
"patch_size = 224\n",
|
| 45 |
+
"rust_threshold = 150\n",
|
| 46 |
+
"max_number_of_images_per_group = 12"
|
| 47 |
+
]
|
| 48 |
+
},
|
| 49 |
+
{
|
| 50 |
+
"cell_type": "markdown",
|
| 51 |
+
"metadata": {},
|
| 52 |
+
"source": [
|
| 53 |
+
"This snippet will make patches of the images in the destination/patches directory."
|
| 54 |
+
]
|
| 55 |
+
},
|
| 56 |
+
{
|
| 57 |
+
"cell_type": "code",
|
| 58 |
+
"execution_count": null,
|
| 59 |
+
"metadata": {},
|
| 60 |
+
"outputs": [],
|
| 61 |
+
"source": [
|
| 62 |
+
"patches_path = os.path.join(dest, \"patches\")\n",
|
| 63 |
+
"images_dir = os.path.join(patches_path, \"images\")\n",
|
| 64 |
+
"masks_dir = os.path.join(patches_path, \"masks\")"
|
| 65 |
+
]
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"cell_type": "code",
|
| 69 |
+
"execution_count": 3,
|
| 70 |
+
"metadata": {},
|
| 71 |
+
"outputs": [
|
| 72 |
+
{
|
| 73 |
+
"name": "stdout",
|
| 74 |
+
"output_type": "stream",
|
| 75 |
+
"text": [
|
| 76 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/10.jpg\n",
|
| 77 |
+
"image shape: (4000, 6016, 3)\n",
|
| 78 |
+
"total patches: 442\n",
|
| 79 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/100.jpg\n",
|
| 80 |
+
"image shape: (4608, 3456, 3)\n",
|
| 81 |
+
"total patches: 300\n",
|
| 82 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/101.jpg\n",
|
| 83 |
+
"image shape: (3456, 3612, 3)\n",
|
| 84 |
+
"total patches: 240\n",
|
| 85 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/102.jpg\n",
|
| 86 |
+
"image shape: (2984, 4248, 3)\n",
|
| 87 |
+
"total patches: 234\n",
|
| 88 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/103.jpg\n",
|
| 89 |
+
"image shape: (3584, 3456, 3)\n",
|
| 90 |
+
"total patches: 226\n",
|
| 91 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/104.jpg\n",
|
| 92 |
+
"image shape: (3456, 4608, 3)\n",
|
| 93 |
+
"total patches: 300\n",
|
| 94 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/109.jpg\n",
|
| 95 |
+
"image shape: (4608, 3456, 3)\n",
|
| 96 |
+
"total patches: 300\n",
|
| 97 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/11.jpg\n",
|
| 98 |
+
"image shape: (4000, 6016, 3)\n",
|
| 99 |
+
"total patches: 442\n",
|
| 100 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/110.jpg\n",
|
| 101 |
+
"image shape: (4600, 2536, 3)\n",
|
| 102 |
+
"total patches: 220\n",
|
| 103 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/111.jpg\n",
|
| 104 |
+
"image shape: (2909, 4608, 3)\n",
|
| 105 |
+
"total patches: 240\n",
|
| 106 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/113.jpg\n",
|
| 107 |
+
"image shape: (4608, 3456, 3)\n",
|
| 108 |
+
"total patches: 300\n",
|
| 109 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/114.jpg\n",
|
| 110 |
+
"image shape: (3456, 4608, 3)\n",
|
| 111 |
+
"total patches: 300\n",
|
| 112 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/117.jpg\n",
|
| 113 |
+
"image shape: (3456, 4608, 3)\n",
|
| 114 |
+
"total patches: 300\n",
|
| 115 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/118.jpg\n",
|
| 116 |
+
"image shape: (3456, 4608, 3)\n",
|
| 117 |
+
"total patches: 300\n",
|
| 118 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/119.jpg\n",
|
| 119 |
+
"image shape: (3456, 4608, 3)\n",
|
| 120 |
+
"total patches: 300\n",
|
| 121 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/12.jpg\n",
|
| 122 |
+
"image shape: (4000, 6016, 3)\n",
|
| 123 |
+
"total patches: 442\n",
|
| 124 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/121.jpg\n",
|
| 125 |
+
"image shape: (3456, 4608, 3)\n",
|
| 126 |
+
"total patches: 300\n",
|
| 127 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/122.jpg\n",
|
| 128 |
+
"image shape: (3456, 4608, 3)\n",
|
| 129 |
+
"total patches: 300\n",
|
| 130 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/124.jpg\n",
|
| 131 |
+
"image shape: (4608, 3456, 3)\n",
|
| 132 |
+
"total patches: 300\n",
|
| 133 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/125.jpg\n",
|
| 134 |
+
"image shape: (4608, 3456, 3)\n",
|
| 135 |
+
"total patches: 300\n",
|
| 136 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/128.jpg\n",
|
| 137 |
+
"image shape: (3456, 4608, 3)\n",
|
| 138 |
+
"total patches: 300\n",
|
| 139 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/129.jpg\n",
|
| 140 |
+
"image shape: (3456, 4608, 3)\n",
|
| 141 |
+
"total patches: 300\n",
|
| 142 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/13.jpg\n",
|
| 143 |
+
"image shape: (4000, 6016, 3)\n",
|
| 144 |
+
"total patches: 442\n",
|
| 145 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/130.jpg\n",
|
| 146 |
+
"image shape: (3456, 4608, 3)\n",
|
| 147 |
+
"total patches: 300\n",
|
| 148 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/131.jpg\n",
|
| 149 |
+
"image shape: (3456, 4608, 3)\n",
|
| 150 |
+
"total patches: 300\n",
|
| 151 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/132.jpg\n",
|
| 152 |
+
"image shape: (4608, 3456, 3)\n",
|
| 153 |
+
"total patches: 300\n",
|
| 154 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/133.jpg\n",
|
| 155 |
+
"image shape: (4608, 3456, 3)\n",
|
| 156 |
+
"total patches: 300\n",
|
| 157 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/134.jpg\n",
|
| 158 |
+
"image shape: (4608, 3456, 3)\n",
|
| 159 |
+
"total patches: 300\n",
|
| 160 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/135.jpg\n",
|
| 161 |
+
"image shape: (4608, 3456, 3)\n",
|
| 162 |
+
"total patches: 300\n",
|
| 163 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/136.jpg\n",
|
| 164 |
+
"image shape: (3456, 4608, 3)\n",
|
| 165 |
+
"total patches: 300\n",
|
| 166 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/137.jpg\n",
|
| 167 |
+
"image shape: (4608, 3456, 3)\n",
|
| 168 |
+
"total patches: 300\n",
|
| 169 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/138.jpg\n",
|
| 170 |
+
"image shape: (4608, 3456, 3)\n",
|
| 171 |
+
"total patches: 300\n",
|
| 172 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/139.jpg\n",
|
| 173 |
+
"image shape: (3456, 4608, 3)\n",
|
| 174 |
+
"total patches: 300\n",
|
| 175 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/14.jpg\n",
|
| 176 |
+
"image shape: (4000, 6016, 3)\n",
|
| 177 |
+
"total patches: 442\n",
|
| 178 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/140.jpg\n",
|
| 179 |
+
"image shape: (3456, 4608, 3)\n",
|
| 180 |
+
"total patches: 300\n",
|
| 181 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/141.jpg\n",
|
| 182 |
+
"image shape: (4608, 3456, 3)\n",
|
| 183 |
+
"total patches: 300\n",
|
| 184 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/142.jpg\n",
|
| 185 |
+
"image shape: (4608, 3456, 3)\n",
|
| 186 |
+
"total patches: 300\n",
|
| 187 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/144.jpg\n",
|
| 188 |
+
"image shape: (4608, 3456, 3)\n",
|
| 189 |
+
"total patches: 300\n",
|
| 190 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/145.jpg\n",
|
| 191 |
+
"image shape: (4608, 3456, 3)\n",
|
| 192 |
+
"total patches: 300\n",
|
| 193 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/146.jpg\n",
|
| 194 |
+
"image shape: (4608, 3456, 3)\n",
|
| 195 |
+
"total patches: 300\n",
|
| 196 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/149.jpg\n",
|
| 197 |
+
"image shape: (3968, 3424, 3)\n",
|
| 198 |
+
"total patches: 255\n",
|
| 199 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/15.jpg\n",
|
| 200 |
+
"image shape: (4000, 6016, 3)\n",
|
| 201 |
+
"total patches: 442\n",
|
| 202 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/150.jpg\n",
|
| 203 |
+
"image shape: (3456, 4608, 3)\n",
|
| 204 |
+
"total patches: 300\n",
|
| 205 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/19.jpg\n",
|
| 206 |
+
"image shape: (4000, 6016, 3)\n",
|
| 207 |
+
"total patches: 442\n",
|
| 208 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/2.jpg\n",
|
| 209 |
+
"image shape: (4000, 6016, 3)\n",
|
| 210 |
+
"total patches: 442\n",
|
| 211 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/21.jpg\n",
|
| 212 |
+
"image shape: (4000, 6016, 3)\n",
|
| 213 |
+
"total patches: 442\n",
|
| 214 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/24.jpg\n",
|
| 215 |
+
"image shape: (4000, 6016, 3)\n",
|
| 216 |
+
"total patches: 442\n",
|
| 217 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/26.jpg\n",
|
| 218 |
+
"image shape: (4000, 6016, 3)\n",
|
| 219 |
+
"total patches: 442\n",
|
| 220 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/27.jpg\n",
|
| 221 |
+
"image shape: (4000, 6016, 3)\n",
|
| 222 |
+
"total patches: 442\n",
|
| 223 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/28.jpg\n",
|
| 224 |
+
"image shape: (4000, 6016, 3)\n",
|
| 225 |
+
"total patches: 442\n",
|
| 226 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/3.jpg\n",
|
| 227 |
+
"image shape: (4000, 6016, 3)\n",
|
| 228 |
+
"total patches: 442\n",
|
| 229 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/30.jpg\n",
|
| 230 |
+
"image shape: (4000, 6016, 3)\n",
|
| 231 |
+
"total patches: 442\n",
|
| 232 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/31.jpg\n",
|
| 233 |
+
"image shape: (4000, 6016, 3)\n",
|
| 234 |
+
"total patches: 442\n",
|
| 235 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/32.jpg\n",
|
| 236 |
+
"image shape: (4000, 6016, 3)\n",
|
| 237 |
+
"total patches: 442\n",
|
| 238 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/33.jpg\n",
|
| 239 |
+
"image shape: (4000, 6016, 3)\n",
|
| 240 |
+
"total patches: 442\n",
|
| 241 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/5.jpg\n",
|
| 242 |
+
"image shape: (4000, 6016, 3)\n",
|
| 243 |
+
"total patches: 442\n",
|
| 244 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/57.jpg\n",
|
| 245 |
+
"image shape: (4000, 6016, 3)\n",
|
| 246 |
+
"total patches: 442\n",
|
| 247 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/58.jpg\n",
|
| 248 |
+
"image shape: (4000, 6016, 3)\n",
|
| 249 |
+
"total patches: 442\n",
|
| 250 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/60.jpg\n",
|
| 251 |
+
"image shape: (4000, 6016, 3)\n",
|
| 252 |
+
"total patches: 442\n",
|
| 253 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/64.jpg\n",
|
| 254 |
+
"image shape: (4000, 6016, 3)\n",
|
| 255 |
+
"total patches: 442\n",
|
| 256 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/69.jpg\n",
|
| 257 |
+
"image shape: (4608, 3456, 3)\n",
|
| 258 |
+
"total patches: 300\n",
|
| 259 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/71.jpg\n",
|
| 260 |
+
"image shape: (3456, 4608, 3)\n",
|
| 261 |
+
"total patches: 300\n",
|
| 262 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/72.jpg\n",
|
| 263 |
+
"image shape: (3456, 4608, 3)\n",
|
| 264 |
+
"total patches: 300\n",
|
| 265 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/73.jpg\n",
|
| 266 |
+
"image shape: (3456, 4608, 3)\n",
|
| 267 |
+
"total patches: 300\n",
|
| 268 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/74.jpg\n",
|
| 269 |
+
"image shape: (3456, 4608, 3)\n",
|
| 270 |
+
"total patches: 300\n",
|
| 271 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/75.jpg\n",
|
| 272 |
+
"image shape: (3456, 4608, 3)\n",
|
| 273 |
+
"total patches: 300\n",
|
| 274 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/76.jpg\n",
|
| 275 |
+
"image shape: (3456, 4608, 3)\n",
|
| 276 |
+
"total patches: 300\n",
|
| 277 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/78.jpg\n",
|
| 278 |
+
"image shape: (3456, 4608, 3)\n",
|
| 279 |
+
"total patches: 300\n",
|
| 280 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/79.jpg\n",
|
| 281 |
+
"image shape: (3456, 4608, 3)\n",
|
| 282 |
+
"total patches: 300\n",
|
| 283 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/81.jpg\n",
|
| 284 |
+
"image shape: (3456, 4608, 3)\n",
|
| 285 |
+
"total patches: 300\n",
|
| 286 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/83.jpg\n",
|
| 287 |
+
"image shape: (3456, 4608, 3)\n",
|
| 288 |
+
"total patches: 300\n",
|
| 289 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/85.jpg\n",
|
| 290 |
+
"image shape: (4608, 3456, 3)\n",
|
| 291 |
+
"total patches: 300\n",
|
| 292 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/86.jpg\n",
|
| 293 |
+
"image shape: (4608, 3456, 3)\n",
|
| 294 |
+
"total patches: 300\n",
|
| 295 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/87.jpg\n",
|
| 296 |
+
"image shape: (4608, 3456, 3)\n",
|
| 297 |
+
"total patches: 300\n",
|
| 298 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/88.jpg\n",
|
| 299 |
+
"image shape: (4608, 3456, 3)\n",
|
| 300 |
+
"total patches: 300\n",
|
| 301 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/9.jpg\n",
|
| 302 |
+
"image shape: (4000, 6016, 3)\n",
|
| 303 |
+
"total patches: 442\n",
|
| 304 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/90.jpg\n",
|
| 305 |
+
"image shape: (4608, 3456, 3)\n",
|
| 306 |
+
"total patches: 300\n",
|
| 307 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/91.jpg\n",
|
| 308 |
+
"image shape: (4608, 3456, 3)\n",
|
| 309 |
+
"total patches: 300\n",
|
| 310 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/92.jpg\n",
|
| 311 |
+
"image shape: (4608, 3456, 3)\n",
|
| 312 |
+
"total patches: 300\n",
|
| 313 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/93.jpg\n",
|
| 314 |
+
"image shape: (4608, 3456, 3)\n",
|
| 315 |
+
"total patches: 300\n",
|
| 316 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/94.jpg\n",
|
| 317 |
+
"image shape: (4608, 3456, 3)\n",
|
| 318 |
+
"total patches: 300\n",
|
| 319 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/95.jpg\n",
|
| 320 |
+
"image shape: (4608, 3456, 3)\n",
|
| 321 |
+
"total patches: 300\n",
|
| 322 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/96.jpg\n",
|
| 323 |
+
"image shape: (4608, 3456, 3)\n",
|
| 324 |
+
"total patches: 300\n",
|
| 325 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/98.jpg\n",
|
| 326 |
+
"image shape: (4608, 3456, 3)\n",
|
| 327 |
+
"total patches: 300\n",
|
| 328 |
+
"/scratch/wej36how/Datasets/NWRD/train/images/99.jpg\n",
|
| 329 |
+
"image shape: (4608, 3456, 3)\n",
|
| 330 |
+
"total patches: 300\n",
|
| 331 |
+
"total image count: 28523\n",
|
| 332 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/10.png\n",
|
| 333 |
+
"image shape: (4000, 6016, 3)\n",
|
| 334 |
+
"total patches: 442\n",
|
| 335 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/100.png\n",
|
| 336 |
+
"image shape: (4608, 3456, 3)\n",
|
| 337 |
+
"total patches: 300\n",
|
| 338 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/101.png\n",
|
| 339 |
+
"image shape: (3435, 3593, 3)\n",
|
| 340 |
+
"total patches: 240\n",
|
| 341 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/102.png\n",
|
| 342 |
+
"image shape: (2984, 4248, 3)\n",
|
| 343 |
+
"total patches: 234\n",
|
| 344 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/103.png\n",
|
| 345 |
+
"image shape: (3584, 3456, 3)\n",
|
| 346 |
+
"total patches: 226\n",
|
| 347 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/104.png\n",
|
| 348 |
+
"image shape: (3456, 4608, 3)\n",
|
| 349 |
+
"total patches: 300\n",
|
| 350 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/109.png\n",
|
| 351 |
+
"image shape: (4608, 3456, 3)\n",
|
| 352 |
+
"total patches: 300\n",
|
| 353 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/11.png\n",
|
| 354 |
+
"image shape: (4000, 6016, 3)\n",
|
| 355 |
+
"total patches: 442\n",
|
| 356 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/110.png\n",
|
| 357 |
+
"image shape: (4600, 2536, 3)\n",
|
| 358 |
+
"total patches: 220\n",
|
| 359 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/111.png\n",
|
| 360 |
+
"image shape: (2909, 4608, 3)\n",
|
| 361 |
+
"total patches: 240\n",
|
| 362 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/113.png\n",
|
| 363 |
+
"image shape: (4608, 3456, 3)\n",
|
| 364 |
+
"total patches: 300\n",
|
| 365 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/114.png\n",
|
| 366 |
+
"image shape: (3456, 4608, 3)\n",
|
| 367 |
+
"total patches: 300\n",
|
| 368 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/117.png\n",
|
| 369 |
+
"image shape: (3456, 4608, 3)\n",
|
| 370 |
+
"total patches: 300\n",
|
| 371 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/118.png\n",
|
| 372 |
+
"image shape: (3456, 4608, 3)\n",
|
| 373 |
+
"total patches: 300\n",
|
| 374 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/119.png\n",
|
| 375 |
+
"image shape: (3456, 4608, 3)\n",
|
| 376 |
+
"total patches: 300\n",
|
| 377 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/12.png\n",
|
| 378 |
+
"image shape: (4000, 6016, 3)\n",
|
| 379 |
+
"total patches: 442\n",
|
| 380 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/121.png\n",
|
| 381 |
+
"image shape: (3456, 4608, 3)\n",
|
| 382 |
+
"total patches: 300\n",
|
| 383 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/122.png\n",
|
| 384 |
+
"image shape: (3456, 4608, 3)\n",
|
| 385 |
+
"total patches: 300\n",
|
| 386 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/124.png\n",
|
| 387 |
+
"image shape: (4608, 3456, 3)\n",
|
| 388 |
+
"total patches: 300\n",
|
| 389 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/125.png\n",
|
| 390 |
+
"image shape: (4608, 3456, 3)\n",
|
| 391 |
+
"total patches: 300\n",
|
| 392 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/128.png\n",
|
| 393 |
+
"image shape: (3456, 4608, 3)\n",
|
| 394 |
+
"total patches: 300\n",
|
| 395 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/129.png\n",
|
| 396 |
+
"image shape: (3456, 4608, 3)\n",
|
| 397 |
+
"total patches: 300\n",
|
| 398 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/13.png\n",
|
| 399 |
+
"image shape: (4000, 6016, 3)\n",
|
| 400 |
+
"total patches: 442\n",
|
| 401 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/130.png\n",
|
| 402 |
+
"image shape: (3456, 4608, 3)\n",
|
| 403 |
+
"total patches: 300\n",
|
| 404 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/131.png\n",
|
| 405 |
+
"image shape: (3456, 4608, 3)\n",
|
| 406 |
+
"total patches: 300\n",
|
| 407 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/132.png\n",
|
| 408 |
+
"image shape: (4608, 3456, 3)\n",
|
| 409 |
+
"total patches: 300\n",
|
| 410 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/133.png\n",
|
| 411 |
+
"image shape: (4608, 3456, 3)\n",
|
| 412 |
+
"total patches: 300\n",
|
| 413 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/134.png\n",
|
| 414 |
+
"image shape: (4608, 3456, 3)\n",
|
| 415 |
+
"total patches: 300\n",
|
| 416 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/135.png\n",
|
| 417 |
+
"image shape: (4608, 3456, 3)\n",
|
| 418 |
+
"total patches: 300\n",
|
| 419 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/136.png\n",
|
| 420 |
+
"image shape: (3456, 4608, 3)\n",
|
| 421 |
+
"total patches: 300\n",
|
| 422 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/137.png\n",
|
| 423 |
+
"image shape: (4608, 3456, 3)\n",
|
| 424 |
+
"total patches: 300\n",
|
| 425 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/138.png\n",
|
| 426 |
+
"image shape: (4608, 3456, 3)\n",
|
| 427 |
+
"total patches: 300\n",
|
| 428 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/139.png\n",
|
| 429 |
+
"image shape: (3456, 4608, 3)\n",
|
| 430 |
+
"total patches: 300\n",
|
| 431 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/14.png\n",
|
| 432 |
+
"image shape: (4000, 6016, 3)\n",
|
| 433 |
+
"total patches: 442\n",
|
| 434 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/140.png\n",
|
| 435 |
+
"image shape: (3456, 4608, 3)\n",
|
| 436 |
+
"total patches: 300\n",
|
| 437 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/141.png\n",
|
| 438 |
+
"image shape: (4608, 3456, 3)\n",
|
| 439 |
+
"total patches: 300\n",
|
| 440 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/142.png\n",
|
| 441 |
+
"image shape: (4608, 3456, 3)\n",
|
| 442 |
+
"total patches: 300\n",
|
| 443 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/144.png\n",
|
| 444 |
+
"image shape: (4608, 3456, 3)\n",
|
| 445 |
+
"total patches: 300\n",
|
| 446 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/145.png\n",
|
| 447 |
+
"image shape: (4608, 3456, 3)\n",
|
| 448 |
+
"total patches: 300\n",
|
| 449 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/146.png\n",
|
| 450 |
+
"image shape: (4608, 3456, 3)\n",
|
| 451 |
+
"total patches: 300\n",
|
| 452 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/149.png\n",
|
| 453 |
+
"image shape: (3968, 3424, 3)\n",
|
| 454 |
+
"total patches: 255\n",
|
| 455 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/15.png\n",
|
| 456 |
+
"image shape: (4000, 6016, 3)\n",
|
| 457 |
+
"total patches: 442\n",
|
| 458 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/150.png\n",
|
| 459 |
+
"image shape: (3456, 4608, 3)\n",
|
| 460 |
+
"total patches: 300\n",
|
| 461 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/19.png\n",
|
| 462 |
+
"image shape: (4000, 6016, 3)\n",
|
| 463 |
+
"total patches: 442\n",
|
| 464 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/2.png\n",
|
| 465 |
+
"image shape: (4000, 6016, 3)\n",
|
| 466 |
+
"total patches: 442\n",
|
| 467 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/21.png\n",
|
| 468 |
+
"image shape: (4000, 6016, 3)\n",
|
| 469 |
+
"total patches: 442\n",
|
| 470 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/24.png\n",
|
| 471 |
+
"image shape: (4000, 6016, 3)\n",
|
| 472 |
+
"total patches: 442\n",
|
| 473 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/26.png\n",
|
| 474 |
+
"image shape: (4000, 6016, 3)\n",
|
| 475 |
+
"total patches: 442\n",
|
| 476 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/27.png\n",
|
| 477 |
+
"image shape: (4000, 6016, 3)\n",
|
| 478 |
+
"total patches: 442\n",
|
| 479 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/28.png\n",
|
| 480 |
+
"image shape: (4000, 6016, 3)\n",
|
| 481 |
+
"total patches: 442\n",
|
| 482 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/3.png\n",
|
| 483 |
+
"image shape: (4000, 6016, 3)\n",
|
| 484 |
+
"total patches: 442\n",
|
| 485 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/30.png\n",
|
| 486 |
+
"image shape: (4000, 6016, 3)\n",
|
| 487 |
+
"total patches: 442\n",
|
| 488 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/31.png\n",
|
| 489 |
+
"image shape: (4000, 6016, 3)\n",
|
| 490 |
+
"total patches: 442\n",
|
| 491 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/32.png\n",
|
| 492 |
+
"image shape: (4000, 6016, 3)\n",
|
| 493 |
+
"total patches: 442\n",
|
| 494 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/33.png\n",
|
| 495 |
+
"image shape: (4000, 6016, 3)\n",
|
| 496 |
+
"total patches: 442\n",
|
| 497 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/5.png\n",
|
| 498 |
+
"image shape: (4000, 6016, 3)\n",
|
| 499 |
+
"total patches: 442\n",
|
| 500 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/57.png\n",
|
| 501 |
+
"image shape: (4000, 6016, 3)\n",
|
| 502 |
+
"total patches: 442\n",
|
| 503 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/58.png\n",
|
| 504 |
+
"image shape: (4000, 6016, 3)\n",
|
| 505 |
+
"total patches: 442\n",
|
| 506 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/60.png\n",
|
| 507 |
+
"image shape: (4000, 6016, 3)\n",
|
| 508 |
+
"total patches: 442\n",
|
| 509 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/64.png\n",
|
| 510 |
+
"image shape: (4000, 6016, 3)\n",
|
| 511 |
+
"total patches: 442\n",
|
| 512 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/69.png\n",
|
| 513 |
+
"image shape: (4608, 3456, 3)\n",
|
| 514 |
+
"total patches: 300\n",
|
| 515 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/71.png\n",
|
| 516 |
+
"image shape: (3456, 4608, 3)\n",
|
| 517 |
+
"total patches: 300\n",
|
| 518 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/72.png\n",
|
| 519 |
+
"image shape: (3456, 4608, 3)\n",
|
| 520 |
+
"total patches: 300\n",
|
| 521 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/73.png\n",
|
| 522 |
+
"image shape: (3456, 4608, 3)\n",
|
| 523 |
+
"total patches: 300\n",
|
| 524 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/74.png\n",
|
| 525 |
+
"image shape: (3456, 4608, 3)\n",
|
| 526 |
+
"total patches: 300\n",
|
| 527 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/75.png\n",
|
| 528 |
+
"image shape: (3456, 4608, 3)\n",
|
| 529 |
+
"total patches: 300\n",
|
| 530 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/76.png\n",
|
| 531 |
+
"image shape: (3456, 4608, 3)\n",
|
| 532 |
+
"total patches: 300\n",
|
| 533 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/78.png\n",
|
| 534 |
+
"image shape: (3456, 4608, 3)\n",
|
| 535 |
+
"total patches: 300\n",
|
| 536 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/79.png\n",
|
| 537 |
+
"image shape: (3456, 4608, 3)\n",
|
| 538 |
+
"total patches: 300\n",
|
| 539 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/81.png\n",
|
| 540 |
+
"image shape: (3456, 4608, 3)\n",
|
| 541 |
+
"total patches: 300\n",
|
| 542 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/83.png\n",
|
| 543 |
+
"image shape: (3456, 4608, 3)\n",
|
| 544 |
+
"total patches: 300\n",
|
| 545 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/85.png\n",
|
| 546 |
+
"image shape: (4608, 3456, 3)\n",
|
| 547 |
+
"total patches: 300\n",
|
| 548 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/86.png\n",
|
| 549 |
+
"image shape: (4608, 3456, 3)\n",
|
| 550 |
+
"total patches: 300\n",
|
| 551 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/87.png\n",
|
| 552 |
+
"image shape: (4608, 3456, 3)\n",
|
| 553 |
+
"total patches: 300\n",
|
| 554 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/88.png\n",
|
| 555 |
+
"image shape: (4608, 3456, 3)\n",
|
| 556 |
+
"total patches: 300\n",
|
| 557 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/9.png\n",
|
| 558 |
+
"image shape: (4000, 6016, 3)\n",
|
| 559 |
+
"total patches: 442\n",
|
| 560 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/90.png\n",
|
| 561 |
+
"image shape: (4608, 3456, 3)\n",
|
| 562 |
+
"total patches: 300\n",
|
| 563 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/91.png\n",
|
| 564 |
+
"image shape: (4608, 3456, 3)\n",
|
| 565 |
+
"total patches: 300\n",
|
| 566 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/92.png\n",
|
| 567 |
+
"image shape: (4608, 3456, 3)\n",
|
| 568 |
+
"total patches: 300\n",
|
| 569 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/93.png\n",
|
| 570 |
+
"image shape: (4608, 3456, 3)\n",
|
| 571 |
+
"total patches: 300\n",
|
| 572 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/94.png\n",
|
| 573 |
+
"image shape: (4608, 3456, 3)\n",
|
| 574 |
+
"total patches: 300\n",
|
| 575 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/95.png\n",
|
| 576 |
+
"image shape: (4608, 3456, 3)\n",
|
| 577 |
+
"total patches: 300\n",
|
| 578 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/96.png\n",
|
| 579 |
+
"image shape: (4608, 3456, 3)\n",
|
| 580 |
+
"total patches: 300\n",
|
| 581 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/98.png\n",
|
| 582 |
+
"image shape: (4608, 3456, 3)\n",
|
| 583 |
+
"total patches: 300\n",
|
| 584 |
+
"/scratch/wej36how/Datasets/NWRD/train/masks/99.png\n",
|
| 585 |
+
"image shape: (4608, 3456, 3)\n",
|
| 586 |
+
"total patches: 300\n",
|
| 587 |
+
"total masks count: 28523\n"
|
| 588 |
+
]
|
| 589 |
+
}
|
| 590 |
+
],
|
| 591 |
+
"source": [
|
| 592 |
+
"masks_paths = glob.glob(f'{os.path.join(source, \"masks\", \"*\")}')\n",
|
| 593 |
+
"images_paths = glob.glob(f'{os.path.join(source, \"images\", \"*\")}')\n",
|
| 594 |
+
"images_paths.sort()\n",
|
| 595 |
+
"masks_paths.sort()\n",
|
| 596 |
+
"\n",
|
| 597 |
+
"\n",
|
| 598 |
+
"os.makedirs(patches_path)\n",
|
| 599 |
+
"os.makedirs(images_dir)\n",
|
| 600 |
+
"os.makedirs(masks_dir)\n",
|
| 601 |
+
"\n",
|
| 602 |
+
"def create_patches(fname):\n",
|
| 603 |
+
" x = 0\n",
|
| 604 |
+
" y = 0\n",
|
| 605 |
+
" patches = []\n",
|
| 606 |
+
" img = cv2.imread(fname)\n",
|
| 607 |
+
" print(\"image shape:\",img.shape)\n",
|
| 608 |
+
" p_num = 0\n",
|
| 609 |
+
" while (y + patch_size < img.shape[0]):\n",
|
| 610 |
+
" \n",
|
| 611 |
+
" if (x + patch_size > img.shape[1]):\n",
|
| 612 |
+
" x = 0\n",
|
| 613 |
+
" y += patch_size\n",
|
| 614 |
+
" if y + patch_size <= img.shape[0] and x + patch_size <= img.shape[1]:\n",
|
| 615 |
+
" patches.append([x, y])\n",
|
| 616 |
+
" x += patch_size\n",
|
| 617 |
+
" print(\"total patches: \", len(patches))\n",
|
| 618 |
+
" return patches\n",
|
| 619 |
+
"\n",
|
| 620 |
+
"total_count = 0\n",
|
| 621 |
+
"for u in images_paths:\n",
|
| 622 |
+
" print(u)\n",
|
| 623 |
+
" patches = create_patches(u)\n",
|
| 624 |
+
" bgr = cv2.imread(u)\n",
|
| 625 |
+
" image_name = u.split('/')[-1].split('.')[0]\n",
|
| 626 |
+
" total_count += len(patches)\n",
|
| 627 |
+
"\n",
|
| 628 |
+
" for count, P in enumerate(patches):\n",
|
| 629 |
+
" cv2.imwrite(os.path.join(images_dir,f\"{image_name}_{count}.png\"), bgr[P[1]:P[1]+patch_size,P[0]:P[0]+patch_size])\n",
|
| 630 |
+
" \n",
|
| 631 |
+
"print(\"total image count:\", total_count)\n",
|
| 632 |
+
"\n",
|
| 633 |
+
"total_count = 0\n",
|
| 634 |
+
"for u in masks_paths:\n",
|
| 635 |
+
" print(u)\n",
|
| 636 |
+
" patches = create_patches(u)\n",
|
| 637 |
+
" bgr = cv2.imread(u)\n",
|
| 638 |
+
" image_name = u.split('/')[-1].split('.')[0]\n",
|
| 639 |
+
"\n",
|
| 640 |
+
" total_count += len(patches)\n",
|
| 641 |
+
"\n",
|
| 642 |
+
" for count, P in enumerate(patches):\n",
|
| 643 |
+
" cv2.imwrite(os.path.join(masks_dir,f\"{image_name}_{count}.png\"), bgr[P[1]:P[1]+patch_size,P[0]:P[0]+patch_size])\n",
|
| 644 |
+
" \n",
|
| 645 |
+
"print(\"total masks count:\", total_count)"
|
| 646 |
+
]
|
| 647 |
+
},
|
| 648 |
+
{
|
| 649 |
+
"cell_type": "markdown",
|
| 650 |
+
"metadata": {},
|
| 651 |
+
"source": [
|
| 652 |
+
"This will saperate the rust and non rust patches and put them in and put them in directory destination/RustNonRustSplit"
|
| 653 |
+
]
|
| 654 |
+
},
|
| 655 |
+
{
|
| 656 |
+
"cell_type": "code",
|
| 657 |
+
"execution_count": 4,
|
| 658 |
+
"metadata": {},
|
| 659 |
+
"outputs": [],
|
| 660 |
+
"source": [
|
| 661 |
+
"destination = os.path.join(dest, \"RustNonRustSplit\")\n",
|
| 662 |
+
"root = patches_path\n",
|
| 663 |
+
"\n",
|
| 664 |
+
"os.makedirs(destination)\n",
|
| 665 |
+
"os.makedirs(os.path.join(destination,\"non_rust\",\"images\"))\n",
|
| 666 |
+
"os.makedirs(os.path.join(destination,\"non_rust\",\"masks\"))\n",
|
| 667 |
+
"os.makedirs(os.path.join(destination,\"rust\",\"images\"))\n",
|
| 668 |
+
"os.makedirs(os.path.join(destination,\"rust\",\"masks\"))\n",
|
| 669 |
+
"\n",
|
| 670 |
+
"masks_path = os.path.join(root, \"masks\", \"*.png\")\n",
|
| 671 |
+
"masks_paths = glob.glob(masks_path)\n",
|
| 672 |
+
"minimum=1000\n",
|
| 673 |
+
"min_patch=0\n",
|
| 674 |
+
"rust_count=0\n",
|
| 675 |
+
"non_rust_count=0\n",
|
| 676 |
+
"\n",
|
| 677 |
+
"for mask_path in masks_paths:\n",
|
| 678 |
+
" patch_name = mask_path.split(\"/\")[-1].split(\".\")[0]\n",
|
| 679 |
+
" \n",
|
| 680 |
+
" patch_mask = cv2.imread(mask_path, 0)\n",
|
| 681 |
+
" patch_img = cv2.imread(os.path.join(root, \"images\",patch_name+\".png\"))\n",
|
| 682 |
+
"\n",
|
| 683 |
+
" condition = (patch_mask > 150)\n",
|
| 684 |
+
" count = np.sum(condition)\n",
|
| 685 |
+
" \n",
|
| 686 |
+
" if count<=rust_threshold:\n",
|
| 687 |
+
" cv2.imwrite(os.path.join(destination,\"non_rust\",\"images\",f\"{patch_name}.png\"), patch_img)\n",
|
| 688 |
+
" cv2.imwrite(os.path.join(destination,\"non_rust\",\"masks\",f\"{patch_name}.png\"), patch_mask)\n",
|
| 689 |
+
" non_rust_count+=1\n",
|
| 690 |
+
" else:\n",
|
| 691 |
+
" if (count<=minimum):\n",
|
| 692 |
+
" minimum=count\n",
|
| 693 |
+
" min_patch = patch_name\n",
|
| 694 |
+
" cv2.imwrite(os.path.join(destination,\"rust\",\"images\",f\"{patch_name}.png\"), patch_img)\n",
|
| 695 |
+
" cv2.imwrite(os.path.join(destination,\"rust\",\"masks\",f\"{patch_name}.png\"), patch_mask)\n",
|
| 696 |
+
" rust_count+=1\n",
|
| 697 |
+
"\n",
|
| 698 |
+
"print(\"minimum rust patch:\",min_patch)\n",
|
| 699 |
+
"print(\"minimum rust patch white pixels:\",minimum)\n",
|
| 700 |
+
"print(\"rust count=\", rust_count)\n",
|
| 701 |
+
"print(\"non rust count=\", non_rust_count)"
|
| 702 |
+
]
|
| 703 |
+
},
|
| 704 |
+
{
|
| 705 |
+
"cell_type": "markdown",
|
| 706 |
+
"metadata": {},
|
| 707 |
+
"source": [
|
| 708 |
+
"Run the next two code snippets for training only. The following code will augment the images in the destination/RustNonRustSplit/images and destination/RustNonRustSplit/masks folder. "
|
| 709 |
+
]
|
| 710 |
+
},
|
| 711 |
+
{
|
| 712 |
+
"cell_type": "code",
|
| 713 |
+
"execution_count": null,
|
| 714 |
+
"metadata": {},
|
| 715 |
+
"outputs": [],
|
| 716 |
+
"source": [
|
| 717 |
+
"#flip images horizontally\n",
|
| 718 |
+
"def flip_images_hor(input_image):\n",
|
| 719 |
+
" # Iterate over the images in the input directory\n",
|
| 720 |
+
" transform_hflip = T.RandomHorizontalFlip(p=1.0) # Set probability to 1.0 to always flip\n",
|
| 721 |
+
" return transform_hflip(input_image)\n",
|
| 722 |
+
"\n",
|
| 723 |
+
"#flip images vertically\n",
|
| 724 |
+
"def flip_images_ver(input_image):\n",
|
| 725 |
+
" # Iterate over the images in the input directory\n",
|
| 726 |
+
" transform_vflip = T.RandomVerticalFlip(p=1.0) # Set probability to 1.0 to always flip\n",
|
| 727 |
+
" return transform_vflip(input_image) \n",
|
| 728 |
+
" \n",
|
| 729 |
+
"def shear_vertical(input_image, shear_factor=45):\n",
|
| 730 |
+
" # Apply vertical shear\n",
|
| 731 |
+
" sheared_image = TF.affine(input_image, angle=0, translate=(0, 0), scale=1, shear=(0, shear_factor))\n",
|
| 732 |
+
" return sheared_image\n",
|
| 733 |
+
"\n",
|
| 734 |
+
"def shear_horizontal(input_image, shear_factor=45): # Increased shear for testing\n",
|
| 735 |
+
" sheared_image = TF.affine(input_image, angle=0, translate=(0, 0), scale=1, shear=(shear_factor, 0))\n",
|
| 736 |
+
" return sheared_image\n",
|
| 737 |
+
"\n",
|
| 738 |
+
"def rotate_images(input_image, angle=45):\n",
|
| 739 |
+
" # Convert PIL Image to NumPy array\n",
|
| 740 |
+
" input_array = np.array(input_image)\n",
|
| 741 |
+
" # Rotate the image\n",
|
| 742 |
+
" height, width = input_array.shape[:2]\n",
|
| 743 |
+
" rotation_matrix = cv2.getRotationMatrix2D((width / 2, height / 2), angle, 1)\n",
|
| 744 |
+
" rotated_array = cv2.warpAffine(input_array, rotation_matrix, (width, height))\n",
|
| 745 |
+
" # Convert NumPy array back to PIL Image\n",
|
| 746 |
+
" rotated_image = Image.fromarray(rotated_array)\n",
|
| 747 |
+
" return rotated_image\n",
|
| 748 |
+
"\n",
|
| 749 |
+
"def dark(input_image,gamma):\n",
|
| 750 |
+
" dark_image= TF.adjust_gamma(input_image, gamma)\n",
|
| 751 |
+
" return dark_image\n",
|
| 752 |
+
"\n",
|
| 753 |
+
"def augment_image(img_path):\n",
|
| 754 |
+
"\n",
|
| 755 |
+
" # Apply the transformations\n",
|
| 756 |
+
" \n",
|
| 757 |
+
" #orig_image\n",
|
| 758 |
+
" orig_img = Image.open(Path(img_path))\n",
|
| 759 |
+
" \n",
|
| 760 |
+
" #flip images\n",
|
| 761 |
+
" img_hflipped = flip_images_hor(orig_img)\n",
|
| 762 |
+
" img_vflipped = flip_images_ver(orig_img)\n",
|
| 763 |
+
" \n",
|
| 764 |
+
" \n",
|
| 765 |
+
" #shear images\n",
|
| 766 |
+
" hor_shear = shear_horizontal(orig_img)\n",
|
| 767 |
+
" ver_shear = shear_vertical(orig_img)\n",
|
| 768 |
+
" \n",
|
| 769 |
+
" #dark\n",
|
| 770 |
+
" img_dark = dark(img_hflipped, 2)\n",
|
| 771 |
+
" img_rot = rotate_images(orig_img, angle=45)\n",
|
| 772 |
+
" \n",
|
| 773 |
+
" return [img_dark,img_hflipped,img_vflipped,hor_shear,ver_shear, img_rot]\n",
|
| 774 |
+
"\n",
|
| 775 |
+
"def creating_file_with_augmented_images(file_path_master_dataset, file_path_augmented_images):\n",
|
| 776 |
+
" master_dataset_folder = file_path_master_dataset\n",
|
| 777 |
+
" files_in_master_dataset = os.listdir(file_path_master_dataset)\n",
|
| 778 |
+
" augmented_images_folder = file_path_augmented_images\n",
|
| 779 |
+
" \n",
|
| 780 |
+
" for image_name in files_in_master_dataset:\n",
|
| 781 |
+
" image_path = os.path.join(master_dataset_folder, image_name)\n",
|
| 782 |
+
" required_images = augment_image(image_path) # Assuming augment_image is defined elsewhere\n",
|
| 783 |
+
" i = 0\n",
|
| 784 |
+
" for augmented_image in required_images:\n",
|
| 785 |
+
" # Convert RGBA to RGB if necessary\n",
|
| 786 |
+
" if augmented_image.mode == 'RGBA':\n",
|
| 787 |
+
" augmented_image = augmented_image.convert('RGB')\n",
|
| 788 |
+
" \n",
|
| 789 |
+
" # Save as png\n",
|
| 790 |
+
" augmented_image_path = os.path.join(augmented_images_folder, f\"aug{i}_{image_name}\")\n",
|
| 791 |
+
" augmented_image.save(augmented_image_path, format='png')\n",
|
| 792 |
+
" i += 1\n",
|
| 793 |
+
"\n",
|
| 794 |
+
"master_dataset = os.path.join(destination,\"rust\",\"images\")\n",
|
| 795 |
+
"augmented_dataset = os.path.join(destination,\"rust\",\"images\")\n",
|
| 796 |
+
"creating_file_with_augmented_images(master_dataset,augmented_dataset)\n",
|
| 797 |
+
"\n",
|
| 798 |
+
"master_dataset = os.path.join(destination,\"rust\",\"masks\")\n",
|
| 799 |
+
"augmented_dataset = os.path.join(destination,\"rust\",\"masks\")\n",
|
| 800 |
+
"creating_file_with_augmented_images(master_dataset,augmented_dataset)"
|
| 801 |
+
]
|
| 802 |
+
},
|
| 803 |
+
{
|
| 804 |
+
"cell_type": "markdown",
|
| 805 |
+
"metadata": {},
|
| 806 |
+
"source": [
|
| 807 |
+
"Run next snippet only for training dataset. To remove patches that have their rust removed becuase of their augmentations"
|
| 808 |
+
]
|
| 809 |
+
},
|
| 810 |
+
{
|
| 811 |
+
"cell_type": "code",
|
| 812 |
+
"execution_count": null,
|
| 813 |
+
"metadata": {},
|
| 814 |
+
"outputs": [
|
| 815 |
+
{
|
| 816 |
+
"name": "stdout",
|
| 817 |
+
"output_type": "stream",
|
| 818 |
+
"text": [
|
| 819 |
+
"minimum rust patch: 0\n",
|
| 820 |
+
"minimum rust patch white pixels: 1000\n",
|
| 821 |
+
"rust count= 0\n",
|
| 822 |
+
"non rust count= 0\n"
|
| 823 |
+
]
|
| 824 |
+
}
|
| 825 |
+
],
|
| 826 |
+
"source": [
|
| 827 |
+
"root = os.path.join(destination,\"rust\")\n",
|
| 828 |
+
"\n",
|
| 829 |
+
"non_rust_images_dir = os.path.join(destination,\"non_rust\",\"images\")\n",
|
| 830 |
+
"non_rust_masks_dir = os.path.join(destination,\"non_rust\",\"masks\")\n",
|
| 831 |
+
"\n",
|
| 832 |
+
"masks_path = os.path.join(root, \"masks\", \"*.png\")\n",
|
| 833 |
+
"masks_paths = glob.glob(masks_path)\n",
|
| 834 |
+
"minimum=1000\n",
|
| 835 |
+
"min_patch=0\n",
|
| 836 |
+
"rust_count=0\n",
|
| 837 |
+
"non_rust_count=0\n",
|
| 838 |
+
"\n",
|
| 839 |
+
"for mask_path in masks_paths:\n",
|
| 840 |
+
" patch_name = mask_path.split(\"/\")[-1].split(\".\")[0]\n",
|
| 841 |
+
" \n",
|
| 842 |
+
" patch_mask = cv2.imread(mask_path, 0)\n",
|
| 843 |
+
" patch_img = cv2.imread(os.path.join(root, \"images\",patch_name+\".png\"))\n",
|
| 844 |
+
"\n",
|
| 845 |
+
" condition = (patch_mask > 150)\n",
|
| 846 |
+
" count = np.sum(condition)\n",
|
| 847 |
+
" \n",
|
| 848 |
+
" if count<=rust_threshold:\n",
|
| 849 |
+
" os.remove(mask_path)\n",
|
| 850 |
+
" os.remove(os.path.join(root, \"images\",patch_name+\".png\"))\n",
|
| 851 |
+
"\n",
|
| 852 |
+
" cv2.imwrite(os.path.join(non_rust_images_dir,f\"{patch_name}.png\"), patch_img)\n",
|
| 853 |
+
" cv2.imwrite(os.path.join(non_rust_masks_dir,f\"{patch_name}.png\"), patch_mask)\n",
|
| 854 |
+
" non_rust_count+=1\n",
|
| 855 |
+
" else:\n",
|
| 856 |
+
" if (count<=minimum):\n",
|
| 857 |
+
" minimum=count\n",
|
| 858 |
+
" min_patch = patch_name\n",
|
| 859 |
+
" # cv2.imwrite(os.path.join(destination,\"rust\",\"images\",f\"{patch_name}.png\"), patch_img)\n",
|
| 860 |
+
" # cv2.imwrite(os.path.join(destination,\"rust\",\"masks\",f\"{patch_name}.png\"), patch_mask)\n",
|
| 861 |
+
" rust_count+=1\n",
|
| 862 |
+
"\n",
|
| 863 |
+
"print(\"minimum rust patch:\",min_patch)\n",
|
| 864 |
+
"print(\"minimum rust patch white pixels:\",minimum)\n",
|
| 865 |
+
"print(\"rust count=\", rust_count)\n",
|
| 866 |
+
"print(\"non rust count=\", non_rust_count)"
|
| 867 |
+
]
|
| 868 |
+
},
|
| 869 |
+
{
|
| 870 |
+
"cell_type": "markdown",
|
| 871 |
+
"metadata": {},
|
| 872 |
+
"source": [
|
| 873 |
+
"Create a dataset for classification model"
|
| 874 |
+
]
|
| 875 |
+
},
|
| 876 |
+
{
|
| 877 |
+
"cell_type": "code",
|
| 878 |
+
"execution_count": null,
|
| 879 |
+
"metadata": {},
|
| 880 |
+
"outputs": [
|
| 881 |
+
{
|
| 882 |
+
"data": {
|
| 883 |
+
"text/plain": [
|
| 884 |
+
"'C:\\\\Users\\\\hasee\\\\Desktop\\\\Germany_2024\\\\Dataset\\\\NWRDFprocessed\\\\train\\\\calssification\\\\non_rust'"
|
| 885 |
+
]
|
| 886 |
+
},
|
| 887 |
+
"execution_count": 7,
|
| 888 |
+
"metadata": {},
|
| 889 |
+
"output_type": "execute_result"
|
| 890 |
+
}
|
| 891 |
+
],
|
| 892 |
+
"source": [
|
| 893 |
+
"rust_images_dir = os.path.join(destination,\"rust\",\"images\")\n",
|
| 894 |
+
"non_rust_images_dir = os.path.join(destination,\"non_rust\",\"images\")\n",
|
| 895 |
+
"\n",
|
| 896 |
+
"rustClassificationDir = os.path.join(dest, \"calssification\", \"rust\")\n",
|
| 897 |
+
"nonRustClassificationDir = os.path.join(dest, \"calssification\", \"non_rust\")\n",
|
| 898 |
+
"os.makedirs(rustClassificationDir, exist_ok=True)\n",
|
| 899 |
+
"os.makedirs(nonRustClassificationDir, exist_ok=True)\n",
|
| 900 |
+
"\n",
|
| 901 |
+
"shutil.copytree(rust_images_dir,rustClassificationDir, dirs_exist_ok=True)\n",
|
| 902 |
+
"shutil.copytree(non_rust_images_dir,nonRustClassificationDir, dirs_exist_ok=True)\n"
|
| 903 |
+
]
|
| 904 |
+
},
|
| 905 |
+
{
|
| 906 |
+
"cell_type": "markdown",
|
| 907 |
+
"metadata": {},
|
| 908 |
+
"source": [
|
| 909 |
+
"Run the next code snippet for training dataset only. It deletes non-rust patches to match rust patches in the classification folder only."
|
| 910 |
+
]
|
| 911 |
+
},
|
| 912 |
+
{
|
| 913 |
+
"cell_type": "code",
|
| 914 |
+
"execution_count": null,
|
| 915 |
+
"metadata": {},
|
| 916 |
+
"outputs": [],
|
| 917 |
+
"source": [
|
| 918 |
+
"import os\n",
|
| 919 |
+
"import glob\n",
|
| 920 |
+
"\n",
|
| 921 |
+
"def delete_extra_images(directory, target_count):\n",
|
| 922 |
+
" # Get a list of all image files in the directory\n",
|
| 923 |
+
" image_files = glob.glob(os.path.join(directory, '*.JPG')) + glob.glob(os.path.join(directory, '*.jpeg')) + glob.glob(os.path.join(directory, '*.png'))\n",
|
| 924 |
+
" \n",
|
| 925 |
+
" # Check if the number of images exceeds the target count\n",
|
| 926 |
+
" if len(image_files) > target_count:\n",
|
| 927 |
+
" # Calculate the number of images to delete\n",
|
| 928 |
+
" num_to_delete = len(image_files) - target_count\n",
|
| 929 |
+
" # Sort the images by modification time (oldest first)\n",
|
| 930 |
+
" image_files.sort(key=os.path.getmtime)\n",
|
| 931 |
+
" # Delete the extra images\n",
|
| 932 |
+
" for i in range(num_to_delete):\n",
|
| 933 |
+
" os.remove(image_files[i])\n",
|
| 934 |
+
" print(f\"{num_to_delete} images deleted.\")\n",
|
| 935 |
+
" elif len(image_files) < target_count:\n",
|
| 936 |
+
" print(\"Warning: Number of images in directory is less than the target count.\")\n",
|
| 937 |
+
"\n",
|
| 938 |
+
"if len(os.listdir(rustClassificationDir))< len(os.listdir(nonRustClassificationDir)):\n",
|
| 939 |
+
" delete_extra_images(nonRustClassificationDir, len(os.listdir(rustClassificationDir)))\n",
|
| 940 |
+
"else:\n",
|
| 941 |
+
" delete_extra_images(rustClassificationDir, len(os.listdir(nonRustClassificationDir)))\n"
|
| 942 |
+
]
|
| 943 |
+
},
|
| 944 |
+
{
|
| 945 |
+
"cell_type": "markdown",
|
| 946 |
+
"metadata": {},
|
| 947 |
+
"source": [
|
| 948 |
+
"The following code creates a coslaiency style structure for co-saliency models training"
|
| 949 |
+
]
|
| 950 |
+
},
|
| 951 |
+
{
|
| 952 |
+
"cell_type": "code",
|
| 953 |
+
"execution_count": null,
|
| 954 |
+
"metadata": {},
|
| 955 |
+
"outputs": [],
|
| 956 |
+
"source": [
|
| 957 |
+
"rust_dir = os.path.join(destination,\"rust\")\n",
|
| 958 |
+
"rustCosaliencynDir = os.path.join(dest, \"cosaliency\")\n",
|
| 959 |
+
"shutil.copytree(rust_dir,rustCosaliencynDir, dirs_exist_ok=True)\n",
|
| 960 |
+
"\n",
|
| 961 |
+
"# Function to split images into folders based on image number\n",
|
| 962 |
+
"def split_images_into_folders(source_dir, destination_dir):\n",
|
| 963 |
+
" # Create destination directory if it doesn't exist\n",
|
| 964 |
+
" if not os.path.exists(destination_dir):\n",
|
| 965 |
+
" os.makedirs(destination_dir)\n",
|
| 966 |
+
" # Iterate through files in the source directory\n",
|
| 967 |
+
" for filename in os.listdir(source_dir):\n",
|
| 968 |
+
" if filename.endswith('.png'):\n",
|
| 969 |
+
" image_no = filename.split('_')[0] # Extract image number from filename\n",
|
| 970 |
+
" if not image_no.isdigit():\n",
|
| 971 |
+
" image_no = filename.split('_')[1]\n",
|
| 972 |
+
" destination_subdir = os.path.join(destination_dir, image_no)\n",
|
| 973 |
+
" # Create subdirectory if it doesn't exist\n",
|
| 974 |
+
" if not os.path.exists(destination_subdir):\n",
|
| 975 |
+
" os.makedirs(destination_subdir)\n",
|
| 976 |
+
" # Move the image file to the respective subdirectory\n",
|
| 977 |
+
" shutil.move(os.path.join(source_dir, filename), destination_subdir)\n",
|
| 978 |
+
"\n",
|
| 979 |
+
"def organize_images(main_directory):\n",
|
| 980 |
+
" # Ensure the main directory exists\n",
|
| 981 |
+
" if not os.path.exists(main_directory):\n",
|
| 982 |
+
" print(f\"The specified main directory '{main_directory}' does not exist.\")\n",
|
| 983 |
+
" return\n",
|
| 984 |
+
"\n",
|
| 985 |
+
" # Get a list of subdirectories in the main directory\n",
|
| 986 |
+
" subdirectories = [d for d in os.listdir(main_directory) if os.path.isdir(os.path.join(main_directory, d))]\n",
|
| 987 |
+
"\n",
|
| 988 |
+
" # Process each subdirectory\n",
|
| 989 |
+
" for subdir in subdirectories:\n",
|
| 990 |
+
" subdir_path = os.path.join(main_directory, subdir)\n",
|
| 991 |
+
"\n",
|
| 992 |
+
" # Get a list of images in the subdirectory\n",
|
| 993 |
+
" images = [f for f in os.listdir(subdir_path) if f.lower().endswith(('.jpg', '.jpeg', '.png', '.gif'))]\n",
|
| 994 |
+
" # Determine the number of images per subdirectory\n",
|
| 995 |
+
" images_per_subdir = 12\n",
|
| 996 |
+
" num_subdirectories = len(images) // images_per_subdir\n",
|
| 997 |
+
" n=0\n",
|
| 998 |
+
" # Create additional subdirectories if needed\n",
|
| 999 |
+
" for i in range(num_subdirectories - 1):\n",
|
| 1000 |
+
" new_subdir_name = f\"{subdir}_part{i + 1}\"\n",
|
| 1001 |
+
" new_subdir_path = os.path.join(main_directory, new_subdir_name)\n",
|
| 1002 |
+
"\n",
|
| 1003 |
+
" # Create the new subdirectory\n",
|
| 1004 |
+
" os.makedirs(new_subdir_path)\n",
|
| 1005 |
+
"\n",
|
| 1006 |
+
" # Move images to the new subdirectory\n",
|
| 1007 |
+
" for j in range(images_per_subdir):\n",
|
| 1008 |
+
" old_image_path = os.path.join(subdir_path, images[n])\n",
|
| 1009 |
+
" new_image_path = os.path.join(new_subdir_path, images[n])\n",
|
| 1010 |
+
" shutil.move(old_image_path, new_image_path)\n",
|
| 1011 |
+
" n+=1\n",
|
| 1012 |
+
"\n",
|
| 1013 |
+
"source_directory = os.path.join(dest, \"cosaliency\", \"images\")\n",
|
| 1014 |
+
"destination_directory = os.path.join(dest, \"cosaliency\", \"images\")\n",
|
| 1015 |
+
"split_images_into_folders(source_directory, destination_directory)\n",
|
| 1016 |
+
"organize_images(destination_directory)\n",
|
| 1017 |
+
"\n",
|
| 1018 |
+
"source_directory = os.path.join(dest, \"cosaliency\", \"masks\")\n",
|
| 1019 |
+
"destination_directory = os.path.join(dest, \"cosaliency\", \"masks\")\n",
|
| 1020 |
+
"split_images_into_folders(source_directory, destination_directory)\n",
|
| 1021 |
+
"organize_images(destination_directory)"
|
| 1022 |
+
]
|
| 1023 |
+
}
|
| 1024 |
+
],
|
| 1025 |
+
"metadata": {
|
| 1026 |
+
"kernelspec": {
|
| 1027 |
+
"display_name": "segformer",
|
| 1028 |
+
"language": "python",
|
| 1029 |
+
"name": "python3"
|
| 1030 |
+
},
|
| 1031 |
+
"language_info": {
|
| 1032 |
+
"codemirror_mode": {
|
| 1033 |
+
"name": "ipython",
|
| 1034 |
+
"version": 3
|
| 1035 |
+
},
|
| 1036 |
+
"file_extension": ".py",
|
| 1037 |
+
"mimetype": "text/x-python",
|
| 1038 |
+
"name": "python",
|
| 1039 |
+
"nbconvert_exporter": "python",
|
| 1040 |
+
"pygments_lexer": "ipython3",
|
| 1041 |
+
"version": "3.8.0"
|
| 1042 |
+
}
|
| 1043 |
+
},
|
| 1044 |
+
"nbformat": 4,
|
| 1045 |
+
"nbformat_minor": 2
|
| 1046 |
+
}
|
evaluator.cpython-37.pyc
ADDED
|
Binary file (12.3 kB). View file
|
|
|
evaluator.py
ADDED
|
@@ -0,0 +1,490 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from scipy.io import savemat
|
| 7 |
+
import torch
|
| 8 |
+
from torchvision import transforms
|
| 9 |
+
|
| 10 |
+
from PIL import ImageFile
|
| 11 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Eval_thread():
|
| 15 |
+
def __init__(self, loader, method='', dataset='', output_dir='', epoch='', cuda=True):
|
| 16 |
+
self.loader = loader
|
| 17 |
+
self.method = method
|
| 18 |
+
self.dataset = dataset
|
| 19 |
+
self.cuda = cuda
|
| 20 |
+
self.output_dir = output_dir
|
| 21 |
+
self.epoch = epoch.split('ep')[-1]
|
| 22 |
+
self.logfile = os.path.join(output_dir, 'result.txt')
|
| 23 |
+
self.dataset2smeasure_bottom_bound = {'CoCA': 0.673, 'CoSOD3k': 0.802, 'CoSal2015': 0.845} # S_measures of GCoNet
|
| 24 |
+
|
| 25 |
+
def run(self, AP=False, AUC=False, save_metrics=False, continue_eval=True):
|
| 26 |
+
Res = {}
|
| 27 |
+
start_time = time.time()
|
| 28 |
+
|
| 29 |
+
if continue_eval:
|
| 30 |
+
s = self.Eval_Smeasure()
|
| 31 |
+
if s > self.dataset2smeasure_bottom_bound[self.dataset]:
|
| 32 |
+
mae = self.Eval_mae()
|
| 33 |
+
Em = self.Eval_Emeasure()
|
| 34 |
+
max_e = Em.max().item()
|
| 35 |
+
mean_e = Em.mean().item()
|
| 36 |
+
Em = Em.cpu().numpy()
|
| 37 |
+
Fm, prec, recall = self.Eval_fmeasure()
|
| 38 |
+
max_f = Fm.max().item()
|
| 39 |
+
mean_f = Fm.mean().item()
|
| 40 |
+
Fm = Fm.cpu().numpy()
|
| 41 |
+
else:
|
| 42 |
+
mae = 1
|
| 43 |
+
Em = torch.zeros(255).cpu().numpy()
|
| 44 |
+
max_e = 0
|
| 45 |
+
mean_e = 0
|
| 46 |
+
Fm, prec, recall = 0, 0, 0
|
| 47 |
+
max_f = 0
|
| 48 |
+
mean_f = 0
|
| 49 |
+
continue_eval = False
|
| 50 |
+
else:
|
| 51 |
+
s = 0
|
| 52 |
+
mae = 1
|
| 53 |
+
Em = torch.zeros(255).cpu().numpy()
|
| 54 |
+
max_e = 0
|
| 55 |
+
mean_e = 0
|
| 56 |
+
Fm, prec, recall = 0, 0, 0
|
| 57 |
+
max_f = 0
|
| 58 |
+
mean_f = 0
|
| 59 |
+
continue_eval = False
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
if AP:
|
| 63 |
+
prec = prec.cpu().numpy()
|
| 64 |
+
recall = recall.cpu().numpy()
|
| 65 |
+
avg_p = self.Eval_AP(prec, recall)
|
| 66 |
+
|
| 67 |
+
if AUC:
|
| 68 |
+
auc, TPR, FPR = self.Eval_auc()
|
| 69 |
+
TPR = TPR.cpu().numpy()
|
| 70 |
+
FPR = FPR.cpu().numpy()
|
| 71 |
+
|
| 72 |
+
if save_metrics:
|
| 73 |
+
os.makedirs(os.path.join(self.output_dir, self.method, self.epoch), exist_ok=True)
|
| 74 |
+
Res['Sm'] = s
|
| 75 |
+
if s > self.dataset2smeasure_bottom_bound[self.dataset]:
|
| 76 |
+
Res['MAE'] = mae
|
| 77 |
+
Res['MaxEm'] = max_e
|
| 78 |
+
Res['MeanEm'] = mean_e
|
| 79 |
+
Res['Em'] = Em
|
| 80 |
+
Res['Fm'] = Fm
|
| 81 |
+
else:
|
| 82 |
+
Res['MAE'] = 1
|
| 83 |
+
Res['MaxEm'] = 0
|
| 84 |
+
Res['MeanEm'] = 0
|
| 85 |
+
Res['Em'] = torch.zeros(255).cpu().numpy()
|
| 86 |
+
Res['Fm'] = 0
|
| 87 |
+
|
| 88 |
+
if AP:
|
| 89 |
+
Res['MaxFm'] = max_f
|
| 90 |
+
Res['MeanFm'] = mean_f
|
| 91 |
+
Res['AP'] = avg_p
|
| 92 |
+
Res['Prec'] = prec
|
| 93 |
+
Res['Recall'] = recall
|
| 94 |
+
|
| 95 |
+
if AUC:
|
| 96 |
+
Res['AUC'] = auc
|
| 97 |
+
Res['TPR'] = TPR
|
| 98 |
+
Res['FPR'] = FPR
|
| 99 |
+
|
| 100 |
+
os.makedirs(os.path.join(self.output_dir, self.method, self.epoch), exist_ok=True)
|
| 101 |
+
savemat(os.path.join(self.output_dir, self.method, self.epoch, self.dataset + '.mat'), Res)
|
| 102 |
+
|
| 103 |
+
info = '{} ({}): {:.4f} max-Emeasure || {:.4f} S-measure || {:.4f} max-fm || {:.4f} mae || {:.4f} mean-Emeasure || {:.4f} mean-fm'.format(
|
| 104 |
+
self.dataset, self.method+'-ep{}'.format(self.epoch), max_e, s, max_f, mae, mean_e, mean_f
|
| 105 |
+
)
|
| 106 |
+
if AP:
|
| 107 |
+
info += ' || {:.4f} AP'.format(avg_p)
|
| 108 |
+
if AUC:
|
| 109 |
+
info += ' || {:.4f} AUC'.format(auc)
|
| 110 |
+
info += '.'
|
| 111 |
+
self.LOG(info + '\n')
|
| 112 |
+
|
| 113 |
+
return '[cost:{:.4f}s] '.format(time.time() - start_time) + info, continue_eval
|
| 114 |
+
|
| 115 |
+
def Eval_mae(self):
|
| 116 |
+
if self.epoch:
|
| 117 |
+
print('Evaluating MAE...')
|
| 118 |
+
avg_mae, img_num = 0.0, 0.0
|
| 119 |
+
with torch.no_grad():
|
| 120 |
+
trans = transforms.Compose([transforms.ToTensor()])
|
| 121 |
+
for pred, gt in self.loader:
|
| 122 |
+
if self.cuda:
|
| 123 |
+
pred = trans(pred).cuda()
|
| 124 |
+
gt = trans(gt).cuda()
|
| 125 |
+
else:
|
| 126 |
+
pred = trans(pred)
|
| 127 |
+
gt = trans(gt)
|
| 128 |
+
mea = torch.abs(pred - gt).mean()
|
| 129 |
+
if mea == mea: # for Nan
|
| 130 |
+
avg_mae += mea
|
| 131 |
+
img_num += 1.0
|
| 132 |
+
avg_mae /= img_num
|
| 133 |
+
return avg_mae.item()
|
| 134 |
+
|
| 135 |
+
def Eval_fmeasure(self):
|
| 136 |
+
print('Evaluating FMeasure...')
|
| 137 |
+
beta2 = 0.3
|
| 138 |
+
avg_f, avg_p, avg_r, img_num = 0.0, 0.0, 0.0, 0.0
|
| 139 |
+
|
| 140 |
+
with torch.no_grad():
|
| 141 |
+
trans = transforms.Compose([transforms.ToTensor()])
|
| 142 |
+
for pred, gt in self.loader:
|
| 143 |
+
if self.cuda:
|
| 144 |
+
pred = trans(pred).cuda()
|
| 145 |
+
gt = trans(gt).cuda()
|
| 146 |
+
pred = (pred - torch.min(pred)) / (torch.max(pred) -
|
| 147 |
+
torch.min(pred) + 1e-20)
|
| 148 |
+
else:
|
| 149 |
+
pred = trans(pred)
|
| 150 |
+
pred = (pred - torch.min(pred)) / (torch.max(pred) -
|
| 151 |
+
torch.min(pred) + 1e-20)
|
| 152 |
+
gt = trans(gt)
|
| 153 |
+
prec, recall = self._eval_pr(pred, gt, 255)
|
| 154 |
+
f_score = (1 + beta2) * prec * recall / (beta2 * prec + recall)
|
| 155 |
+
f_score[f_score != f_score] = 0 # for Nan
|
| 156 |
+
avg_f += f_score
|
| 157 |
+
avg_p += prec
|
| 158 |
+
avg_r += recall
|
| 159 |
+
img_num += 1.0
|
| 160 |
+
Fm = avg_f / img_num
|
| 161 |
+
avg_p = avg_p / img_num
|
| 162 |
+
avg_r = avg_r / img_num
|
| 163 |
+
return Fm, avg_p, avg_r
|
| 164 |
+
|
| 165 |
+
def Eval_auc(self):
|
| 166 |
+
print('Evaluating AUC...')
|
| 167 |
+
|
| 168 |
+
avg_tpr, avg_fpr, avg_auc, img_num = 0.0, 0.0, 0.0, 0.0
|
| 169 |
+
|
| 170 |
+
with torch.no_grad():
|
| 171 |
+
trans = transforms.Compose([transforms.ToTensor()])
|
| 172 |
+
for pred, gt in self.loader:
|
| 173 |
+
if self.cuda:
|
| 174 |
+
pred = trans(pred).cuda()
|
| 175 |
+
pred = (pred - torch.min(pred)) / (torch.max(pred) -
|
| 176 |
+
torch.min(pred) + 1e-20)
|
| 177 |
+
gt = trans(gt).cuda()
|
| 178 |
+
else:
|
| 179 |
+
pred = trans(pred)
|
| 180 |
+
pred = (pred - torch.min(pred)) / (torch.max(pred) -
|
| 181 |
+
torch.min(pred) + 1e-20)
|
| 182 |
+
gt = trans(gt)
|
| 183 |
+
TPR, FPR = self._eval_roc(pred, gt, 255)
|
| 184 |
+
avg_tpr += TPR
|
| 185 |
+
avg_fpr += FPR
|
| 186 |
+
img_num += 1.0
|
| 187 |
+
avg_tpr = avg_tpr / img_num
|
| 188 |
+
avg_fpr = avg_fpr / img_num
|
| 189 |
+
|
| 190 |
+
sorted_idxes = torch.argsort(avg_fpr)
|
| 191 |
+
avg_tpr = avg_tpr[sorted_idxes]
|
| 192 |
+
avg_fpr = avg_fpr[sorted_idxes]
|
| 193 |
+
avg_auc = torch.trapz(avg_tpr, avg_fpr)
|
| 194 |
+
|
| 195 |
+
return avg_auc.item(), avg_tpr, avg_fpr
|
| 196 |
+
|
| 197 |
+
def Eval_Emeasure(self):
|
| 198 |
+
print('Evaluating EMeasure...')
|
| 199 |
+
avg_e, img_num = 0.0, 0.0
|
| 200 |
+
with torch.no_grad():
|
| 201 |
+
trans = transforms.Compose([transforms.ToTensor()])
|
| 202 |
+
Em = torch.zeros(255)
|
| 203 |
+
if self.cuda:
|
| 204 |
+
Em = Em.cuda()
|
| 205 |
+
for pred, gt in self.loader:
|
| 206 |
+
if self.cuda:
|
| 207 |
+
pred = trans(pred).cuda()
|
| 208 |
+
pred = (pred - torch.min(pred)) / (torch.max(pred) -
|
| 209 |
+
torch.min(pred) + 1e-20)
|
| 210 |
+
gt = trans(gt).cuda()
|
| 211 |
+
else:
|
| 212 |
+
pred = trans(pred)
|
| 213 |
+
pred = (pred - torch.min(pred)) / (torch.max(pred) -
|
| 214 |
+
torch.min(pred) + 1e-20)
|
| 215 |
+
gt = trans(gt)
|
| 216 |
+
Em += self._eval_e(pred, gt, 255)
|
| 217 |
+
img_num += 1.0
|
| 218 |
+
|
| 219 |
+
Em /= img_num
|
| 220 |
+
return Em
|
| 221 |
+
|
| 222 |
+
def select_by_Smeasure(self, bar=0.9, loader_comp=None, bar_comp=0.1):
|
| 223 |
+
print('Evaluating SMeasure...')
|
| 224 |
+
good_ones = []
|
| 225 |
+
good_ones_comp = []
|
| 226 |
+
good_ones_gt = []
|
| 227 |
+
alpha, avg_q, img_num = 0.5, 0.0, 0.0
|
| 228 |
+
with torch.no_grad():
|
| 229 |
+
trans = transforms.Compose([transforms.ToTensor()])
|
| 230 |
+
for (pred, gt, predpath, gtpath), (pred_comp, gt_comp, predpath_comp) in zip(self.loader, loader_comp):
|
| 231 |
+
# pred X gt
|
| 232 |
+
if self.cuda:
|
| 233 |
+
pred = trans(pred).cuda()
|
| 234 |
+
pred = (pred - torch.min(pred)) / (torch.max(pred) -
|
| 235 |
+
torch.min(pred) + 1e-20)
|
| 236 |
+
gt = trans(gt).cuda()
|
| 237 |
+
else:
|
| 238 |
+
pred = trans(pred)
|
| 239 |
+
pred = (pred - torch.min(pred)) / (torch.max(pred) -
|
| 240 |
+
torch.min(pred) + 1e-20)
|
| 241 |
+
gt = trans(gt)
|
| 242 |
+
y = gt.mean()
|
| 243 |
+
if y == 0:
|
| 244 |
+
x = pred.mean()
|
| 245 |
+
Q = 1.0 - x
|
| 246 |
+
elif y == 1:
|
| 247 |
+
x = pred.mean()
|
| 248 |
+
Q = x
|
| 249 |
+
else:
|
| 250 |
+
gt[gt >= 0.5] = 1
|
| 251 |
+
gt[gt < 0.5] = 0
|
| 252 |
+
Q = alpha * self._S_object(
|
| 253 |
+
pred, gt) + (1 - alpha) * self._S_region(pred, gt)
|
| 254 |
+
if Q.item() < 0:
|
| 255 |
+
Q = torch.FloatTensor([0.0])
|
| 256 |
+
img_num += 1.0
|
| 257 |
+
avg_q += Q.item()
|
| 258 |
+
# pred_comp X gt
|
| 259 |
+
if self.cuda:
|
| 260 |
+
pred_comp = trans(pred_comp).cuda()
|
| 261 |
+
pred_comp = (pred_comp - torch.min(pred_comp)) / (torch.max(pred_comp) -
|
| 262 |
+
torch.min(pred_comp) + 1e-20)
|
| 263 |
+
gt_comp = trans(gt_comp).cuda()
|
| 264 |
+
else:
|
| 265 |
+
pred_comp = trans(pred_comp)
|
| 266 |
+
pred_comp = (pred_comp - torch.min(pred_comp)) / (torch.max(pred_comp) -
|
| 267 |
+
torch.min(pred_comp) + 1e-20)
|
| 268 |
+
gt_comp = trans(gt_comp)
|
| 269 |
+
y = gt_comp.mean()
|
| 270 |
+
if y == 0:
|
| 271 |
+
x = pred_comp.mean()
|
| 272 |
+
Q_comp = 1.0 - x
|
| 273 |
+
elif y == 1:
|
| 274 |
+
x = pred_comp.mean()
|
| 275 |
+
Q_comp = x
|
| 276 |
+
else:
|
| 277 |
+
gt_comp[gt_comp >= 0.5] = 1
|
| 278 |
+
gt_comp[gt_comp < 0.5] = 0
|
| 279 |
+
Q_comp = alpha * self._S_object(
|
| 280 |
+
pred_comp, gt_comp) + (1 - alpha) * self._S_region(pred_comp, gt_comp)
|
| 281 |
+
if Q_comp.item() < 0:
|
| 282 |
+
Q_comp = torch.FloatTensor([0.0])
|
| 283 |
+
if Q.item() > bar and (Q.item() - Q_comp.item()) > bar_comp:
|
| 284 |
+
good_ones.append(predpath)
|
| 285 |
+
good_ones_comp.append(predpath_comp)
|
| 286 |
+
good_ones_gt.append(gtpath)
|
| 287 |
+
avg_q /= img_num
|
| 288 |
+
return avg_q, good_ones, good_ones_comp, good_ones_gt
|
| 289 |
+
|
| 290 |
+
def Eval_Smeasure(self):
|
| 291 |
+
print('Evaluating SMeasure...')
|
| 292 |
+
alpha, avg_q, img_num = 0.5, 0.0, 0.0
|
| 293 |
+
with torch.no_grad():
|
| 294 |
+
trans = transforms.Compose([transforms.ToTensor()])
|
| 295 |
+
for pred, gt in self.loader:
|
| 296 |
+
if self.cuda:
|
| 297 |
+
pred = trans(pred).cuda()
|
| 298 |
+
pred = (pred - torch.min(pred)) / (torch.max(pred) -
|
| 299 |
+
torch.min(pred) + 1e-20)
|
| 300 |
+
gt = trans(gt).cuda()
|
| 301 |
+
else:
|
| 302 |
+
pred = trans(pred)
|
| 303 |
+
pred = (pred - torch.min(pred)) / (torch.max(pred) -
|
| 304 |
+
torch.min(pred) + 1e-20)
|
| 305 |
+
gt = trans(gt)
|
| 306 |
+
y = gt.mean()
|
| 307 |
+
if y == 0:
|
| 308 |
+
x = pred.mean()
|
| 309 |
+
Q = 1.0 - x
|
| 310 |
+
elif y == 1:
|
| 311 |
+
x = pred.mean()
|
| 312 |
+
Q = x
|
| 313 |
+
else:
|
| 314 |
+
gt[gt >= 0.5] = 1
|
| 315 |
+
gt[gt < 0.5] = 0
|
| 316 |
+
Q = alpha * self._S_object(
|
| 317 |
+
pred, gt) + (1 - alpha) * self._S_region(pred, gt)
|
| 318 |
+
if Q.item() < 0:
|
| 319 |
+
Q = torch.FloatTensor([0.0])
|
| 320 |
+
img_num += 1.0
|
| 321 |
+
avg_q += Q.item()
|
| 322 |
+
avg_q /= img_num
|
| 323 |
+
return avg_q
|
| 324 |
+
|
| 325 |
+
def LOG(self, output):
|
| 326 |
+
os.makedirs(self.output_dir, exist_ok=True)
|
| 327 |
+
with open(self.logfile, 'a') as f:
|
| 328 |
+
f.write(output)
|
| 329 |
+
|
| 330 |
+
def _eval_e(self, y_pred, y, num):
|
| 331 |
+
if self.cuda:
|
| 332 |
+
score = torch.zeros(num).cuda()
|
| 333 |
+
thlist = torch.linspace(0, 1 - 1e-10, num).cuda()
|
| 334 |
+
else:
|
| 335 |
+
score = torch.zeros(num)
|
| 336 |
+
thlist = torch.linspace(0, 1 - 1e-10, num)
|
| 337 |
+
for i in range(num):
|
| 338 |
+
y_pred_th = (y_pred >= thlist[i]).float()
|
| 339 |
+
fm = y_pred_th - y_pred_th.mean()
|
| 340 |
+
gt = y - y.mean()
|
| 341 |
+
align_matrix = 2 * gt * fm / (gt * gt + fm * fm + 1e-20)
|
| 342 |
+
enhanced = ((align_matrix + 1) * (align_matrix + 1)) / 4
|
| 343 |
+
score[i] = torch.sum(enhanced) / (y.numel() - 1 + 1e-20)
|
| 344 |
+
return score
|
| 345 |
+
|
| 346 |
+
def _eval_pr(self, y_pred, y, num):
|
| 347 |
+
if self.cuda:
|
| 348 |
+
prec, recall = torch.zeros(num).cuda(), torch.zeros(num).cuda()
|
| 349 |
+
thlist = torch.linspace(0, 1 - 1e-10, num).cuda()
|
| 350 |
+
else:
|
| 351 |
+
prec, recall = torch.zeros(num), torch.zeros(num)
|
| 352 |
+
thlist = torch.linspace(0, 1 - 1e-10, num)
|
| 353 |
+
for i in range(num):
|
| 354 |
+
y_temp = (y_pred >= thlist[i]).float()
|
| 355 |
+
tp = (y_temp * y).sum()
|
| 356 |
+
prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / (y.sum() + 1e-20)
|
| 357 |
+
return prec, recall
|
| 358 |
+
|
| 359 |
+
def _eval_roc(self, y_pred, y, num):
|
| 360 |
+
if self.cuda:
|
| 361 |
+
TPR, FPR = torch.zeros(num).cuda(), torch.zeros(num).cuda()
|
| 362 |
+
thlist = torch.linspace(0, 1 - 1e-10, num).cuda()
|
| 363 |
+
else:
|
| 364 |
+
TPR, FPR = torch.zeros(num), torch.zeros(num)
|
| 365 |
+
thlist = torch.linspace(0, 1 - 1e-10, num)
|
| 366 |
+
for i in range(num):
|
| 367 |
+
y_temp = (y_pred >= thlist[i]).float()
|
| 368 |
+
tp = (y_temp * y).sum()
|
| 369 |
+
fp = (y_temp * (1 - y)).sum()
|
| 370 |
+
tn = ((1 - y_temp) * (1 - y)).sum()
|
| 371 |
+
fn = ((1 - y_temp) * y).sum()
|
| 372 |
+
|
| 373 |
+
TPR[i] = tp / (tp + fn + 1e-20)
|
| 374 |
+
FPR[i] = fp / (fp + tn + 1e-20)
|
| 375 |
+
|
| 376 |
+
return TPR, FPR
|
| 377 |
+
|
| 378 |
+
def _S_object(self, pred, gt):
|
| 379 |
+
fg = torch.where(gt == 0, torch.zeros_like(pred), pred)
|
| 380 |
+
bg = torch.where(gt == 1, torch.zeros_like(pred), 1 - pred)
|
| 381 |
+
o_fg = self._object(fg, gt)
|
| 382 |
+
o_bg = self._object(bg, 1 - gt)
|
| 383 |
+
u = gt.mean()
|
| 384 |
+
Q = u * o_fg + (1 - u) * o_bg
|
| 385 |
+
return Q
|
| 386 |
+
|
| 387 |
+
def _object(self, pred, gt):
|
| 388 |
+
temp = pred[gt == 1]
|
| 389 |
+
x = temp.mean()
|
| 390 |
+
sigma_x = temp.std()
|
| 391 |
+
score = 2.0 * x / (x * x + 1.0 + sigma_x + 1e-20)
|
| 392 |
+
|
| 393 |
+
return score
|
| 394 |
+
|
| 395 |
+
def _S_region(self, pred, gt):
|
| 396 |
+
X, Y = self._centroid(gt)
|
| 397 |
+
gt1, gt2, gt3, gt4, w1, w2, w3, w4 = self._divideGT(gt, X, Y)
|
| 398 |
+
p1, p2, p3, p4 = self._dividePrediction(pred, X, Y)
|
| 399 |
+
Q1 = self._ssim(p1, gt1)
|
| 400 |
+
Q2 = self._ssim(p2, gt2)
|
| 401 |
+
Q3 = self._ssim(p3, gt3)
|
| 402 |
+
Q4 = self._ssim(p4, gt4)
|
| 403 |
+
Q = w1 * Q1 + w2 * Q2 + w3 * Q3 + w4 * Q4
|
| 404 |
+
return Q
|
| 405 |
+
|
| 406 |
+
def _centroid(self, gt):
|
| 407 |
+
rows, cols = gt.size()[-2:]
|
| 408 |
+
gt = gt.view(rows, cols)
|
| 409 |
+
if gt.sum() == 0:
|
| 410 |
+
if self.cuda:
|
| 411 |
+
X = torch.eye(1).cuda() * round(cols / 2)
|
| 412 |
+
Y = torch.eye(1).cuda() * round(rows / 2)
|
| 413 |
+
else:
|
| 414 |
+
X = torch.eye(1) * round(cols / 2)
|
| 415 |
+
Y = torch.eye(1) * round(rows / 2)
|
| 416 |
+
else:
|
| 417 |
+
total = gt.sum()
|
| 418 |
+
if self.cuda:
|
| 419 |
+
i = torch.from_numpy(np.arange(0, cols)).cuda().float()
|
| 420 |
+
j = torch.from_numpy(np.arange(0, rows)).cuda().float()
|
| 421 |
+
else:
|
| 422 |
+
i = torch.from_numpy(np.arange(0, cols)).float()
|
| 423 |
+
j = torch.from_numpy(np.arange(0, rows)).float()
|
| 424 |
+
X = torch.round((gt.sum(dim=0) * i).sum() / total + 1e-20)
|
| 425 |
+
Y = torch.round((gt.sum(dim=1) * j).sum() / total + 1e-20)
|
| 426 |
+
return X.long(), Y.long()
|
| 427 |
+
|
| 428 |
+
def _divideGT(self, gt, X, Y):
|
| 429 |
+
h, w = gt.size()[-2:]
|
| 430 |
+
area = h * w
|
| 431 |
+
gt = gt.view(h, w)
|
| 432 |
+
LT = gt[:Y, :X]
|
| 433 |
+
RT = gt[:Y, X:w]
|
| 434 |
+
LB = gt[Y:h, :X]
|
| 435 |
+
RB = gt[Y:h, X:w]
|
| 436 |
+
X = X.float()
|
| 437 |
+
Y = Y.float()
|
| 438 |
+
w1 = X * Y / area
|
| 439 |
+
w2 = (w - X) * Y / area
|
| 440 |
+
w3 = X * (h - Y) / area
|
| 441 |
+
w4 = 1 - w1 - w2 - w3
|
| 442 |
+
return LT, RT, LB, RB, w1, w2, w3, w4
|
| 443 |
+
|
| 444 |
+
def _dividePrediction(self, pred, X, Y):
|
| 445 |
+
h, w = pred.size()[-2:]
|
| 446 |
+
pred = pred.view(h, w)
|
| 447 |
+
LT = pred[:Y, :X]
|
| 448 |
+
RT = pred[:Y, X:w]
|
| 449 |
+
LB = pred[Y:h, :X]
|
| 450 |
+
RB = pred[Y:h, X:w]
|
| 451 |
+
return LT, RT, LB, RB
|
| 452 |
+
|
| 453 |
+
def _ssim(self, pred, gt):
|
| 454 |
+
gt = gt.float()
|
| 455 |
+
h, w = pred.size()[-2:]
|
| 456 |
+
N = h * w
|
| 457 |
+
x = pred.mean()
|
| 458 |
+
y = gt.mean()
|
| 459 |
+
sigma_x2 = ((pred - x) * (pred - x)).sum() / (N - 1 + 1e-20)
|
| 460 |
+
sigma_y2 = ((gt - y) * (gt - y)).sum() / (N - 1 + 1e-20)
|
| 461 |
+
sigma_xy = ((pred - x) * (gt - y)).sum() / (N - 1 + 1e-20)
|
| 462 |
+
|
| 463 |
+
aplha = 4 * x * y * sigma_xy
|
| 464 |
+
beta = (x * x + y * y) * (sigma_x2 + sigma_y2)
|
| 465 |
+
|
| 466 |
+
if aplha != 0:
|
| 467 |
+
Q = aplha / (beta + 1e-20)
|
| 468 |
+
elif aplha == 0 and beta == 0:
|
| 469 |
+
Q = 1.0
|
| 470 |
+
else:
|
| 471 |
+
Q = 0
|
| 472 |
+
return Q
|
| 473 |
+
|
| 474 |
+
def Eval_AP(self, prec, recall):
|
| 475 |
+
# Ref:
|
| 476 |
+
# https://github.com/facebookresearch/Detectron/blob/05d04d3a024f0991339de45872d02f2f50669b3d/lib/datasets/voc_eval.py#L54
|
| 477 |
+
print('Evaluating AP...')
|
| 478 |
+
ap_r = np.concatenate(([0.], recall, [1.]))
|
| 479 |
+
ap_p = np.concatenate(([0.], prec, [0.]))
|
| 480 |
+
sorted_idxes = np.argsort(ap_r)
|
| 481 |
+
ap_r = ap_r[sorted_idxes]
|
| 482 |
+
ap_p = ap_p[sorted_idxes]
|
| 483 |
+
count = ap_r.shape[0]
|
| 484 |
+
|
| 485 |
+
for i in range(count - 1, 0, -1):
|
| 486 |
+
ap_p[i - 1] = max(ap_p[i], ap_p[i - 1])
|
| 487 |
+
|
| 488 |
+
i = np.where(ap_r[1:] != ap_r[:-1])[0]
|
| 489 |
+
ap = np.sum((ap_r[i + 1] - ap_r[i]) * ap_p[i + 1])
|
| 490 |
+
return ap
|
hist_of_pixel_values.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
root_dir = os.path.join([rd for rd in os.listdir('.') if 'gconet_' in rd][0], 'CoCA/Accordion')
|
| 8 |
+
image_paths = [os.path.join(root_dir, p) for p in os.listdir(root_dir)]
|
| 9 |
+
pixel_values = []
|
| 10 |
+
for image_path in image_paths:
|
| 11 |
+
image = cv2.imread(image_path)
|
| 12 |
+
pixel_value = image.flatten().squeeze().tolist()
|
| 13 |
+
pixel_values += pixel_value
|
| 14 |
+
|
| 15 |
+
pixel_values = np.array(pixel_values)
|
| 16 |
+
|
| 17 |
+
non_zero_values = pixel_values[pixel_values >= 0]
|
| 18 |
+
margin_values_percent = (np.sum(non_zero_values > 230) + np.sum(non_zero_values <= 0)) / non_zero_values.shape[0] * 100
|
| 19 |
+
print('histing...')
|
| 20 |
+
plt.hist(x=non_zero_values)
|
| 21 |
+
plt.title('(0+>230)/all, {:.1f} % are margin values'.format(margin_values_percent))
|
| 22 |
+
plt.savefig('hist_(0+>230)|all.png')
|
| 23 |
+
plt.show()
|
| 24 |
+
|
| 25 |
+
non_zero_values = pixel_values[pixel_values >= 0]
|
| 26 |
+
margin_values_percent = (np.sum(non_zero_values > 230) + np.sum(non_zero_values < 0)) / non_zero_values.shape[0] * 100
|
| 27 |
+
print('histing...')
|
| 28 |
+
plt.figure()
|
| 29 |
+
plt.hist(x=non_zero_values)
|
| 30 |
+
plt.title('(230)/all, {:.1f} % are margin values'.format(margin_values_percent))
|
| 31 |
+
plt.savefig('hist_(230)|all.png')
|
| 32 |
+
plt.show()
|
| 33 |
+
|
| 34 |
+
non_zero_values = pixel_values[pixel_values > 0]
|
| 35 |
+
margin_values_percent = (np.sum(non_zero_values > 230) + np.sum(non_zero_values <= 0)) / non_zero_values.shape[0] * 100
|
| 36 |
+
print('histing...')
|
| 37 |
+
plt.figure()
|
| 38 |
+
plt.hist(x=non_zero_values)
|
| 39 |
+
plt.title('(0+>230)/(all-0), {:.1f} % are margin values'.format(margin_values_percent))
|
| 40 |
+
plt.savefig('hist_(0+>230)|(all-0).png')
|
| 41 |
+
plt.show()
|
loss.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import math
|
| 5 |
+
import numpy as np
|
| 6 |
+
from torch.autograd import Variable
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class IoU_loss(torch.nn.Module):
|
| 10 |
+
def __init__(self):
|
| 11 |
+
super(IoU_loss, self).__init__()
|
| 12 |
+
|
| 13 |
+
def forward(self, pred, target):
|
| 14 |
+
b = pred.shape[0]
|
| 15 |
+
IoU = 0.0
|
| 16 |
+
for i in range(0, b):
|
| 17 |
+
#compute the IoU of the foreground
|
| 18 |
+
Iand1 = torch.sum(target[i, :, :, :]*pred[i, :, :, :])
|
| 19 |
+
Ior1 = torch.sum(target[i, :, :, :]) + torch.sum(pred[i, :, :, :])-Iand1
|
| 20 |
+
IoU1 = Iand1/(Ior1 + 1e-5)
|
| 21 |
+
#IoU loss is (1-IoU1)
|
| 22 |
+
IoU = IoU + (1-IoU1)
|
| 23 |
+
|
| 24 |
+
return IoU/b
|
| 25 |
+
#return IoU
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class Scale_IoU(nn.Module):
|
| 29 |
+
def __init__(self):
|
| 30 |
+
super(Scale_IoU, self).__init__()
|
| 31 |
+
self.iou = IoU_loss()
|
| 32 |
+
|
| 33 |
+
def forward(self, scaled_preds, gt):
|
| 34 |
+
loss = 0
|
| 35 |
+
for pred_lvl in scaled_preds[0:]:
|
| 36 |
+
loss += self.iou(torch.sigmoid(pred_lvl), gt) + self.iou(1-torch.sigmoid(pred_lvl), 1-gt)
|
| 37 |
+
return loss
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def compute_cos_dis(x_sup, x_que):
|
| 41 |
+
x_sup = x_sup.view(x_sup.size()[0], x_sup.size()[1], -1)
|
| 42 |
+
x_que = x_que.view(x_que.size()[0], x_que.size()[1], -1)
|
| 43 |
+
|
| 44 |
+
x_que_norm = torch.norm(x_que, p=2, dim=1, keepdim=True)
|
| 45 |
+
x_sup_norm = torch.norm(x_sup, p=2, dim=1, keepdim=True)
|
| 46 |
+
|
| 47 |
+
x_que_norm = x_que_norm.permute(0, 2, 1)
|
| 48 |
+
x_qs_norm = torch.matmul(x_que_norm, x_sup_norm)
|
| 49 |
+
|
| 50 |
+
x_que = x_que.permute(0, 2, 1)
|
| 51 |
+
|
| 52 |
+
x_qs = torch.matmul(x_que, x_sup)
|
| 53 |
+
x_qs = x_qs / (x_qs_norm + 1e-5)
|
| 54 |
+
return x_qs
|
| 55 |
+
|
| 56 |
+
|
main.cpython-37.pyc
ADDED
|
Binary file (10.5 kB). View file
|
|
|
main.cpython-38.pyc
ADDED
|
Binary file (10.3 kB). View file
|
|
|
main.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from models.vgg import VGG_Backbone
|
| 5 |
+
from util import *
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def weights_init(module):
|
| 9 |
+
if isinstance(module, nn.Conv2d):
|
| 10 |
+
nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
|
| 11 |
+
if module.bias is not None:
|
| 12 |
+
nn.init.zeros_(module.bias)
|
| 13 |
+
elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
|
| 14 |
+
nn.init.ones_(module.weight)
|
| 15 |
+
if module.bias is not None:
|
| 16 |
+
nn.init.zeros_(module.bias)
|
| 17 |
+
elif isinstance(module, nn.Linear):
|
| 18 |
+
nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
|
| 19 |
+
if module.bias is not None:
|
| 20 |
+
nn.init.zeros_(module.bias)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class EnLayer(nn.Module):
|
| 24 |
+
def __init__(self, in_channel=64):
|
| 25 |
+
super(EnLayer, self).__init__()
|
| 26 |
+
self.enlayer = nn.Sequential(
|
| 27 |
+
nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1),
|
| 28 |
+
nn.ReLU(inplace=True),
|
| 29 |
+
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
def forward(self, x):
|
| 33 |
+
x = self.enlayer(x)
|
| 34 |
+
return x
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class LatLayer(nn.Module):
|
| 38 |
+
def __init__(self, in_channel):
|
| 39 |
+
super(LatLayer, self).__init__()
|
| 40 |
+
self.convlayer = nn.Sequential(
|
| 41 |
+
nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1),
|
| 42 |
+
nn.ReLU(inplace=True),
|
| 43 |
+
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
def forward(self, x):
|
| 47 |
+
x = self.convlayer(x)
|
| 48 |
+
return x
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class DSLayer(nn.Module):
|
| 52 |
+
def __init__(self, in_channel=64):
|
| 53 |
+
super(DSLayer, self).__init__()
|
| 54 |
+
self.enlayer = nn.Sequential(
|
| 55 |
+
nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1),
|
| 56 |
+
nn.ReLU(inplace=True),
|
| 57 |
+
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
|
| 58 |
+
nn.ReLU(inplace=True),
|
| 59 |
+
)
|
| 60 |
+
self.predlayer = nn.Sequential(
|
| 61 |
+
nn.Conv2d(64, 1, kernel_size=1, stride=1, padding=0))#, nn.Sigmoid())
|
| 62 |
+
|
| 63 |
+
def forward(self, x):
|
| 64 |
+
x = self.enlayer(x)
|
| 65 |
+
x = self.predlayer(x)
|
| 66 |
+
return x
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class half_DSLayer(nn.Module):
|
| 70 |
+
def __init__(self, in_channel=512):
|
| 71 |
+
super(half_DSLayer, self).__init__()
|
| 72 |
+
self.enlayer = nn.Sequential(
|
| 73 |
+
nn.Conv2d(in_channel, int(in_channel/4), kernel_size=3, stride=1, padding=1),
|
| 74 |
+
nn.ReLU(inplace=True),
|
| 75 |
+
)
|
| 76 |
+
self.predlayer = nn.Sequential(
|
| 77 |
+
nn.Conv2d(int(in_channel/4), 1, kernel_size=1, stride=1, padding=0)) #, nn.Sigmoid())
|
| 78 |
+
|
| 79 |
+
def forward(self, x):
|
| 80 |
+
x = self.enlayer(x)
|
| 81 |
+
x = self.predlayer(x)
|
| 82 |
+
return x
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class AugAttentionModule(nn.Module):
|
| 86 |
+
def __init__(self, input_channels=512):
|
| 87 |
+
super(AugAttentionModule, self).__init__()
|
| 88 |
+
self.query_transform = nn.Sequential(
|
| 89 |
+
nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0),
|
| 90 |
+
nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0),
|
| 91 |
+
)
|
| 92 |
+
self.key_transform = nn.Sequential(
|
| 93 |
+
nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0),
|
| 94 |
+
nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0),
|
| 95 |
+
)
|
| 96 |
+
self.value_transform = nn.Sequential(
|
| 97 |
+
nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0),
|
| 98 |
+
nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0),
|
| 99 |
+
)
|
| 100 |
+
self.scale = 1.0 / (input_channels ** 0.5)
|
| 101 |
+
self.conv = nn.Sequential(
|
| 102 |
+
nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0),
|
| 103 |
+
nn.ReLU(inplace=True),
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
def forward(self, x):
|
| 107 |
+
B, C, H, W = x.size()
|
| 108 |
+
x = self.conv(x)
|
| 109 |
+
x_query = self.query_transform(x).view(B, C, -1).permute(0, 2, 1) # B,HW,C
|
| 110 |
+
# x_key: C,BHW
|
| 111 |
+
x_key = self.key_transform(x).view(B, C, -1) # B, C,HW
|
| 112 |
+
# x_value: BHW, C
|
| 113 |
+
x_value = self.value_transform(x).view(B, C, -1).permute(0, 2, 1) # B,HW,C
|
| 114 |
+
attention_bmm = torch.bmm(x_query, x_key)*self.scale # B, HW, HW
|
| 115 |
+
attention = F.softmax(attention_bmm, dim=-1)
|
| 116 |
+
attention_sort = torch.sort(attention_bmm, dim=-1, descending=True)[1]
|
| 117 |
+
attention_sort = torch.sort(attention_sort, dim=-1)[1]
|
| 118 |
+
#####
|
| 119 |
+
attention_positive_num = torch.ones_like(attention).cuda()
|
| 120 |
+
attention_positive_num[attention_bmm < 0] = 0
|
| 121 |
+
att_pos_mask = attention_positive_num.clone()
|
| 122 |
+
attention_positive_num = torch.sum(attention_positive_num, dim=-1, keepdim=True).expand_as(attention_sort)
|
| 123 |
+
attention_sort_pos = attention_sort.float().clone()
|
| 124 |
+
apn = attention_positive_num-1
|
| 125 |
+
attention_sort_pos[attention_sort > apn] = 0
|
| 126 |
+
attention_mask = ((attention_sort_pos+1)**3)*att_pos_mask + (1-att_pos_mask)
|
| 127 |
+
out = torch.bmm(attention*attention_mask, x_value)
|
| 128 |
+
out = out.view(B, H, W, C).permute(0, 3, 1, 2)
|
| 129 |
+
return out+x
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class AttLayer(nn.Module):
|
| 133 |
+
def __init__(self, input_channels=512):
|
| 134 |
+
super(AttLayer, self).__init__()
|
| 135 |
+
self.query_transform = nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0)
|
| 136 |
+
self.key_transform = nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0)
|
| 137 |
+
self.scale = 1.0 / (input_channels ** 0.5)
|
| 138 |
+
self.conv = nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0)
|
| 139 |
+
|
| 140 |
+
def correlation(self, x5, seeds):
|
| 141 |
+
B, C, H5, W5 = x5.size()
|
| 142 |
+
if self.training:
|
| 143 |
+
correlation_maps = F.conv2d(x5, weight=seeds) # B,B,H,W
|
| 144 |
+
else:
|
| 145 |
+
correlation_maps = torch.relu(F.conv2d(x5, weight=seeds)) # B,B,H,W
|
| 146 |
+
correlation_maps = correlation_maps.mean(1).view(B, -1)
|
| 147 |
+
min_value = torch.min(correlation_maps, dim=1, keepdim=True)[0]
|
| 148 |
+
max_value = torch.max(correlation_maps, dim=1, keepdim=True)[0]
|
| 149 |
+
correlation_maps = (correlation_maps - min_value) / (max_value - min_value + 1e-12) # shape=[B, HW]
|
| 150 |
+
correlation_maps = correlation_maps.view(B, 1, H5, W5) # shape=[B, 1, H, W]
|
| 151 |
+
return correlation_maps
|
| 152 |
+
|
| 153 |
+
def forward(self, x5):
|
| 154 |
+
# x: B,C,H,W
|
| 155 |
+
x5 = self.conv(x5)+x5
|
| 156 |
+
B, C, H5, W5 = x5.size()
|
| 157 |
+
x_query = self.query_transform(x5).view(B, C, -1)
|
| 158 |
+
# x_query: B,HW,C
|
| 159 |
+
x_query = torch.transpose(x_query, 1, 2).contiguous().view(-1, C) # BHW, C
|
| 160 |
+
# x_key: B,C,HW
|
| 161 |
+
x_key = self.key_transform(x5).view(B, C, -1)
|
| 162 |
+
x_key = torch.transpose(x_key, 0, 1).contiguous().view(C, -1) # C, BHW
|
| 163 |
+
# W = Q^T K: B,HW,HW
|
| 164 |
+
x_w1 = torch.matmul(x_query, x_key) * self.scale # BHW, BHW
|
| 165 |
+
x_w = x_w1.view(B * H5 * W5, B, H5 * W5)
|
| 166 |
+
x_w = torch.max(x_w, -1).values # BHW, B
|
| 167 |
+
x_w = x_w.mean(-1)
|
| 168 |
+
x_w = x_w.view(B, -1) # B, HW
|
| 169 |
+
x_w = F.softmax(x_w, dim=-1) # B, HW
|
| 170 |
+
##### mine ######
|
| 171 |
+
# x_w_max = torch.max(x_w, -1)
|
| 172 |
+
# max_indices0 = x_w_max.indices.unsqueeze(-1).unsqueeze(-1)
|
| 173 |
+
norm0 = F.normalize(x5, dim=1)
|
| 174 |
+
# norm = norm0.view(B, C, -1)
|
| 175 |
+
# max_indices = max_indices0.expand(B, C, -1)
|
| 176 |
+
# seeds = torch.gather(norm, 2, max_indices).unsqueeze(-1)
|
| 177 |
+
x_w = x_w.unsqueeze(1)
|
| 178 |
+
x_w_max = torch.max(x_w, -1).values.unsqueeze(2).expand_as(x_w)
|
| 179 |
+
mask = torch.zeros_like(x_w).cuda()
|
| 180 |
+
mask[x_w == x_w_max] = 1
|
| 181 |
+
mask = mask.view(B, 1, H5, W5)
|
| 182 |
+
seeds = norm0 * mask
|
| 183 |
+
seeds = seeds.sum(3).sum(2).unsqueeze(2).unsqueeze(3)
|
| 184 |
+
cormap = self.correlation(norm0, seeds)
|
| 185 |
+
x51 = x5 * cormap
|
| 186 |
+
proto1 = torch.mean(x51, (0, 2, 3), True)
|
| 187 |
+
return x5, proto1, x5*proto1+x51, mask
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class Decoder(nn.Module):
|
| 191 |
+
def __init__(self):
|
| 192 |
+
super(Decoder, self).__init__()
|
| 193 |
+
self.toplayer = nn.Sequential(
|
| 194 |
+
nn.Conv2d(512, 64, kernel_size=1, stride=1, padding=0),
|
| 195 |
+
nn.ReLU(inplace=True),
|
| 196 |
+
nn.Conv2d(64, 64, kernel_size=1, stride=1, padding=0))
|
| 197 |
+
self.latlayer4 = LatLayer(in_channel=512)
|
| 198 |
+
self.latlayer3 = LatLayer(in_channel=256)
|
| 199 |
+
self.latlayer2 = LatLayer(in_channel=128)
|
| 200 |
+
self.latlayer1 = LatLayer(in_channel=64)
|
| 201 |
+
|
| 202 |
+
self.enlayer4 = EnLayer()
|
| 203 |
+
self.enlayer3 = EnLayer()
|
| 204 |
+
self.enlayer2 = EnLayer()
|
| 205 |
+
self.enlayer1 = EnLayer()
|
| 206 |
+
|
| 207 |
+
self.dslayer4 = DSLayer()
|
| 208 |
+
self.dslayer3 = DSLayer()
|
| 209 |
+
self.dslayer2 = DSLayer()
|
| 210 |
+
self.dslayer1 = DSLayer()
|
| 211 |
+
|
| 212 |
+
def _upsample_add(self, x, y):
|
| 213 |
+
[_, _, H, W] = y.size()
|
| 214 |
+
x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=False)
|
| 215 |
+
return x + y
|
| 216 |
+
|
| 217 |
+
def forward(self, weighted_x5, x4, x3, x2, x1, H, W):
|
| 218 |
+
preds = []
|
| 219 |
+
p5 = self.toplayer(weighted_x5)
|
| 220 |
+
p4 = self._upsample_add(p5, self.latlayer4(x4))
|
| 221 |
+
p4 = self.enlayer4(p4)
|
| 222 |
+
_pred = self.dslayer4(p4)
|
| 223 |
+
preds.append(
|
| 224 |
+
F.interpolate(_pred,
|
| 225 |
+
size=(H, W),
|
| 226 |
+
mode='bilinear', align_corners=False))
|
| 227 |
+
|
| 228 |
+
p3 = self._upsample_add(p4, self.latlayer3(x3))
|
| 229 |
+
p3 = self.enlayer3(p3)
|
| 230 |
+
_pred = self.dslayer3(p3)
|
| 231 |
+
preds.append(
|
| 232 |
+
F.interpolate(_pred,
|
| 233 |
+
size=(H, W),
|
| 234 |
+
mode='bilinear', align_corners=False))
|
| 235 |
+
|
| 236 |
+
p2 = self._upsample_add(p3, self.latlayer2(x2))
|
| 237 |
+
p2 = self.enlayer2(p2)
|
| 238 |
+
_pred = self.dslayer2(p2)
|
| 239 |
+
preds.append(
|
| 240 |
+
F.interpolate(_pred,
|
| 241 |
+
size=(H, W),
|
| 242 |
+
mode='bilinear', align_corners=False))
|
| 243 |
+
|
| 244 |
+
p1 = self._upsample_add(p2, self.latlayer1(x1))
|
| 245 |
+
p1 = self.enlayer1(p1)
|
| 246 |
+
_pred = self.dslayer1(p1)
|
| 247 |
+
preds.append(
|
| 248 |
+
F.interpolate(_pred,
|
| 249 |
+
size=(H, W),
|
| 250 |
+
mode='bilinear', align_corners=False))
|
| 251 |
+
return preds
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class DCFMNet(nn.Module):
|
| 255 |
+
""" Class for extracting activations and
|
| 256 |
+
registering gradients from targetted intermediate layers """
|
| 257 |
+
def __init__(self, mode='train'):
|
| 258 |
+
super(DCFMNet, self).__init__()
|
| 259 |
+
self.gradients = None
|
| 260 |
+
self.backbone = VGG_Backbone()
|
| 261 |
+
self.mode = mode
|
| 262 |
+
self.aug = AugAttentionModule()
|
| 263 |
+
self.fusion = AttLayer(512)
|
| 264 |
+
self.decoder = Decoder()
|
| 265 |
+
|
| 266 |
+
def set_mode(self, mode):
|
| 267 |
+
self.mode = mode
|
| 268 |
+
|
| 269 |
+
def forward(self, x, gt):
|
| 270 |
+
if self.mode == 'train':
|
| 271 |
+
preds = self._forward(x, gt)
|
| 272 |
+
else:
|
| 273 |
+
with torch.no_grad():
|
| 274 |
+
preds = self._forward(x, gt)
|
| 275 |
+
|
| 276 |
+
return preds
|
| 277 |
+
|
| 278 |
+
def featextract(self, x):
|
| 279 |
+
x1 = self.backbone.conv1(x)
|
| 280 |
+
x2 = self.backbone.conv2(x1)
|
| 281 |
+
x3 = self.backbone.conv3(x2)
|
| 282 |
+
x4 = self.backbone.conv4(x3)
|
| 283 |
+
x5 = self.backbone.conv5(x4)
|
| 284 |
+
return x5, x4, x3, x2, x1
|
| 285 |
+
|
| 286 |
+
def _forward(self, x, gt):
|
| 287 |
+
[B, _, H, W] = x.size()
|
| 288 |
+
x5, x4, x3, x2, x1 = self.featextract(x)
|
| 289 |
+
feat, proto, weighted_x5, cormap = self.fusion(x5)
|
| 290 |
+
feataug = self.aug(weighted_x5)
|
| 291 |
+
preds = self.decoder(feataug, x4, x3, x2, x1, H, W)
|
| 292 |
+
if self.training:
|
| 293 |
+
gt = F.interpolate(gt, size=weighted_x5.size()[2:], mode='bilinear', align_corners=False)
|
| 294 |
+
feat_pos, proto_pos, weighted_x5_pos, cormap_pos = self.fusion(x5 * gt)
|
| 295 |
+
feat_neg, proto_neg, weighted_x5_neg, cormap_neg = self.fusion(x5*(1-gt))
|
| 296 |
+
return preds, proto, proto_pos, proto_neg
|
| 297 |
+
return preds
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
class DCFM(nn.Module):
|
| 301 |
+
def __init__(self, mode='train'):
|
| 302 |
+
super(DCFM, self).__init__()
|
| 303 |
+
set_seed(123)
|
| 304 |
+
self.dcfmnet = DCFMNet()
|
| 305 |
+
self.mode = mode
|
| 306 |
+
|
| 307 |
+
def set_mode(self, mode):
|
| 308 |
+
self.mode = mode
|
| 309 |
+
self.dcfmnet.set_mode(self.mode)
|
| 310 |
+
|
| 311 |
+
def forward(self, x, gt):
|
| 312 |
+
########## Co-SOD ############
|
| 313 |
+
preds = self.dcfmnet(x, gt)
|
| 314 |
+
return preds
|
| 315 |
+
|
preprocessing.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import os
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import torch
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
import torchvision.transforms as T
|
| 10 |
+
import torchvision.transforms.functional as TF
|
| 11 |
+
import numpy as np
|
| 12 |
+
from torchvision import transforms
|
| 13 |
+
import os
|
| 14 |
+
import cv2
|
| 15 |
+
import matplotlib.pyplot as plt
|
| 16 |
+
import shutil
|
| 17 |
+
|
| 18 |
+
source = "/scratch/wej36how/Datasets/NWRD/train"
|
| 19 |
+
dest = "/scratch/wej36how/Datasets/NWRDProcessed/train"
|
| 20 |
+
patch_size = 224
|
| 21 |
+
rust_threshold = 150
|
| 22 |
+
max_number_of_images_per_group = 12
|
| 23 |
+
|
| 24 |
+
patches_path = os.path.join(dest, "patches")
|
| 25 |
+
images_dir = os.path.join(patches_path, "images")
|
| 26 |
+
masks_dir = os.path.join(patches_path, "masks")
|
| 27 |
+
|
| 28 |
+
destination = os.path.join(dest, "RustNonRustSplit")
|
| 29 |
+
root = patches_path
|
| 30 |
+
|
| 31 |
+
rust_images_dir = os.path.join(destination,"rust","images")
|
| 32 |
+
non_rust_images_dir = os.path.join(destination,"non_rust","images")
|
| 33 |
+
|
| 34 |
+
rustClassificationDir = os.path.join(dest, "calssification", "rust")
|
| 35 |
+
nonRustClassificationDir = os.path.join(dest, "calssification", "non_rust")
|
| 36 |
+
os.makedirs(rustClassificationDir, exist_ok=True)
|
| 37 |
+
os.makedirs(nonRustClassificationDir, exist_ok=True)
|
| 38 |
+
|
| 39 |
+
shutil.copytree(rust_images_dir,rustClassificationDir, dirs_exist_ok=True)
|
| 40 |
+
shutil.copytree(non_rust_images_dir,nonRustClassificationDir, dirs_exist_ok=True)
|
| 41 |
+
|
| 42 |
+
import os
|
| 43 |
+
import glob
|
| 44 |
+
|
| 45 |
+
def delete_extra_images(directory, target_count):
|
| 46 |
+
# Get a list of all image files in the directory
|
| 47 |
+
image_files = glob.glob(os.path.join(directory, '*.JPG')) + glob.glob(os.path.join(directory, '*.jpeg')) + glob.glob(os.path.join(directory, '*.png'))
|
| 48 |
+
|
| 49 |
+
# Check if the number of images exceeds the target count
|
| 50 |
+
if len(image_files) > target_count:
|
| 51 |
+
# Calculate the number of images to delete
|
| 52 |
+
num_to_delete = len(image_files) - target_count
|
| 53 |
+
# Sort the images by modification time (oldest first)
|
| 54 |
+
image_files.sort(key=os.path.getmtime)
|
| 55 |
+
# Delete the extra images
|
| 56 |
+
for i in range(num_to_delete):
|
| 57 |
+
os.remove(image_files[i])
|
| 58 |
+
print(f"{num_to_delete} images deleted.")
|
| 59 |
+
elif len(image_files) < target_count:
|
| 60 |
+
print("Warning: Number of images in directory is less than the target count.")
|
| 61 |
+
|
| 62 |
+
if len(os.listdir(rustClassificationDir))< len(os.listdir(nonRustClassificationDir)):
|
| 63 |
+
delete_extra_images(nonRustClassificationDir, len(os.listdir(rustClassificationDir)))
|
| 64 |
+
else:
|
| 65 |
+
delete_extra_images(rustClassificationDir, len(os.listdir(nonRustClassificationDir)))
|
| 66 |
+
|
| 67 |
+
rust_dir = os.path.join(destination,"rust")
|
| 68 |
+
rustCosaliencynDir = os.path.join(dest, "cosaliency")
|
| 69 |
+
shutil.copytree(rust_dir,rustCosaliencynDir, dirs_exist_ok=True)
|
| 70 |
+
|
| 71 |
+
# Function to split images into folders based on image number
|
| 72 |
+
def split_images_into_folders(source_dir, destination_dir):
|
| 73 |
+
# Create destination directory if it doesn't exist
|
| 74 |
+
if not os.path.exists(destination_dir):
|
| 75 |
+
os.makedirs(destination_dir)
|
| 76 |
+
# Iterate through files in the source directory
|
| 77 |
+
for filename in os.listdir(source_dir):
|
| 78 |
+
if filename.endswith('.png'):
|
| 79 |
+
image_no = filename.split('_')[0] # Extract image number from filename
|
| 80 |
+
if not image_no.isdigit():
|
| 81 |
+
image_no = filename.split('_')[1]
|
| 82 |
+
destination_subdir = os.path.join(destination_dir, image_no)
|
| 83 |
+
# Create subdirectory if it doesn't exist
|
| 84 |
+
if not os.path.exists(destination_subdir):
|
| 85 |
+
os.makedirs(destination_subdir)
|
| 86 |
+
# Move the image file to the respective subdirectory
|
| 87 |
+
shutil.move(os.path.join(source_dir, filename), destination_subdir)
|
| 88 |
+
|
| 89 |
+
def organize_images(main_directory):
|
| 90 |
+
# Ensure the main directory exists
|
| 91 |
+
if not os.path.exists(main_directory):
|
| 92 |
+
print(f"The specified main directory '{main_directory}' does not exist.")
|
| 93 |
+
return
|
| 94 |
+
|
| 95 |
+
# Get a list of subdirectories in the main directory
|
| 96 |
+
subdirectories = [d for d in os.listdir(main_directory) if os.path.isdir(os.path.join(main_directory, d))]
|
| 97 |
+
|
| 98 |
+
# Process each subdirectory
|
| 99 |
+
for subdir in subdirectories:
|
| 100 |
+
subdir_path = os.path.join(main_directory, subdir)
|
| 101 |
+
|
| 102 |
+
# Get a list of images in the subdirectory
|
| 103 |
+
images = [f for f in os.listdir(subdir_path) if f.lower().endswith(('.jpg', '.jpeg', '.png', '.gif'))]
|
| 104 |
+
# Determine the number of images per subdirectory
|
| 105 |
+
images_per_subdir = 12
|
| 106 |
+
num_subdirectories = len(images) // images_per_subdir
|
| 107 |
+
n=0
|
| 108 |
+
# Create additional subdirectories if needed
|
| 109 |
+
for i in range(num_subdirectories - 1):
|
| 110 |
+
new_subdir_name = f"{subdir}_part{i + 1}"
|
| 111 |
+
new_subdir_path = os.path.join(main_directory, new_subdir_name)
|
| 112 |
+
|
| 113 |
+
# Create the new subdirectory
|
| 114 |
+
os.makedirs(new_subdir_path)
|
| 115 |
+
|
| 116 |
+
# Move images to the new subdirectory
|
| 117 |
+
for j in range(images_per_subdir):
|
| 118 |
+
old_image_path = os.path.join(subdir_path, images[n])
|
| 119 |
+
new_image_path = os.path.join(new_subdir_path, images[n])
|
| 120 |
+
shutil.move(old_image_path, new_image_path)
|
| 121 |
+
n+=1
|
| 122 |
+
|
| 123 |
+
source_directory = os.path.join(dest, "cosaliency", "images")
|
| 124 |
+
destination_directory = os.path.join(dest, "cosaliency", "images")
|
| 125 |
+
split_images_into_folders(source_directory, destination_directory)
|
| 126 |
+
organize_images(destination_directory)
|
| 127 |
+
|
| 128 |
+
source_directory = os.path.join(dest, "cosaliency", "masks")
|
| 129 |
+
destination_directory = os.path.join(dest, "cosaliency", "masks")
|
| 130 |
+
split_images_into_folders(source_directory, destination_directory)
|
| 131 |
+
organize_images(destination_directory)
|
requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
matplotlib==3.4.1
|
| 2 |
+
numpy==1.19.2
|
| 3 |
+
opencv_python==4.5.1.48
|
| 4 |
+
pandas==1.2.4
|
| 5 |
+
Pillow==9.1.0
|
| 6 |
+
pytorch_toolbelt==0.4.3
|
| 7 |
+
scikit_image==0.18.1
|
| 8 |
+
skimage==0.0
|
| 9 |
+
torch==1.7.1
|
| 10 |
+
torchvision==0.2.2
|
| 11 |
+
tqdm==4.60.0
|
| 12 |
+
transformers
|
segmentation_metrics.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
from sklearn.metrics import precision_score, recall_score, f1_score
|
| 5 |
+
from sklearn.metrics import jaccard_score
|
| 6 |
+
|
| 7 |
+
# Directories containing the prediction maps and ground truth masks
|
| 8 |
+
dir1 = '/home/wej36how/codes/CoSOD-main/result/Predictions/NWRDFRust_concatenated'
|
| 9 |
+
dir2 = '/home/wej36how/datasets/NWRDF/test/masks'
|
| 10 |
+
|
| 11 |
+
# Initialize lists to store scores
|
| 12 |
+
precisions = []
|
| 13 |
+
recalls = []
|
| 14 |
+
f1_scores = []
|
| 15 |
+
iou_scores = []
|
| 16 |
+
|
| 17 |
+
# Loop through all files in the prediction directory
|
| 18 |
+
for filename in os.listdir(dir1):
|
| 19 |
+
pred_path = os.path.join(dir1, filename)
|
| 20 |
+
gt_path = os.path.join(dir2, filename)
|
| 21 |
+
print(pred_path)
|
| 22 |
+
# Ensure that the file exists in both directories
|
| 23 |
+
if os.path.exists(pred_path) and os.path.exists(gt_path):
|
| 24 |
+
# Load the prediction and ground truth images
|
| 25 |
+
pred_img = cv2.imread(pred_path, cv2.IMREAD_GRAYSCALE)
|
| 26 |
+
gt_img = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE)
|
| 27 |
+
|
| 28 |
+
# Flatten the images to 1D arrays
|
| 29 |
+
pred_flat = pred_img.flatten()
|
| 30 |
+
gt_flat = gt_img.flatten()
|
| 31 |
+
|
| 32 |
+
# Binarize the images (assuming binary segmentation masks)
|
| 33 |
+
pred_flat = (pred_flat > 127).astype(np.uint8)
|
| 34 |
+
gt_flat = (gt_flat > 127).astype(np.uint8)
|
| 35 |
+
|
| 36 |
+
# Calculate precision, recall, and F1 score
|
| 37 |
+
precision = precision_score(gt_flat, pred_flat)
|
| 38 |
+
recall = recall_score(gt_flat, pred_flat)
|
| 39 |
+
f1 = f1_score(gt_flat, pred_flat)
|
| 40 |
+
iou = jaccard_score(gt_flat, pred_flat)
|
| 41 |
+
|
| 42 |
+
# Append the scores to the lists
|
| 43 |
+
precisions.append(precision)
|
| 44 |
+
recalls.append(recall)
|
| 45 |
+
f1_scores.append(f1)
|
| 46 |
+
iou_scores.append(iou)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# Calculate average scores
|
| 51 |
+
avg_precision = np.mean(precisions)
|
| 52 |
+
avg_recall = np.mean(recalls)
|
| 53 |
+
avg_f1_score = np.mean(f1_scores)
|
| 54 |
+
avg_iou = np.mean(iou_scores)
|
| 55 |
+
|
| 56 |
+
# Print the results
|
| 57 |
+
print(f'Average Precision: {avg_precision:.4f}')
|
| 58 |
+
print(f'Average Recall: {avg_recall:.4f}')
|
| 59 |
+
print(f'Average F1 Score: {avg_f1_score:.4f}')
|
| 60 |
+
print(f'Average iou Score: {avg_iou:.4f}')
|
select_results.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
import numpy as np
|
| 4 |
+
import cv2
|
| 5 |
+
|
| 6 |
+
from evaluator import Eval_thread
|
| 7 |
+
from dataloader import EvalDataset
|
| 8 |
+
|
| 9 |
+
import sys
|
| 10 |
+
sys.path.append('..')
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def main(cfg):
|
| 14 |
+
dataset_names = cfg.datasets.split('+')
|
| 15 |
+
root_dir_predictions = [dr for dr in os.listdir('.') if 'gconet_' in dr]
|
| 16 |
+
root_dir_prediction_comp = cfg.gt_dir.replace('/gts', '/gconet')
|
| 17 |
+
print('root_dir_predictions:', root_dir_predictions)
|
| 18 |
+
root_dir_prediction = root_dir_predictions[0]
|
| 19 |
+
root_dir_good_ones = 'good_ones'
|
| 20 |
+
for dataset in dataset_names:
|
| 21 |
+
dir_prediction = os.path.join(root_dir_prediction, dataset)
|
| 22 |
+
dir_prediction_comp = os.path.join(root_dir_prediction_comp, dataset)
|
| 23 |
+
dir_gt = os.path.join(cfg.gt_dir, dataset)
|
| 24 |
+
loader = EvalDataset(
|
| 25 |
+
dir_prediction, # preds
|
| 26 |
+
dir_gt, # GT
|
| 27 |
+
return_predpath=True,
|
| 28 |
+
return_gtpath=True
|
| 29 |
+
)
|
| 30 |
+
loader_comp = EvalDataset(
|
| 31 |
+
dir_prediction_comp, # preds
|
| 32 |
+
dir_gt, # GT
|
| 33 |
+
return_predpath=True
|
| 34 |
+
)
|
| 35 |
+
print('Selecting predictions from {}'.format(dir_prediction))
|
| 36 |
+
thread = Eval_thread(loader, cuda=cfg.cuda)
|
| 37 |
+
s_measure, good_ones, good_ones_comp, good_ones_gt = thread.select_by_Smeasure(bar=0.95, loader_comp=loader_comp, bar_comp=0.2)
|
| 38 |
+
dir_good_ones = os.path.join(root_dir_good_ones, dataset)
|
| 39 |
+
os.makedirs(dir_good_ones, exist_ok=True)
|
| 40 |
+
print('have good_ones {}'.format(len(good_ones)))
|
| 41 |
+
for good_one, good_one_comp, good_one_gt in zip(good_ones, good_ones_comp, good_ones_gt):
|
| 42 |
+
dir_category = os.path.join(dir_good_ones, good_one.split('/')[-2])
|
| 43 |
+
os.makedirs(dir_category, exist_ok=True)
|
| 44 |
+
save_path = os.path.join(dir_category, good_one.split('/')[-1])
|
| 45 |
+
sal_map = cv2.imread(good_one)
|
| 46 |
+
sal_map_gt = cv2.imread(good_one_gt)
|
| 47 |
+
sal_map_comp = cv2.imread(good_one_comp)
|
| 48 |
+
image_path = good_one_gt.replace('/gts', '/images').replace('.png', '.jpg')
|
| 49 |
+
image = cv2.imread(image_path)
|
| 50 |
+
cv2.imwrite(save_path, sal_map)
|
| 51 |
+
split_line = np.zeros((sal_map.shape[0], 10, 3)).astype(sal_map.dtype) + 127
|
| 52 |
+
comp = cv2.hconcat([image, split_line, sal_map_gt, split_line, sal_map, split_line, sal_map_comp])
|
| 53 |
+
save_path_comp = ''.join((save_path[:-4], '_comp', save_path[-4:]))
|
| 54 |
+
cv2.imwrite(save_path_comp, comp)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
if __name__ == "__main__":
|
| 58 |
+
parser = argparse.ArgumentParser()
|
| 59 |
+
parser.add_argument('--datasets', type=str, default='CoCA+CoSOD3k+CoSal2015')
|
| 60 |
+
parser.add_argument('--gt_dir', type=str, default='/root/datasets/sod/gts', help='GT')
|
| 61 |
+
parser.add_argument('--cuda', type=bool, default=True)
|
| 62 |
+
config = parser.parse_args()
|
| 63 |
+
main(config)
|
sort_results.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
move_best_results_here = False
|
| 8 |
+
|
| 9 |
+
record = ['dataset', 'ckpt', 'Emax', 'Smeasure', 'Fmax', 'MAE', 'Emean', 'Fmean']
|
| 10 |
+
measurement = 'Emax'
|
| 11 |
+
score_idx = record.index(measurement)
|
| 12 |
+
|
| 13 |
+
with open('output/details/result.txt', 'r') as f:
|
| 14 |
+
res = f.read()
|
| 15 |
+
|
| 16 |
+
res = res.replace('||', '').replace('(', '').replace(')', '')
|
| 17 |
+
|
| 18 |
+
score = []
|
| 19 |
+
for r in res.splitlines():
|
| 20 |
+
ds = r.split()
|
| 21 |
+
s = ds[:2]
|
| 22 |
+
for idx_d, d in enumerate(ds[2:]):
|
| 23 |
+
if idx_d % 2 == 0:
|
| 24 |
+
s.append(float(d))
|
| 25 |
+
score.append(s)
|
| 26 |
+
|
| 27 |
+
ss = sorted(score, key=lambda x: (x[record.index('dataset')], x[record.index('Emax')], x[record.index('Smeasure')], x[record.index('Fmax')], x[record.index('ckpt')]), reverse=True)
|
| 28 |
+
ss_ar = np.array(ss)
|
| 29 |
+
np.savetxt('score_sorted.txt', ss_ar, fmt='%s')
|
| 30 |
+
ckpt_coca = ss_ar[ss_ar[:, 0] == 'CoCA'][0][1]
|
| 31 |
+
ckpt_cosod = ss_ar[ss_ar[:, 0] == 'CoSOD3k'][0][1]
|
| 32 |
+
ckpt_cosal = ss_ar[ss_ar[:, 0] == 'CoSal2015'][0][1]
|
| 33 |
+
|
| 34 |
+
best_coca_scores = ss_ar[ss_ar[:, 1] == ckpt_coca]
|
| 35 |
+
best_cosod_scores = ss_ar[ss_ar[:, 1] == ckpt_cosod]
|
| 36 |
+
best_cosal_scores = ss_ar[ss_ar[:, 1] == ckpt_cosal]
|
| 37 |
+
print('Best (models may be different):')
|
| 38 |
+
print('CoCA:\n', best_coca_scores)
|
| 39 |
+
print('CoSOD3k:\n', best_cosod_scores)
|
| 40 |
+
print('CoSal2015:\n', best_cosal_scores)
|
| 41 |
+
|
| 42 |
+
# Overal relative Emax improvement on three datasets
|
| 43 |
+
if measurement == 'Emax':
|
| 44 |
+
gco_scores = {'CoCA': 0.760, 'CoSOD3k': 0.860, 'CoSal2015': 0.887}
|
| 45 |
+
gco_scores_Smeasure = {'CoCA': 0.673, 'CoSOD3k': 0.802, 'CoSal2015': 0.845}
|
| 46 |
+
elif measurement == 'Smeasure':
|
| 47 |
+
gco_scores = {'CoCA': 0.673, 'CoSOD3k': 0.802, 'CoSal2015': 0.845}
|
| 48 |
+
elif measurement == 'Fmax':
|
| 49 |
+
gco_scores = {'CoCA': 0.544, 'CoSOD3k': 0.777, 'CoSal2015': 0.847}
|
| 50 |
+
elif measurement == 'Emean':
|
| 51 |
+
gco_scores = {'CoCA': 0.1, 'CoSOD3k': 0.1, 'CoSal2015': 0.1}
|
| 52 |
+
elif measurement == 'Fmean':
|
| 53 |
+
gco_scores = {'CoCA': 0.1, 'CoSOD3k': 0.1, 'CoSal2015': 0.1}
|
| 54 |
+
ckpts = list(set(ss_ar[:, 1].squeeze().tolist()))
|
| 55 |
+
improvements_mean = []
|
| 56 |
+
improvements_lst = []
|
| 57 |
+
improvements_mean_Smeasure = []
|
| 58 |
+
improvements_lst_Smeasure = []
|
| 59 |
+
for ckpt in ckpts:
|
| 60 |
+
scores = ss_ar[ss_ar[:, 1] == ckpt]
|
| 61 |
+
if scores.shape[0] != len(gco_scores):
|
| 62 |
+
improvements_mean.append(-1)
|
| 63 |
+
improvements_lst.append([-1, -1, 1])
|
| 64 |
+
improvements_mean_Smeasure.append(-1)
|
| 65 |
+
improvements_lst_Smeasure.append([-1, -1, 1])
|
| 66 |
+
continue
|
| 67 |
+
score_coca = float(scores[scores[:, 0] == 'CoCA'][0][score_idx])
|
| 68 |
+
score_cosod = float(scores[scores[:, 0] == 'CoSOD3k'][0][score_idx])
|
| 69 |
+
score_cosal = float(scores[scores[:, 0] == 'CoSal2015'][0][score_idx])
|
| 70 |
+
improvements = [
|
| 71 |
+
(score_coca - gco_scores['CoCA']) / gco_scores['CoCA'],
|
| 72 |
+
(score_cosod - gco_scores['CoSOD3k']) / gco_scores['CoSOD3k'],
|
| 73 |
+
(score_cosal - gco_scores['CoSal2015']) / gco_scores['CoSal2015']
|
| 74 |
+
]
|
| 75 |
+
improvement_mean = np.mean(improvements)
|
| 76 |
+
improvements_mean.append(improvement_mean)
|
| 77 |
+
improvements_lst.append(improvements)
|
| 78 |
+
|
| 79 |
+
# Smeasure
|
| 80 |
+
score_coca = float(scores[scores[:, 0] == 'CoCA'][0][record.index('Smeasure')])
|
| 81 |
+
score_cosod = float(scores[scores[:, 0] == 'CoSOD3k'][0][record.index('Smeasure')])
|
| 82 |
+
score_cosal = float(scores[scores[:, 0] == 'CoSal2015'][0][record.index('Smeasure')])
|
| 83 |
+
improvements_Smeasure = [
|
| 84 |
+
(score_coca - gco_scores_Smeasure['CoCA']) / gco_scores_Smeasure['CoCA'],
|
| 85 |
+
(score_cosod - gco_scores_Smeasure['CoSOD3k']) / gco_scores_Smeasure['CoSOD3k'],
|
| 86 |
+
(score_cosal - gco_scores_Smeasure['CoSal2015']) / gco_scores_Smeasure['CoSal2015']
|
| 87 |
+
]
|
| 88 |
+
improvement_mean_Smeasure = np.mean(improvements_Smeasure)
|
| 89 |
+
improvements_mean_Smeasure.append(improvement_mean_Smeasure)
|
| 90 |
+
improvements_lst_Smeasure.append(improvements_Smeasure)
|
| 91 |
+
best_measurement = 'Emax'
|
| 92 |
+
if best_measurement == 'Emax':
|
| 93 |
+
best_improvement_index = np.argsort(improvements_mean).tolist()[-1]
|
| 94 |
+
best_ckpt = ckpts[best_improvement_index]
|
| 95 |
+
best_improvement_mean = improvements_mean[best_improvement_index]
|
| 96 |
+
best_improvements = improvements_lst[best_improvement_index]
|
| 97 |
+
|
| 98 |
+
best_improvement_mean_Smeasure = improvements_mean_Smeasure[best_improvement_index]
|
| 99 |
+
best_improvements_Smeasure = improvements_lst_Smeasure[best_improvement_index]
|
| 100 |
+
elif best_measurement == 'Smeasure':
|
| 101 |
+
best_improvement_index = np.argsort(improvements_mean_Smeasure).tolist()[-1]
|
| 102 |
+
best_ckpt = ckpts[best_improvement_index]
|
| 103 |
+
best_improvement_mean_Smeasure = improvements_mean_Smeasure[best_improvement_index]
|
| 104 |
+
best_improvements_Smeasure = improvements_lst_Smeasure[best_improvement_index]
|
| 105 |
+
|
| 106 |
+
best_improvement_mean = improvements_mean[best_improvement_index]
|
| 107 |
+
best_improvements = improvements_lst[best_improvement_index]
|
| 108 |
+
|
| 109 |
+
print('The overall best one:')
|
| 110 |
+
print(ss_ar[ss_ar[:, 1] == best_ckpt])
|
| 111 |
+
print('Got Emax improvements on CoCA-{:.3f}%, CoSOD3k-{:.3f}%, CoSal2015-{:.3f}%, mean_improvement: {:.3f}%.'.format(
|
| 112 |
+
best_improvements[0]*100, best_improvements[1]*100, best_improvements[2]*100, best_improvement_mean*100
|
| 113 |
+
))
|
| 114 |
+
print('Got Smes improvements on CoCA-{:.3f}%, CoSOD3k-{:.3f}%, CoSal2015-{:.3f}%, mean_improvement: {:.3f}%.'.format(
|
| 115 |
+
best_improvements_Smeasure[0]*100, best_improvements_Smeasure[1]*100, best_improvements_Smeasure[2]*100, best_improvement_mean_Smeasure*100
|
| 116 |
+
))
|
| 117 |
+
trial = int(best_ckpt.split('_')[-1].split('-')[0])
|
| 118 |
+
ep = int(best_ckpt.split('ep')[-1].split(':')[0])
|
| 119 |
+
if move_best_results_here:
|
| 120 |
+
trial, ep = 'gconet_{}'.format(trial), 'ep{}'.format(ep)
|
| 121 |
+
dr = os.path.join(trial, ep)
|
| 122 |
+
dst = '-'.join((trial, ep))
|
| 123 |
+
shutil.move(os.path.join('/root/datasets/sod/preds', dr), dst)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
# model_indices = sorted([fname.split('_')[-1] for fname in os.listdir('output/details') if 'gconet_' in fname])
|
| 127 |
+
# emax = {}
|
| 128 |
+
# for model_idx in model_indices:
|
| 129 |
+
# m = 'gconet_{}-'.format(model_idx)
|
| 130 |
+
# if m not in list(emax.keys()):
|
| 131 |
+
# emax[m] = []
|
| 132 |
+
# for s in score:
|
| 133 |
+
# if m in s[1]:
|
| 134 |
+
# ep = int(s[1].split('ep')[-1].rstrip('):'))
|
| 135 |
+
# emax[m].append([ep, s[2], s[0]])
|
| 136 |
+
|
| 137 |
+
# for m, e in emax.items():
|
| 138 |
+
# plot_name = m[:-1]
|
| 139 |
+
# print('Saving {} ...'.format(plot_name))
|
| 140 |
+
# e = np.array(e)
|
| 141 |
+
# e_coca = e[e[:, -1] == 'CoCA']
|
| 142 |
+
# e_cosod = e[e[:, -1] == 'CoSOD3k']
|
| 143 |
+
# e_cosal = e[e[:, -1] == 'CoSal2015']
|
| 144 |
+
# eps = sorted(list(set(e_coca[:, 0].astype(float))))
|
| 145 |
+
|
| 146 |
+
# e_coca = np.array(sorted(e_coca, key=lambda x: int(x[0])))[:, 1].astype(float)
|
| 147 |
+
# e_cosod = np.array(sorted(e_cosod, key=lambda x: int(x[0])))[:, 1].astype(float)
|
| 148 |
+
# e_cosal = np.array(sorted(e_cosal, key=lambda x: int(x[0])))[:, 1].astype(float)
|
| 149 |
+
|
| 150 |
+
# plt.figure()
|
| 151 |
+
# plt.plot(eps, e_coca)
|
| 152 |
+
# plt.plot(eps, e_cosod)
|
| 153 |
+
# plt.plot(eps, e_cosal)
|
| 154 |
+
# plt.legend(['CoCA', 'CoSOD3k', 'CoSal2015'])
|
| 155 |
+
# plt.title(m)
|
| 156 |
+
# plt.savefig('{}.png'.format(plot_name))
|
test.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from PIL import Image
|
| 2 |
+
from dataset import get_loader
|
| 3 |
+
import torch
|
| 4 |
+
from torchvision import transforms
|
| 5 |
+
# from util import save_tensor_img, Logger
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from torch import nn
|
| 8 |
+
import os
|
| 9 |
+
from models.main import *
|
| 10 |
+
import argparse
|
| 11 |
+
# import numpy as np
|
| 12 |
+
# import cv2
|
| 13 |
+
# from skimage import img_as_ubyte
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def main(args):
|
| 17 |
+
# Init model
|
| 18 |
+
|
| 19 |
+
device = torch.device("cuda")
|
| 20 |
+
model = DCFM()
|
| 21 |
+
model = model.to(device)
|
| 22 |
+
try:
|
| 23 |
+
# modelname = os.path.join(args.param_root, 'best_ep198_Smeasure0.7019.pth')
|
| 24 |
+
modelname = "/scratch/wej36how/codes/DCFM-master/best_ep12_Smeasure0.7256.pth"
|
| 25 |
+
dcfmnet_dict = torch.load(modelname)
|
| 26 |
+
print('loaded', modelname)
|
| 27 |
+
except:
|
| 28 |
+
dcfmnet_dict = torch.load(os.path.join(args.param_root, 'dcfm.pth'))
|
| 29 |
+
|
| 30 |
+
model.to(device)
|
| 31 |
+
model.dcfmnet.load_state_dict(dcfmnet_dict)
|
| 32 |
+
model.eval()
|
| 33 |
+
model.set_mode('test')
|
| 34 |
+
|
| 35 |
+
tensor2pil = transforms.ToPILImage()
|
| 36 |
+
for testset in ['NWRD']:
|
| 37 |
+
if testset == 'CoCA':
|
| 38 |
+
test_img_path = './data/images/CoCA/'
|
| 39 |
+
test_gt_path = './data/gts/CoCA/'
|
| 40 |
+
saved_root = os.path.join(args.save_root, 'CoCA')
|
| 41 |
+
elif testset == 'CoSOD3k':
|
| 42 |
+
test_img_path = './data/images/CoSOD3k/'
|
| 43 |
+
test_gt_path = './data/gts/CoSOD3k/'
|
| 44 |
+
saved_root = os.path.join(args.save_root, 'CoSOD3k')
|
| 45 |
+
elif testset == 'CoSal2015':
|
| 46 |
+
test_img_path = './data/images/CoSal2015/'
|
| 47 |
+
test_gt_path = './data/gts/CoSal2015/'
|
| 48 |
+
saved_root = os.path.join(args.save_root, 'CoSal2015')
|
| 49 |
+
elif testset == 'NWRD':
|
| 50 |
+
test_img_path = '/home/wej36how/codes/crossvit/results/nwrd22/images/'
|
| 51 |
+
test_gt_path = '/home/wej36how/codes/crossvit/results/nwrd22/images/'
|
| 52 |
+
saved_root = os.path.join(args.save_root, 'NWRD')
|
| 53 |
+
else:
|
| 54 |
+
print('Unkonwn test dataset')
|
| 55 |
+
print(args.dataset)
|
| 56 |
+
|
| 57 |
+
test_loader = get_loader(
|
| 58 |
+
test_img_path, test_gt_path, args.size, 1, istrain=False, shuffle=False, num_workers=8, pin=True)
|
| 59 |
+
|
| 60 |
+
for batch in tqdm(test_loader):
|
| 61 |
+
inputs = batch[0].to(device).squeeze(0)
|
| 62 |
+
gts = batch[1].to(device).squeeze(0)
|
| 63 |
+
subpaths = batch[2]
|
| 64 |
+
ori_sizes = batch[3]
|
| 65 |
+
scaled_preds= model(inputs, gts)
|
| 66 |
+
scaled_preds = torch.sigmoid(scaled_preds[-1])
|
| 67 |
+
os.makedirs(os.path.join(saved_root, subpaths[0][0].split('/')[0]), exist_ok=True)
|
| 68 |
+
num = gts.shape[0]
|
| 69 |
+
for inum in range(num):
|
| 70 |
+
subpath = subpaths[inum][0]
|
| 71 |
+
ori_size = (ori_sizes[inum][0].item(), ori_sizes[inum][1].item())
|
| 72 |
+
res = nn.functional.interpolate(scaled_preds[inum].unsqueeze(0), size=ori_size, mode='bilinear', align_corners=True)
|
| 73 |
+
save_tensor_img(res, os.path.join(saved_root, subpath))
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
if __name__ == '__main__':
|
| 77 |
+
# Parameter from command line
|
| 78 |
+
parser = argparse.ArgumentParser(description='')
|
| 79 |
+
parser.add_argument('--size',
|
| 80 |
+
default=224,
|
| 81 |
+
type=int,
|
| 82 |
+
help='input size')
|
| 83 |
+
parser.add_argument('--param_root', default='/data1/dcfm/temp', type=str, help='model folder')
|
| 84 |
+
parser.add_argument('--save_root', default='./CoSODmaps/pred', type=str, help='Output folder')
|
| 85 |
+
|
| 86 |
+
args = parser.parse_args()
|
| 87 |
+
|
| 88 |
+
main(args)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
|
train.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.optim as optim
|
| 4 |
+
from util import Logger, AverageMeter, save_checkpoint, save_tensor_img, set_seed
|
| 5 |
+
import os
|
| 6 |
+
import numpy as np
|
| 7 |
+
from matplotlib import pyplot as plt
|
| 8 |
+
import time
|
| 9 |
+
import argparse
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from dataset import get_loader
|
| 12 |
+
from loss import *
|
| 13 |
+
from config import Config
|
| 14 |
+
from evaluation.dataloader import EvalDataset
|
| 15 |
+
from evaluation.evaluator import Eval_thread
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
from models.main import *
|
| 19 |
+
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
import pytorch_toolbelt.losses as PTL
|
| 22 |
+
|
| 23 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
| 24 |
+
# Parameter from command line
|
| 25 |
+
parser = argparse.ArgumentParser(description='')
|
| 26 |
+
|
| 27 |
+
parser.add_argument('--loss',
|
| 28 |
+
default='Scale_IoU',
|
| 29 |
+
type=str,
|
| 30 |
+
help="Options: '', ''")
|
| 31 |
+
parser.add_argument('--bs', '--batch_size', default=1, type=int)
|
| 32 |
+
parser.add_argument('--lr',
|
| 33 |
+
'--learning_rate',
|
| 34 |
+
default=1e-4,
|
| 35 |
+
type=float,
|
| 36 |
+
help='Initial learning rate')
|
| 37 |
+
parser.add_argument('--resume',
|
| 38 |
+
default=None,
|
| 39 |
+
type=str,
|
| 40 |
+
help='path to latest checkpoint')
|
| 41 |
+
parser.add_argument('--epochs', default=200, type=int)
|
| 42 |
+
parser.add_argument('--start_epoch',
|
| 43 |
+
default=0,
|
| 44 |
+
type=int,
|
| 45 |
+
help='manual epoch number (useful on restarts)')
|
| 46 |
+
parser.add_argument('--trainset',
|
| 47 |
+
default='CoCo',
|
| 48 |
+
type=str,
|
| 49 |
+
help="Options: 'CoCo'")
|
| 50 |
+
parser.add_argument('--testsets',
|
| 51 |
+
default='CoCA',
|
| 52 |
+
type=str,
|
| 53 |
+
help="Options: 'CoCA','CoSal2015','CoSOD3k','iCoseg','MSRC'")
|
| 54 |
+
parser.add_argument('--size',
|
| 55 |
+
default=224,
|
| 56 |
+
type=int,
|
| 57 |
+
help='input size')
|
| 58 |
+
parser.add_argument('--tmp', default='/data1/dcfm/temp', help='Temporary folder')
|
| 59 |
+
parser.add_argument('--save_root', default='./CoSODmaps/pred', type=str, help='Output folder')
|
| 60 |
+
|
| 61 |
+
args = parser.parse_args()
|
| 62 |
+
config = Config()
|
| 63 |
+
|
| 64 |
+
# Prepare dataset
|
| 65 |
+
if args.trainset == 'CoCo':
|
| 66 |
+
train_img_path = './data/CoCo/img/'
|
| 67 |
+
train_gt_path = './data/CoCo/gt/'
|
| 68 |
+
train_loader = get_loader(train_img_path,
|
| 69 |
+
train_gt_path,
|
| 70 |
+
args.size,
|
| 71 |
+
args.bs,
|
| 72 |
+
max_num=16, #20,
|
| 73 |
+
istrain=True,
|
| 74 |
+
shuffle=False,
|
| 75 |
+
num_workers=8, #4,
|
| 76 |
+
pin=True)
|
| 77 |
+
|
| 78 |
+
else:
|
| 79 |
+
print('Unkonwn train dataset')
|
| 80 |
+
print(args.dataset)
|
| 81 |
+
|
| 82 |
+
for testset in ['CoCA']:
|
| 83 |
+
if testset == 'CoCA':
|
| 84 |
+
test_img_path = './data/images/CoCA/'
|
| 85 |
+
test_gt_path = './data/gts/CoCA/'
|
| 86 |
+
|
| 87 |
+
saved_root = os.path.join(args.save_root, 'CoCA')
|
| 88 |
+
elif testset == 'CoSOD3k':
|
| 89 |
+
test_img_path = './data/images/CoSOD3k/'
|
| 90 |
+
test_gt_path = './data/gts/CoSOD3k/'
|
| 91 |
+
saved_root = os.path.join(args.save_root, 'CoSOD3k')
|
| 92 |
+
elif testset == 'CoSal2015':
|
| 93 |
+
test_img_path = './data/images/CoSal2015/'
|
| 94 |
+
test_gt_path = './data/gts/CoSal2015/'
|
| 95 |
+
saved_root = os.path.join(args.save_root, 'CoSal2015')
|
| 96 |
+
elif testset == 'CoCo':
|
| 97 |
+
test_img_path = './data/images/CoCo/'
|
| 98 |
+
test_gt_path = './data/gts/CoCo/'
|
| 99 |
+
saved_root = os.path.join(args.save_root, 'CoCo')
|
| 100 |
+
else:
|
| 101 |
+
print('Unkonwn test dataset')
|
| 102 |
+
print(args.dataset)
|
| 103 |
+
|
| 104 |
+
test_loader = get_loader(
|
| 105 |
+
test_img_path, test_gt_path, args.size, 1, istrain=False, shuffle=False, num_workers=8, pin=True)
|
| 106 |
+
|
| 107 |
+
# make dir for tmp
|
| 108 |
+
os.makedirs(args.tmp, exist_ok=True)
|
| 109 |
+
|
| 110 |
+
# Init log file
|
| 111 |
+
logger = Logger(os.path.join(args.tmp, "log.txt"))
|
| 112 |
+
set_seed(123)
|
| 113 |
+
|
| 114 |
+
# Init model
|
| 115 |
+
device = torch.device("cuda")
|
| 116 |
+
|
| 117 |
+
model = DCFM()
|
| 118 |
+
model = model.to(device)
|
| 119 |
+
model.apply(weights_init)
|
| 120 |
+
|
| 121 |
+
model.dcfmnet.backbone._initialize_weights(torch.load('./models/vgg16-397923af.pth'))
|
| 122 |
+
|
| 123 |
+
backbone_params = list(map(id, model.dcfmnet.backbone.parameters()))
|
| 124 |
+
base_params = filter(lambda p: id(p) not in backbone_params,
|
| 125 |
+
model.dcfmnet.parameters())
|
| 126 |
+
|
| 127 |
+
all_params = [{'params': base_params}, {'params': model.dcfmnet.backbone.parameters(), 'lr': args.lr*0.1}]
|
| 128 |
+
|
| 129 |
+
# Setting optimizer
|
| 130 |
+
optimizer = optim.Adam(params=all_params,lr=args.lr, weight_decay=1e-4, betas=[0.9, 0.99])
|
| 131 |
+
|
| 132 |
+
for key, value in model.named_parameters():
|
| 133 |
+
if 'dcfmnet.backbone' in key and 'dcfmnet.backbone.conv5.conv5_3' not in key:
|
| 134 |
+
value.requires_grad = False
|
| 135 |
+
|
| 136 |
+
for key, value in model.named_parameters():
|
| 137 |
+
print(key, value.requires_grad)
|
| 138 |
+
|
| 139 |
+
# log model and optimizer pars
|
| 140 |
+
logger.info("Model details:")
|
| 141 |
+
logger.info(model)
|
| 142 |
+
logger.info("Optimizer details:")
|
| 143 |
+
logger.info(optimizer)
|
| 144 |
+
logger.info("Scheduler details:")
|
| 145 |
+
# logger.info(scheduler)
|
| 146 |
+
logger.info("Other hyperparameters:")
|
| 147 |
+
logger.info(args)
|
| 148 |
+
|
| 149 |
+
# Setting Loss
|
| 150 |
+
exec('from loss import ' + args.loss)
|
| 151 |
+
IOUloss = eval(args.loss+'()')
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def main():
|
| 155 |
+
val_measures = []
|
| 156 |
+
# Optionally resume from a checkpoint
|
| 157 |
+
if args.resume:
|
| 158 |
+
if os.path.isfile(args.resume):
|
| 159 |
+
logger.info("=> loading checkpoint '{}'".format(args.resume))
|
| 160 |
+
checkpoint = torch.load(args.resume)
|
| 161 |
+
args.start_epoch = checkpoint['epoch']
|
| 162 |
+
model.dcfmnet.load_state_dict(checkpoint['state_dict'])
|
| 163 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
| 164 |
+
logger.info("=> loaded checkpoint '{}' (epoch {})".format(
|
| 165 |
+
args.resume, checkpoint['epoch']))
|
| 166 |
+
else:
|
| 167 |
+
logger.info("=> no checkpoint found at '{}'".format(args.resume))
|
| 168 |
+
|
| 169 |
+
print(args.epochs)
|
| 170 |
+
for epoch in range(args.start_epoch, args.epochs):
|
| 171 |
+
train_loss = train(epoch)
|
| 172 |
+
if config.validation:
|
| 173 |
+
measures = validate(model, test_loader, args.testsets)
|
| 174 |
+
val_measures.append(measures)
|
| 175 |
+
print(
|
| 176 |
+
'Validation: S_measure on CoCA for epoch-{} is {:.4f}. Best epoch is epoch-{} with S_measure {:.4f}'.format(
|
| 177 |
+
epoch, measures[0], np.argmax(np.array(val_measures)[:, 0].squeeze()),
|
| 178 |
+
np.max(np.array(val_measures)[:, 0]))
|
| 179 |
+
)
|
| 180 |
+
# Save checkpoint
|
| 181 |
+
save_checkpoint(
|
| 182 |
+
{
|
| 183 |
+
'epoch': epoch + 1,
|
| 184 |
+
'state_dict': model.dcfmnet.state_dict(),
|
| 185 |
+
#'scheduler': scheduler.state_dict(),
|
| 186 |
+
},
|
| 187 |
+
path=args.tmp)
|
| 188 |
+
if config.validation:
|
| 189 |
+
if np.max(np.array(val_measures)[:, 0].squeeze()) == measures[0]:
|
| 190 |
+
best_weights_before = [os.path.join(args.tmp, weight_file) for weight_file in
|
| 191 |
+
os.listdir(args.tmp) if 'best_' in weight_file]
|
| 192 |
+
for best_weight_before in best_weights_before:
|
| 193 |
+
os.remove(best_weight_before)
|
| 194 |
+
torch.save(model.dcfmnet.state_dict(),
|
| 195 |
+
os.path.join(args.tmp, 'best_ep{}_Smeasure{:.4f}.pth'.format(epoch, measures[0])))
|
| 196 |
+
if (epoch + 1) % 10 == 0 or epoch == 0:
|
| 197 |
+
torch.save(model.dcfmnet.state_dict(), args.tmp + '/model-' + str(epoch + 1) + '.pt')
|
| 198 |
+
|
| 199 |
+
if epoch > 188:
|
| 200 |
+
torch.save(model.dcfmnet.state_dict(), args.tmp+'/model-' + str(epoch + 1) + '.pt')
|
| 201 |
+
#dcfmnet_dict = model.dcfmnet.state_dict()
|
| 202 |
+
#torch.save(dcfmnet_dict, os.path.join(args.tmp, 'final.pth'))
|
| 203 |
+
|
| 204 |
+
def sclloss(x, xt, xb):
|
| 205 |
+
cosc = (1+compute_cos_dis(x, xt))*0.5
|
| 206 |
+
cosb = (1+compute_cos_dis(x, xb))*0.5
|
| 207 |
+
loss = -torch.log(cosc+1e-5)-torch.log(1-cosb+1e-5)
|
| 208 |
+
return loss.sum()
|
| 209 |
+
|
| 210 |
+
def train(epoch):
|
| 211 |
+
# Switch to train mode
|
| 212 |
+
model.train()
|
| 213 |
+
model.set_mode('train')
|
| 214 |
+
loss_sum = 0.0
|
| 215 |
+
loss_sumkl = 0.0
|
| 216 |
+
for batch_idx, batch in enumerate(train_loader):
|
| 217 |
+
inputs = batch[0].to(device).squeeze(0)
|
| 218 |
+
gts = batch[1].to(device).squeeze(0)
|
| 219 |
+
pred, proto, protogt, protobg = model(inputs, gts)
|
| 220 |
+
loss_iou = IOUloss(pred, gts)
|
| 221 |
+
loss_scl = sclloss(proto, protogt, protobg)
|
| 222 |
+
loss = loss_iou+0.1*loss_scl
|
| 223 |
+
optimizer.zero_grad()
|
| 224 |
+
loss.backward()
|
| 225 |
+
optimizer.step()
|
| 226 |
+
loss_sum = loss_sum + loss_iou.detach().item()
|
| 227 |
+
|
| 228 |
+
if batch_idx % 20 == 0:
|
| 229 |
+
logger.info('Epoch[{0}/{1}] Iter[{2}/{3}] '
|
| 230 |
+
'Train Loss: loss_iou: {4:.3f}, loss_scl: {5:.3f} '.format(
|
| 231 |
+
epoch,
|
| 232 |
+
args.epochs,
|
| 233 |
+
batch_idx,
|
| 234 |
+
len(train_loader),
|
| 235 |
+
loss_iou,
|
| 236 |
+
loss_scl,
|
| 237 |
+
))
|
| 238 |
+
loss_mean = loss_sum / len(train_loader)
|
| 239 |
+
return loss_sum
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def validate(model, test_loaders, testsets):
|
| 243 |
+
model.eval()
|
| 244 |
+
|
| 245 |
+
testsets = testsets.split('+')
|
| 246 |
+
measures = []
|
| 247 |
+
for testset in testsets[:1]:
|
| 248 |
+
print('Validating {}...'.format(testset))
|
| 249 |
+
#test_loader = test_loaders[testset]
|
| 250 |
+
|
| 251 |
+
saved_root = os.path.join(args.save_root, testset)
|
| 252 |
+
|
| 253 |
+
for batch in test_loader:
|
| 254 |
+
inputs = batch[0].to(device).squeeze(0)
|
| 255 |
+
gts = batch[1].to(device).squeeze(0)
|
| 256 |
+
subpaths = batch[2]
|
| 257 |
+
ori_sizes = batch[3]
|
| 258 |
+
with torch.no_grad():
|
| 259 |
+
scaled_preds = model(inputs, gts)[-1].sigmoid()
|
| 260 |
+
|
| 261 |
+
os.makedirs(os.path.join(saved_root, subpaths[0][0].split('/')[0]), exist_ok=True)
|
| 262 |
+
|
| 263 |
+
num = len(scaled_preds)
|
| 264 |
+
for inum in range(num):
|
| 265 |
+
subpath = subpaths[inum][0]
|
| 266 |
+
ori_size = (ori_sizes[inum][0].item(), ori_sizes[inum][1].item())
|
| 267 |
+
res = nn.functional.interpolate(scaled_preds[inum].unsqueeze(0), size=ori_size, mode='bilinear',
|
| 268 |
+
align_corners=True)
|
| 269 |
+
save_tensor_img(res, os.path.join(saved_root, subpath))
|
| 270 |
+
|
| 271 |
+
eval_loader = EvalDataset(
|
| 272 |
+
saved_root, # preds
|
| 273 |
+
os.path.join('./data/gts', testset) # GT
|
| 274 |
+
)
|
| 275 |
+
evaler = Eval_thread(eval_loader, cuda=True)
|
| 276 |
+
# Use S_measure for validation
|
| 277 |
+
s_measure = evaler.Eval_Smeasure()
|
| 278 |
+
if s_measure > config.val_measures['Smeasure']['CoCA'] and 0:
|
| 279 |
+
# TODO: evluate others measures if s_measure is very high.
|
| 280 |
+
e_max = evaler.Eval_Emeasure().max().item()
|
| 281 |
+
f_max = evaler.Eval_fmeasure().max().item()
|
| 282 |
+
print('Emax: {:4.f}, Fmax: {:4.f}'.format(e_max, f_max))
|
| 283 |
+
measures.append(s_measure)
|
| 284 |
+
|
| 285 |
+
model.train()
|
| 286 |
+
return measures
|
| 287 |
+
|
| 288 |
+
if __name__ == '__main__':
|
| 289 |
+
main()
|
train_wandb.ipynb
ADDED
|
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [
|
| 8 |
+
{
|
| 9 |
+
"name": "stderr",
|
| 10 |
+
"output_type": "stream",
|
| 11 |
+
"text": [
|
| 12 |
+
"/home/wej36how/.conda/envs/vit/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
| 13 |
+
" from .autonotebook import tqdm as notebook_tqdm\n"
|
| 14 |
+
]
|
| 15 |
+
}
|
| 16 |
+
],
|
| 17 |
+
"source": [
|
| 18 |
+
"import torch\n",
|
| 19 |
+
"from torch.utils.data import DataLoader\n",
|
| 20 |
+
"from transformers import AdamW, ViTImageProcessor, ViTForImageClassification\n",
|
| 21 |
+
"from NWRD_dataset import NWRD\n",
|
| 22 |
+
"from tqdm import tqdm\n",
|
| 23 |
+
"import numpy as np\n",
|
| 24 |
+
"import torch.nn.functional as F\n",
|
| 25 |
+
"import os\n",
|
| 26 |
+
"import torch.optim as optim\n",
|
| 27 |
+
"from torchvision import transforms\n"
|
| 28 |
+
]
|
| 29 |
+
},
|
| 30 |
+
{
|
| 31 |
+
"cell_type": "code",
|
| 32 |
+
"execution_count": 2,
|
| 33 |
+
"metadata": {},
|
| 34 |
+
"outputs": [],
|
| 35 |
+
"source": [
|
| 36 |
+
"seed = 42\n",
|
| 37 |
+
"torch.manual_seed(seed)\n",
|
| 38 |
+
"np.random.seed(seed)\n",
|
| 39 |
+
"# If you are using CUDA, set this for further deterministic behavior\n",
|
| 40 |
+
"if torch.cuda.is_available():\n",
|
| 41 |
+
" torch.cuda.manual_seed(seed)\n",
|
| 42 |
+
" torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.\n",
|
| 43 |
+
" # Below settings are recommended for deterministic behavior when using specific convolution operations,\n",
|
| 44 |
+
" # but may reduce performance\n",
|
| 45 |
+
" torch.backends.cudnn.deterministic = True\n",
|
| 46 |
+
" torch.backends.cudnn.benchmark = False"
|
| 47 |
+
]
|
| 48 |
+
},
|
| 49 |
+
{
|
| 50 |
+
"cell_type": "code",
|
| 51 |
+
"execution_count": 3,
|
| 52 |
+
"metadata": {},
|
| 53 |
+
"outputs": [
|
| 54 |
+
{
|
| 55 |
+
"name": "stdout",
|
| 56 |
+
"output_type": "stream",
|
| 57 |
+
"text": [
|
| 58 |
+
"cpu\n"
|
| 59 |
+
]
|
| 60 |
+
}
|
| 61 |
+
],
|
| 62 |
+
"source": [
|
| 63 |
+
"device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
|
| 64 |
+
"CUDA_LAUNCH_BLOCKING=1\n",
|
| 65 |
+
"TORCH_USE_CUDA_DSA=1\n",
|
| 66 |
+
"print(device)"
|
| 67 |
+
]
|
| 68 |
+
},
|
| 69 |
+
{
|
| 70 |
+
"cell_type": "code",
|
| 71 |
+
"execution_count": 4,
|
| 72 |
+
"metadata": {},
|
| 73 |
+
"outputs": [],
|
| 74 |
+
"source": [
|
| 75 |
+
"transformations = transforms.Compose([\n",
|
| 76 |
+
" transforms.Resize((224, 224)), # Resize the image to 224x224\n",
|
| 77 |
+
" transforms.ToTensor() # Convert the image to a PyTorch tensor\n",
|
| 78 |
+
"])"
|
| 79 |
+
]
|
| 80 |
+
},
|
| 81 |
+
{
|
| 82 |
+
"cell_type": "code",
|
| 83 |
+
"execution_count": 5,
|
| 84 |
+
"metadata": {},
|
| 85 |
+
"outputs": [
|
| 86 |
+
{
|
| 87 |
+
"ename": "FileNotFoundError",
|
| 88 |
+
"evalue": "[Errno 2] No such file or directory: 'C:\\\\Users\\\\hasee\\\\Desktop\\\\Germany_2024\\\\Dataset\\\\NWRDprocessed\\\\train\\\\calssification/rust'",
|
| 89 |
+
"output_type": "error",
|
| 90 |
+
"traceback": [
|
| 91 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 92 |
+
"\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
|
| 93 |
+
"Cell \u001b[0;32mIn[5], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m train_ds \u001b[38;5;241m=\u001b[39m \u001b[43mNWRD\u001b[49m\u001b[43m(\u001b[49m\u001b[43mroot_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mC:\u001b[39;49m\u001b[38;5;130;43;01m\\\\\u001b[39;49;00m\u001b[38;5;124;43mUsers\u001b[39;49m\u001b[38;5;130;43;01m\\\\\u001b[39;49;00m\u001b[38;5;124;43mhasee\u001b[39;49m\u001b[38;5;130;43;01m\\\\\u001b[39;49;00m\u001b[38;5;124;43mDesktop\u001b[39;49m\u001b[38;5;130;43;01m\\\\\u001b[39;49;00m\u001b[38;5;124;43mGermany_2024\u001b[39;49m\u001b[38;5;130;43;01m\\\\\u001b[39;49;00m\u001b[38;5;124;43mDataset\u001b[39;49m\u001b[38;5;130;43;01m\\\\\u001b[39;49;00m\u001b[38;5;124;43mNWRDprocessed\u001b[39;49m\u001b[38;5;130;43;01m\\\\\u001b[39;49;00m\u001b[38;5;124;43mtrain\u001b[39;49m\u001b[38;5;130;43;01m\\\\\u001b[39;49;00m\u001b[38;5;124;43mcalssification\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtransform\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtransformations\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2\u001b[0m val_ds \u001b[38;5;241m=\u001b[39m NWRD(root_dir\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mC:\u001b[39m\u001b[38;5;130;01m\\\\\u001b[39;00m\u001b[38;5;124mUsers\u001b[39m\u001b[38;5;130;01m\\\\\u001b[39;00m\u001b[38;5;124mhasee\u001b[39m\u001b[38;5;130;01m\\\\\u001b[39;00m\u001b[38;5;124mDesktop\u001b[39m\u001b[38;5;130;01m\\\\\u001b[39;00m\u001b[38;5;124mGermany_2024\u001b[39m\u001b[38;5;130;01m\\\\\u001b[39;00m\u001b[38;5;124mDataset\u001b[39m\u001b[38;5;130;01m\\\\\u001b[39;00m\u001b[38;5;124mNWRDprocessed\u001b[39m\u001b[38;5;130;01m\\\\\u001b[39;00m\u001b[38;5;124mval\u001b[39m\u001b[38;5;130;01m\\\\\u001b[39;00m\u001b[38;5;124mcalssification\u001b[39m\u001b[38;5;124m\"\u001b[39m, train\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, transform\u001b[38;5;241m=\u001b[39mtransformations)\n\u001b[1;32m 4\u001b[0m train_loader \u001b[38;5;241m=\u001b[39m DataLoader(train_ds, batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m8\u001b[39m, shuffle\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n",
|
| 94 |
+
"File \u001b[0;32m~/codes/crossvit/NWRD_dataset.py:12\u001b[0m, in \u001b[0;36mNWRD.__init__\u001b[0;34m(self, root_dir, transform, train)\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mimages \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 11\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlabels \u001b[38;5;241m=\u001b[39m []\n\u001b[0;32m---> 12\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 95 |
+
"File \u001b[0;32m~/codes/crossvit/NWRD_dataset.py:19\u001b[0m, in \u001b[0;36mNWRD.load_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 16\u001b[0m non_rust_dir \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mroot_dir, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnon_rust\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 18\u001b[0m \u001b[38;5;66;03m# Load rust images\u001b[39;00m\n\u001b[0;32m---> 19\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m filename \u001b[38;5;129;01min\u001b[39;00m \u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlistdir\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrust_dir\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[1;32m 20\u001b[0m filepath \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(rust_dir, filename)\n\u001b[1;32m 21\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mimages\u001b[38;5;241m.\u001b[39mappend(filepath)\n",
|
| 96 |
+
"\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'C:\\\\Users\\\\hasee\\\\Desktop\\\\Germany_2024\\\\Dataset\\\\NWRDprocessed\\\\train\\\\calssification/rust'"
|
| 97 |
+
]
|
| 98 |
+
}
|
| 99 |
+
],
|
| 100 |
+
"source": [
|
| 101 |
+
"train_ds = NWRD(root_dir=\"C:\\\\Users\\\\hasee\\\\Desktop\\\\Germany_2024\\\\Dataset\\\\NWRDprocessed\\\\train\\\\calssification\", train=True, transform=transformations)\n",
|
| 102 |
+
"val_ds = NWRD(root_dir=\"C:\\\\Users\\\\hasee\\\\Desktop\\\\Germany_2024\\\\Dataset\\\\NWRDprocessed\\\\val\\\\calssification\", train=False, transform=transformations)\n",
|
| 103 |
+
" \n",
|
| 104 |
+
"train_loader = DataLoader(train_ds, batch_size=8, shuffle=True)\n",
|
| 105 |
+
"val_loader = DataLoader(val_ds, batch_size=8, shuffle=True)"
|
| 106 |
+
]
|
| 107 |
+
},
|
| 108 |
+
{
|
| 109 |
+
"cell_type": "code",
|
| 110 |
+
"execution_count": 6,
|
| 111 |
+
"metadata": {},
|
| 112 |
+
"outputs": [],
|
| 113 |
+
"source": [
|
| 114 |
+
"mean = [0.485, 0.456, 0.406] # Mean values for RGB channels\n",
|
| 115 |
+
"std = [0.229, 0.224, 0.225] # Standard deviation values for RGB channels\n",
|
| 116 |
+
"#processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224',transform={'mean': mean, 'std': std})\n",
|
| 117 |
+
"processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')\n",
|
| 118 |
+
"model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')\n",
|
| 119 |
+
"# processor.image_mean=mean\n",
|
| 120 |
+
"# processor.image_std=std\n",
|
| 121 |
+
"#print(processor)"
|
| 122 |
+
]
|
| 123 |
+
},
|
| 124 |
+
{
|
| 125 |
+
"cell_type": "code",
|
| 126 |
+
"execution_count": 8,
|
| 127 |
+
"metadata": {},
|
| 128 |
+
"outputs": [
|
| 129 |
+
{
|
| 130 |
+
"data": {
|
| 131 |
+
"text/plain": [
|
| 132 |
+
"ViTForImageClassification(\n",
|
| 133 |
+
" (vit): ViTModel(\n",
|
| 134 |
+
" (embeddings): ViTEmbeddings(\n",
|
| 135 |
+
" (patch_embeddings): ViTPatchEmbeddings(\n",
|
| 136 |
+
" (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))\n",
|
| 137 |
+
" )\n",
|
| 138 |
+
" (dropout): Dropout(p=0.0, inplace=False)\n",
|
| 139 |
+
" )\n",
|
| 140 |
+
" (encoder): ViTEncoder(\n",
|
| 141 |
+
" (layer): ModuleList(\n",
|
| 142 |
+
" (0-11): 12 x ViTLayer(\n",
|
| 143 |
+
" (attention): ViTSdpaAttention(\n",
|
| 144 |
+
" (attention): ViTSdpaSelfAttention(\n",
|
| 145 |
+
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
|
| 146 |
+
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
|
| 147 |
+
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
|
| 148 |
+
" (dropout): Dropout(p=0.0, inplace=False)\n",
|
| 149 |
+
" )\n",
|
| 150 |
+
" (output): ViTSelfOutput(\n",
|
| 151 |
+
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
|
| 152 |
+
" (dropout): Dropout(p=0.0, inplace=False)\n",
|
| 153 |
+
" )\n",
|
| 154 |
+
" )\n",
|
| 155 |
+
" (intermediate): ViTIntermediate(\n",
|
| 156 |
+
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
|
| 157 |
+
" (intermediate_act_fn): GELUActivation()\n",
|
| 158 |
+
" )\n",
|
| 159 |
+
" (output): ViTOutput(\n",
|
| 160 |
+
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
|
| 161 |
+
" (dropout): Dropout(p=0.0, inplace=False)\n",
|
| 162 |
+
" )\n",
|
| 163 |
+
" (layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
| 164 |
+
" (layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
| 165 |
+
" )\n",
|
| 166 |
+
" )\n",
|
| 167 |
+
" )\n",
|
| 168 |
+
" (layernorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
| 169 |
+
" )\n",
|
| 170 |
+
" (classifier): Linear(in_features=768, out_features=2, bias=True)\n",
|
| 171 |
+
")"
|
| 172 |
+
]
|
| 173 |
+
},
|
| 174 |
+
"execution_count": 8,
|
| 175 |
+
"metadata": {},
|
| 176 |
+
"output_type": "execute_result"
|
| 177 |
+
}
|
| 178 |
+
],
|
| 179 |
+
"source": [
|
| 180 |
+
"model.classifier = torch.nn.Linear(model.config.hidden_size, 2)\n",
|
| 181 |
+
"model.to(device)"
|
| 182 |
+
]
|
| 183 |
+
},
|
| 184 |
+
{
|
| 185 |
+
"cell_type": "markdown",
|
| 186 |
+
"metadata": {},
|
| 187 |
+
"source": [
|
| 188 |
+
"Finetuning of the model based on pretraining weights."
|
| 189 |
+
]
|
| 190 |
+
},
|
| 191 |
+
{
|
| 192 |
+
"cell_type": "code",
|
| 193 |
+
"execution_count": 8,
|
| 194 |
+
"metadata": {},
|
| 195 |
+
"outputs": [],
|
| 196 |
+
"source": [
|
| 197 |
+
"# model_weights = torch.load('/home/Hirra/coding_files/crossvit/weights/wandb_vit_base_final_med_val_NWRD_epoch_50_lr_0.000000001_wd_0.001_batch_size_8_unaugmented_unequlaized/49.pth')\n",
|
| 198 |
+
"# model.load_state_dict(model_weights.state_dict())"
|
| 199 |
+
]
|
| 200 |
+
},
|
| 201 |
+
{
|
| 202 |
+
"cell_type": "code",
|
| 203 |
+
"execution_count": 9,
|
| 204 |
+
"metadata": {},
|
| 205 |
+
"outputs": [],
|
| 206 |
+
"source": [
|
| 207 |
+
"optimizer = optim.SGD(model.parameters(), lr=0.00000003, weight_decay=0.001)\n",
|
| 208 |
+
"criterion = torch.nn.CrossEntropyLoss()\n",
|
| 209 |
+
"weights_directory = 'wandb_vit_base_final_for_time_NWRD_epoch_50_lr_0.000000003_wd_0.001_batch_size_8_unaugmented_training'\n",
|
| 210 |
+
"weight_loc = f\"weights/{weights_directory}\"\n",
|
| 211 |
+
"os.makedirs(weight_loc, exist_ok=True)"
|
| 212 |
+
]
|
| 213 |
+
},
|
| 214 |
+
{
|
| 215 |
+
"cell_type": "code",
|
| 216 |
+
"execution_count": 10,
|
| 217 |
+
"metadata": {},
|
| 218 |
+
"outputs": [
|
| 219 |
+
{
|
| 220 |
+
"name": "stderr",
|
| 221 |
+
"output_type": "stream",
|
| 222 |
+
"text": [
|
| 223 |
+
"Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n",
|
| 224 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mgptautomated\u001b[0m (\u001b[33mtukl_labwork\u001b[0m). Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n",
|
| 225 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m If you're specifying your api key in code, ensure this code is not shared publicly.\n",
|
| 226 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Consider setting the WANDB_API_KEY environment variable, or running `wandb login` from the command line.\n",
|
| 227 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: C:\\Users\\hasee\\.netrc\n"
|
| 228 |
+
]
|
| 229 |
+
},
|
| 230 |
+
{
|
| 231 |
+
"data": {
|
| 232 |
+
"text/plain": [
|
| 233 |
+
"True"
|
| 234 |
+
]
|
| 235 |
+
},
|
| 236 |
+
"execution_count": 10,
|
| 237 |
+
"metadata": {},
|
| 238 |
+
"output_type": "execute_result"
|
| 239 |
+
}
|
| 240 |
+
],
|
| 241 |
+
"source": [
|
| 242 |
+
"import wandb, os\n",
|
| 243 |
+
"#wandb.login()\n",
|
| 244 |
+
"wandb.login(key=\"4e8a21c26ae61cced8d70053c80bbe1b112fec12\")\n",
|
| 245 |
+
"#4e8a21c26ae61cced8d70053c80bbe1b112fec12"
|
| 246 |
+
]
|
| 247 |
+
},
|
| 248 |
+
{
|
| 249 |
+
"cell_type": "code",
|
| 250 |
+
"execution_count": 11,
|
| 251 |
+
"metadata": {},
|
| 252 |
+
"outputs": [
|
| 253 |
+
{
|
| 254 |
+
"name": "stdout",
|
| 255 |
+
"output_type": "stream",
|
| 256 |
+
"text": [
|
| 257 |
+
"env: WANDB_PROJECT=crossvit_rust_classifier_new\n"
|
| 258 |
+
]
|
| 259 |
+
}
|
| 260 |
+
],
|
| 261 |
+
"source": [
|
| 262 |
+
"%env WANDB_PROJECT=crossvit_rust_classifier_new\n",
|
| 263 |
+
"os.environ[\"WANDB_PROJECT\"] = \"<crossvit>\"\n",
|
| 264 |
+
"os.environ[\"WANDB_REPORT_TO\"] = \"wandb\""
|
| 265 |
+
]
|
| 266 |
+
},
|
| 267 |
+
{
|
| 268 |
+
"cell_type": "code",
|
| 269 |
+
"execution_count": 12,
|
| 270 |
+
"metadata": {},
|
| 271 |
+
"outputs": [
|
| 272 |
+
{
|
| 273 |
+
"data": {
|
| 274 |
+
"text/html": [
|
| 275 |
+
"Changes to your `wandb` environment variables will be ignored because your `wandb` session has already started. For more information on how to modify your settings with `wandb.init()` arguments, please refer to <a href='https://wandb.me/wandb-init' target=\"_blank\">the W&B docs</a>."
|
| 276 |
+
],
|
| 277 |
+
"text/plain": [
|
| 278 |
+
"<IPython.core.display.HTML object>"
|
| 279 |
+
]
|
| 280 |
+
},
|
| 281 |
+
"metadata": {},
|
| 282 |
+
"output_type": "display_data"
|
| 283 |
+
},
|
| 284 |
+
{
|
| 285 |
+
"data": {
|
| 286 |
+
"text/html": [
|
| 287 |
+
"wandb version 0.17.3 is available! To upgrade, please run:\n",
|
| 288 |
+
" $ pip install wandb --upgrade"
|
| 289 |
+
],
|
| 290 |
+
"text/plain": [
|
| 291 |
+
"<IPython.core.display.HTML object>"
|
| 292 |
+
]
|
| 293 |
+
},
|
| 294 |
+
"metadata": {},
|
| 295 |
+
"output_type": "display_data"
|
| 296 |
+
},
|
| 297 |
+
{
|
| 298 |
+
"data": {
|
| 299 |
+
"text/html": [
|
| 300 |
+
"Tracking run with wandb version 0.17.2"
|
| 301 |
+
],
|
| 302 |
+
"text/plain": [
|
| 303 |
+
"<IPython.core.display.HTML object>"
|
| 304 |
+
]
|
| 305 |
+
},
|
| 306 |
+
"metadata": {},
|
| 307 |
+
"output_type": "display_data"
|
| 308 |
+
},
|
| 309 |
+
{
|
| 310 |
+
"data": {
|
| 311 |
+
"text/html": [
|
| 312 |
+
"Run data is saved locally in <code>c:\\Users\\hasee\\Desktop\\Germany_2024\\codes\\crossvit\\wandb\\run-20240626_161631-bgtm3oyt</code>"
|
| 313 |
+
],
|
| 314 |
+
"text/plain": [
|
| 315 |
+
"<IPython.core.display.HTML object>"
|
| 316 |
+
]
|
| 317 |
+
},
|
| 318 |
+
"metadata": {},
|
| 319 |
+
"output_type": "display_data"
|
| 320 |
+
},
|
| 321 |
+
{
|
| 322 |
+
"data": {
|
| 323 |
+
"text/html": [
|
| 324 |
+
"Syncing run <strong><a href='https://wandb.ai/tukl_labwork/uncategorized/runs/bgtm3oyt' target=\"_blank\">glamorous-wood-74</a></strong> to <a href='https://wandb.ai/tukl_labwork/uncategorized' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
|
| 325 |
+
],
|
| 326 |
+
"text/plain": [
|
| 327 |
+
"<IPython.core.display.HTML object>"
|
| 328 |
+
]
|
| 329 |
+
},
|
| 330 |
+
"metadata": {},
|
| 331 |
+
"output_type": "display_data"
|
| 332 |
+
},
|
| 333 |
+
{
|
| 334 |
+
"data": {
|
| 335 |
+
"text/html": [
|
| 336 |
+
" View project at <a href='https://wandb.ai/tukl_labwork/uncategorized' target=\"_blank\">https://wandb.ai/tukl_labwork/uncategorized</a>"
|
| 337 |
+
],
|
| 338 |
+
"text/plain": [
|
| 339 |
+
"<IPython.core.display.HTML object>"
|
| 340 |
+
]
|
| 341 |
+
},
|
| 342 |
+
"metadata": {},
|
| 343 |
+
"output_type": "display_data"
|
| 344 |
+
},
|
| 345 |
+
{
|
| 346 |
+
"data": {
|
| 347 |
+
"text/html": [
|
| 348 |
+
" View run at <a href='https://wandb.ai/tukl_labwork/uncategorized/runs/bgtm3oyt' target=\"_blank\">https://wandb.ai/tukl_labwork/uncategorized/runs/bgtm3oyt</a>"
|
| 349 |
+
],
|
| 350 |
+
"text/plain": [
|
| 351 |
+
"<IPython.core.display.HTML object>"
|
| 352 |
+
]
|
| 353 |
+
},
|
| 354 |
+
"metadata": {},
|
| 355 |
+
"output_type": "display_data"
|
| 356 |
+
},
|
| 357 |
+
{
|
| 358 |
+
"name": "stderr",
|
| 359 |
+
"output_type": "stream",
|
| 360 |
+
"text": [
|
| 361 |
+
" 0%| | 0/241 [00:00<?, ?it/s]c:\\Users\\hasee\\miniconda3\\envs\\segformer\\Lib\\site-packages\\transformers\\models\\vit\\modeling_vit.py:253: UserWarning: 1Torch was not compiled with flash attention. (Triggered internally at C:\\cb\\pytorch_1000000000000\\work\\aten\\src\\ATen\\native\\transformers\\cuda\\sdp_utils.cpp:455.)\n",
|
| 362 |
+
" context_layer = torch.nn.functional.scaled_dot_product_attention(\n",
|
| 363 |
+
"Epoch 0 train Loss 0.6551: 21%|██ | 51/241 [00:27<01:42, 1.85it/s]\n"
|
| 364 |
+
]
|
| 365 |
+
},
|
| 366 |
+
{
|
| 367 |
+
"ename": "KeyboardInterrupt",
|
| 368 |
+
"evalue": "",
|
| 369 |
+
"output_type": "error",
|
| 370 |
+
"traceback": [
|
| 371 |
+
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
| 372 |
+
"\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
| 373 |
+
"Cell \u001b[1;32mIn[12], line 22\u001b[0m\n\u001b[0;32m 17\u001b[0m \u001b[38;5;66;03m# print(\"logits\", logits)\u001b[39;00m\n\u001b[0;32m 18\u001b[0m \u001b[38;5;66;03m# print(\"prediction\", predication)\u001b[39;00m\n\u001b[0;32m 19\u001b[0m \u001b[38;5;66;03m# print(\"labels\", labels)\u001b[39;00m\n\u001b[0;32m 21\u001b[0m loss \u001b[38;5;241m=\u001b[39m criterion(logits, labels)\n\u001b[1;32m---> 22\u001b[0m train_losses\u001b[38;5;241m.\u001b[39mappend(\u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mitem\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[0;32m 23\u001b[0m loss\u001b[38;5;241m.\u001b[39mbackward()\n\u001b[0;32m 24\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mstep()\n",
|
| 374 |
+
"\u001b[1;31mKeyboardInterrupt\u001b[0m: "
|
| 375 |
+
]
|
| 376 |
+
}
|
| 377 |
+
],
|
| 378 |
+
"source": [
|
| 379 |
+
"wandb.init()\n",
|
| 380 |
+
"\n",
|
| 381 |
+
"best_epoch = {}\n",
|
| 382 |
+
"train_losses = []\n",
|
| 383 |
+
"for epoch in range(50):\n",
|
| 384 |
+
" model.train\n",
|
| 385 |
+
" train_losses=[]\n",
|
| 386 |
+
" loop = tqdm(enumerate(train_loader), total=len(train_loader))\n",
|
| 387 |
+
" for batch_idx, (images, labels) in loop:\n",
|
| 388 |
+
" inputs = processor(images=images, return_tensors=\"pt\", do_rescale=False).to(device)\n",
|
| 389 |
+
" labels = labels.to(device)\n",
|
| 390 |
+
"\n",
|
| 391 |
+
" outputs = model(**inputs)\n",
|
| 392 |
+
" logits = outputs.logits\n",
|
| 393 |
+
" predication = logits.argmax(axis=1)\n",
|
| 394 |
+
" \n",
|
| 395 |
+
" # print(\"logits\", logits)\n",
|
| 396 |
+
" # print(\"prediction\", predication)\n",
|
| 397 |
+
" # print(\"labels\", labels)\n",
|
| 398 |
+
" \n",
|
| 399 |
+
" loss = criterion(logits, labels)\n",
|
| 400 |
+
" train_losses.append(loss.item())\n",
|
| 401 |
+
" loss.backward()\n",
|
| 402 |
+
" optimizer.step()\n",
|
| 403 |
+
" loop.set_description(f\"Epoch {epoch} train Loss {np.mean(train_losses):.4f}\")\n",
|
| 404 |
+
"\n",
|
| 405 |
+
"\n",
|
| 406 |
+
" print(\"Epoch \"+str(epoch)+\" Train Loss \"+str(np.mean(train_losses)))\n",
|
| 407 |
+
" torch.save(model, weight_loc+'/{}.pth'.format(epoch))\n",
|
| 408 |
+
" wandb.log({\"train_loss\": np.mean(train_losses), \"epoch\": epoch})\n",
|
| 409 |
+
"\n",
|
| 410 |
+
" #validation\n",
|
| 411 |
+
" optimizer.zero_grad()\n",
|
| 412 |
+
" model.eval\n",
|
| 413 |
+
" val_losses=[]\n",
|
| 414 |
+
"\n",
|
| 415 |
+
" loop = tqdm(enumerate(val_loader), total=len(val_loader))\n",
|
| 416 |
+
" with torch.no_grad():\n",
|
| 417 |
+
" for batch_idx, (images, labels) in loop:\n",
|
| 418 |
+
" inputs = processor(images=images, return_tensors=\"pt\", do_rescale=False).to(device)\n",
|
| 419 |
+
" labels = labels.to(device)\n",
|
| 420 |
+
"\n",
|
| 421 |
+
" outputs = model(**inputs)\n",
|
| 422 |
+
" logits = outputs.logits\n",
|
| 423 |
+
" \n",
|
| 424 |
+
" loss = criterion(logits, labels)\n",
|
| 425 |
+
" val_losses.append(loss.item())\n",
|
| 426 |
+
"\n",
|
| 427 |
+
" predication = logits.argmax(axis=1)\n",
|
| 428 |
+
"\n",
|
| 429 |
+
" loss = criterion(logits, labels)\n",
|
| 430 |
+
" val_losses.append(loss.item())\n",
|
| 431 |
+
" \n",
|
| 432 |
+
" loop.set_description(f\"Epoch {epoch} Val Loss {np.mean(val_losses):.4f}\")\n",
|
| 433 |
+
" wandb.log({\"val_loss\": np.mean(val_losses), \"epoch\": epoch})\n",
|
| 434 |
+
"torch.cuda.empty_cache()\n"
|
| 435 |
+
]
|
| 436 |
+
}
|
| 437 |
+
],
|
| 438 |
+
"metadata": {
|
| 439 |
+
"kernelspec": {
|
| 440 |
+
"display_name": "crossvit",
|
| 441 |
+
"language": "python",
|
| 442 |
+
"name": "python3"
|
| 443 |
+
},
|
| 444 |
+
"language_info": {
|
| 445 |
+
"codemirror_mode": {
|
| 446 |
+
"name": "ipython",
|
| 447 |
+
"version": 3
|
| 448 |
+
},
|
| 449 |
+
"file_extension": ".py",
|
| 450 |
+
"mimetype": "text/x-python",
|
| 451 |
+
"name": "python",
|
| 452 |
+
"nbconvert_exporter": "python",
|
| 453 |
+
"pygments_lexer": "ipython3",
|
| 454 |
+
"version": "3.8.19"
|
| 455 |
+
}
|
| 456 |
+
},
|
| 457 |
+
"nbformat": 4,
|
| 458 |
+
"nbformat_minor": 2
|
| 459 |
+
}
|
util.cpython-38.pyc
ADDED
|
Binary file (3.5 kB). View file
|
|
|
util.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
import shutil
|
| 5 |
+
from torchvision import transforms
|
| 6 |
+
import numpy as np
|
| 7 |
+
import random
|
| 8 |
+
import cv2
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Logger():
|
| 12 |
+
def __init__(self, path="log.txt"):
|
| 13 |
+
self.logger = logging.getLogger('DCFM')
|
| 14 |
+
self.file_handler = logging.FileHandler(path, "w")
|
| 15 |
+
self.stdout_handler = logging.StreamHandler()
|
| 16 |
+
self.stdout_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s'))
|
| 17 |
+
self.file_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s'))
|
| 18 |
+
self.logger.addHandler(self.file_handler)
|
| 19 |
+
self.logger.addHandler(self.stdout_handler)
|
| 20 |
+
self.logger.setLevel(logging.INFO)
|
| 21 |
+
self.logger.propagate = False
|
| 22 |
+
|
| 23 |
+
def info(self, txt):
|
| 24 |
+
self.logger.info(txt)
|
| 25 |
+
|
| 26 |
+
def close(self):
|
| 27 |
+
self.file_handler.close()
|
| 28 |
+
self.stdout_handler.close()
|
| 29 |
+
|
| 30 |
+
class AverageMeter(object):
|
| 31 |
+
"""Computes and stores the average and current value"""
|
| 32 |
+
def __init__(self):
|
| 33 |
+
self.reset()
|
| 34 |
+
|
| 35 |
+
def reset(self):
|
| 36 |
+
self.val = 0.0
|
| 37 |
+
self.avg = 0.0
|
| 38 |
+
self.sum = 0.0
|
| 39 |
+
self.count = 0.0
|
| 40 |
+
|
| 41 |
+
def update(self, val, n=1):
|
| 42 |
+
self.val = val
|
| 43 |
+
self.sum += val * n
|
| 44 |
+
self.count += n
|
| 45 |
+
self.avg = self.sum / self.count
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def save_checkpoint(state, path, filename="checkpoint.pth"):
|
| 49 |
+
torch.save(state, os.path.join(path, filename))
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def save_tensor_img(tenor_im, path):
|
| 53 |
+
im = tenor_im.cpu().clone()
|
| 54 |
+
im = im.squeeze(0)
|
| 55 |
+
tensor2pil = transforms.ToPILImage()
|
| 56 |
+
im = tensor2pil(im)
|
| 57 |
+
im.save(path)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def save_tensor_merge(tenor_im, tensor_mask, path, colormap='HOT'):
|
| 61 |
+
im = tenor_im.cpu().detach().clone()
|
| 62 |
+
im = im.squeeze(0).numpy()
|
| 63 |
+
im = ((im - np.min(im)) / (np.max(im) - np.min(im) + 1e-20)) * 255
|
| 64 |
+
im = np.array(im,np.uint8)
|
| 65 |
+
mask = tensor_mask.cpu().detach().clone()
|
| 66 |
+
mask = mask.squeeze(0).numpy()
|
| 67 |
+
mask = ((mask - np.min(mask)) / (np.max(mask) - np.min(mask) + 1e-20)) * 255
|
| 68 |
+
mask = np.clip(mask, 0, 255)
|
| 69 |
+
mask = np.array(mask, np.uint8)
|
| 70 |
+
if colormap == 'HOT':
|
| 71 |
+
mask = cv2.applyColorMap(mask[0,:,:], cv2.COLORMAP_HOT)
|
| 72 |
+
elif colormap == 'PINK':
|
| 73 |
+
mask = cv2.applyColorMap(mask[0,:,:], cv2.COLORMAP_PINK)
|
| 74 |
+
elif colormap == 'BONE':
|
| 75 |
+
mask = cv2.applyColorMap(mask[0,:,:], cv2.COLORMAP_BONE)
|
| 76 |
+
# exec('cv2.applyColorMap(mask[0,:,:], cv2.COLORMAP_' + colormap+')')
|
| 77 |
+
im = im.transpose((1, 2, 0))
|
| 78 |
+
im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
|
| 79 |
+
mix = cv2.addWeighted(im, 0.3, mask, 0.7, 0)
|
| 80 |
+
cv2.imwrite(path, mix)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def set_seed(seed):
|
| 84 |
+
torch.manual_seed(seed)
|
| 85 |
+
torch.cuda.manual_seed(seed)
|
| 86 |
+
torch.cuda.manual_seed_all(seed)
|
| 87 |
+
np.random.seed(seed)
|
| 88 |
+
random.seed(seed)
|
| 89 |
+
torch.backends.cudnn.deterministic = True
|
| 90 |
+
torch.backends.cudnn.benchmark = False
|
| 91 |
+
|
| 92 |
+
|
utils.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
utils.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import utils.utils as gen_utils
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
def adjust_rate_poly(cur_iter, max_iter, power=0.9):
|
| 8 |
+
return (1.0 - 1.0 * cur_iter / max_iter) ** power
|
| 9 |
+
|
| 10 |
+
def adjust_learning_rate_exp(lr, optimizer, iters, decay_rate=0.1, decay_step=25):
|
| 11 |
+
lr = lr * (decay_rate ** (iters // decay_step))
|
| 12 |
+
for param_group in optimizer.param_groups:
|
| 13 |
+
param_group['lr'] = lr * param_group['lr_mult']
|
| 14 |
+
|
| 15 |
+
def adjust_learning_rate_RevGrad(lr, optimizer, max_iter, cur_iter,
|
| 16 |
+
alpha=10, beta=0.75):
|
| 17 |
+
p = 1.0 * cur_iter / (max_iter - 1)
|
| 18 |
+
lr = lr / pow(1.0 + alpha * p, beta)
|
| 19 |
+
for param_group in optimizer.param_groups:
|
| 20 |
+
param_group['lr'] = lr * param_group['lr_mult']
|
| 21 |
+
|
| 22 |
+
def adjust_learning_rate_inv(lr, optimizer, iters, alpha=0.001, beta=0.75):
|
| 23 |
+
lr = lr / pow(1.0 + alpha * iters, beta)
|
| 24 |
+
for param_group in optimizer.param_groups:
|
| 25 |
+
param_group['lr'] = lr * param_group['lr_mult']
|
| 26 |
+
|
| 27 |
+
def adjust_learning_rate_step(lr, optimizer, iters, steps, beta=0.1):
|
| 28 |
+
n = 0
|
| 29 |
+
for step in steps:
|
| 30 |
+
if iters < step:
|
| 31 |
+
break
|
| 32 |
+
n += 1
|
| 33 |
+
|
| 34 |
+
lr = lr * (beta ** n)
|
| 35 |
+
for param_group in optimizer.param_groups:
|
| 36 |
+
param_group['lr'] = lr * param_group['lr_mult']
|
| 37 |
+
|
| 38 |
+
def adjust_learning_rate_poly(lr, optimizer, iters, max_iter, power=0.9):
|
| 39 |
+
lr = lr * (1.0 - 1.0 * iters / max_iter) ** power
|
| 40 |
+
for param_group in optimizer.param_groups:
|
| 41 |
+
param_group['lr'] = lr * param_group['lr_mult']
|
| 42 |
+
|
| 43 |
+
def set_param_groups(net, lr_mult_dict={}):
|
| 44 |
+
params = []
|
| 45 |
+
if hasattr(net, "module"):
|
| 46 |
+
net = net.module
|
| 47 |
+
|
| 48 |
+
modules = net._modules
|
| 49 |
+
for name in modules:
|
| 50 |
+
module = modules[name]
|
| 51 |
+
if name in lr_mult_dict:
|
| 52 |
+
params += [{'params': module.parameters(), \
|
| 53 |
+
'lr_mult': lr_mult_dict[name]}]
|
| 54 |
+
else:
|
| 55 |
+
params += [{'params': module.parameters(), 'lr_mult': 1.0}]
|
| 56 |
+
|
| 57 |
+
return params
|
| 58 |
+
|
| 59 |
+
def LSR(x, dim=1, thres=10.0):
|
| 60 |
+
lsr = -1.0 * torch.mean(x, dim=dim)
|
| 61 |
+
if thres > 0.0:
|
| 62 |
+
return torch.mean((lsr/thres-1.0).detach() * lsr)
|
| 63 |
+
else:
|
| 64 |
+
return torch.mean(lsr)
|
| 65 |
+
|
| 66 |
+
def crop(feats, preds, gt, h, w):
|
| 67 |
+
H, W = feats.shape[-2:]
|
| 68 |
+
tmp_feats = []
|
| 69 |
+
tmp_preds = []
|
| 70 |
+
tmp_gt = []
|
| 71 |
+
N = feats.size(0)
|
| 72 |
+
for i in range(N):
|
| 73 |
+
inds_H = torch.randperm(H)[0:h]
|
| 74 |
+
inds_W = torch.randperm(W)[0:w]
|
| 75 |
+
tmp_feats += [feats[i, :, inds_H[:, None], inds_W]]
|
| 76 |
+
tmp_preds += [preds[i, :, inds_H[:, None], inds_W]]
|
| 77 |
+
tmp_gt += [gt[i, inds_H[:, None], inds_W]]
|
| 78 |
+
|
| 79 |
+
new_feats = torch.stack(tmp_feats, dim=0)
|
| 80 |
+
new_gt = torch.stack(tmp_gt, dim=0)
|
| 81 |
+
new_preds = torch.stack(tmp_preds, dim=0)
|
| 82 |
+
probs = F.softmax(new_preds, dim=1)
|
| 83 |
+
return new_feats, probs, new_gt
|
vgg.cpython-37.pyc
ADDED
|
Binary file (3.75 kB). View file
|
|
|
vgg.cpython-38.pyc
ADDED
|
Binary file (3.78 kB). View file
|
|
|
vgg.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class VGG_Backbone(nn.Module):
|
| 7 |
+
# VGG16 with two branches
|
| 8 |
+
# pooling layer at the front of block
|
| 9 |
+
def __init__(self):
|
| 10 |
+
super(VGG_Backbone, self).__init__()
|
| 11 |
+
conv1 = nn.Sequential()
|
| 12 |
+
conv1.add_module('conv1_1', nn.Conv2d(3, 64, 3, 1, 1))
|
| 13 |
+
conv1.add_module('relu1_1', nn.ReLU(inplace=True))
|
| 14 |
+
conv1.add_module('conv1_2', nn.Conv2d(64, 64, 3, 1, 1))
|
| 15 |
+
conv1.add_module('relu1_2', nn.ReLU(inplace=True))
|
| 16 |
+
self.conv1 = conv1
|
| 17 |
+
|
| 18 |
+
conv2 = nn.Sequential()
|
| 19 |
+
conv2.add_module('pool1', nn.MaxPool2d(2, stride=2))
|
| 20 |
+
conv2.add_module('conv2_1', nn.Conv2d(64, 128, 3, 1, 1))
|
| 21 |
+
conv2.add_module('relu2_1', nn.ReLU())
|
| 22 |
+
conv2.add_module('conv2_2', nn.Conv2d(128, 128, 3, 1, 1))
|
| 23 |
+
conv2.add_module('relu2_2', nn.ReLU())
|
| 24 |
+
self.conv2 = conv2
|
| 25 |
+
|
| 26 |
+
conv3 = nn.Sequential()
|
| 27 |
+
conv3.add_module('pool2', nn.MaxPool2d(2, stride=2))
|
| 28 |
+
conv3.add_module('conv3_1', nn.Conv2d(128, 256, 3, 1, 1))
|
| 29 |
+
conv3.add_module('relu3_1', nn.ReLU())
|
| 30 |
+
conv3.add_module('conv3_2', nn.Conv2d(256, 256, 3, 1, 1))
|
| 31 |
+
conv3.add_module('relu3_2', nn.ReLU())
|
| 32 |
+
conv3.add_module('conv3_3', nn.Conv2d(256, 256, 3, 1, 1))
|
| 33 |
+
conv3.add_module('relu3_3', nn.ReLU())
|
| 34 |
+
self.conv3 = conv3
|
| 35 |
+
|
| 36 |
+
conv4 = nn.Sequential()
|
| 37 |
+
conv4.add_module('pool3', nn.MaxPool2d(2, stride=2))
|
| 38 |
+
conv4.add_module('conv4_1', nn.Conv2d(256, 512, 3, 1, 1))
|
| 39 |
+
conv4.add_module('relu4_1', nn.ReLU())
|
| 40 |
+
conv4.add_module('conv4_2', nn.Conv2d(512, 512, 3, 1, 1))
|
| 41 |
+
conv4.add_module('relu4_2', nn.ReLU())
|
| 42 |
+
conv4.add_module('conv4_3', nn.Conv2d(512, 512, 3, 1, 1))
|
| 43 |
+
conv4.add_module('relu4_3', nn.ReLU())
|
| 44 |
+
self.conv4 = conv4
|
| 45 |
+
|
| 46 |
+
conv5 = nn.Sequential()
|
| 47 |
+
conv5.add_module('pool4', nn.MaxPool2d(2, stride=2))
|
| 48 |
+
conv5.add_module('conv5_1', nn.Conv2d(512, 512, 3, 1, 1))
|
| 49 |
+
conv5.add_module('relu5_1', nn.ReLU())
|
| 50 |
+
conv5.add_module('conv5_2', nn.Conv2d(512, 512, 3, 1, 1))
|
| 51 |
+
conv5.add_module('relu5_2', nn.ReLU())
|
| 52 |
+
conv5.add_module('conv5_3', nn.Conv2d(512, 512, 3, 1, 1))
|
| 53 |
+
conv5.add_module('relu5_3', nn.ReLU())
|
| 54 |
+
self.conv5 = conv5
|
| 55 |
+
|
| 56 |
+
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
|
| 57 |
+
self.classifier = nn.Sequential(
|
| 58 |
+
nn.Linear(512 * 7 * 7, 4096),
|
| 59 |
+
nn.ReLU(True),
|
| 60 |
+
nn.Dropout(),
|
| 61 |
+
nn.Linear(4096, 4096),
|
| 62 |
+
nn.ReLU(True),
|
| 63 |
+
nn.Dropout(),
|
| 64 |
+
nn.Linear(4096, 1000),
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# pre_train = torch.load(os.path.dirname(__file__) + '/vgg16-397923af.pth')
|
| 68 |
+
pre_train = torch.load("/scratch/wej36how/codes/DCFM-master/vgg16-397923af.pth")
|
| 69 |
+
self._initialize_weights(pre_train)
|
| 70 |
+
|
| 71 |
+
def forward(self, x):
|
| 72 |
+
x = self.conv1(x)
|
| 73 |
+
x = self.conv2(x)
|
| 74 |
+
x = self.conv3(x)
|
| 75 |
+
x1 = self.conv4_1(x)
|
| 76 |
+
x1 = self.conv5_1(x1)
|
| 77 |
+
x1 = self.avgpool(x1)
|
| 78 |
+
_x1 = x1.view(x1.size(0), -1)
|
| 79 |
+
pred_vector = self.classifier(_x1)
|
| 80 |
+
|
| 81 |
+
x2 = self.conv4_2(x)
|
| 82 |
+
x2 = self.conv5_2(x2)
|
| 83 |
+
return x1, pred_vector, x2
|
| 84 |
+
|
| 85 |
+
def _initialize_weights(self, pre_train):
|
| 86 |
+
keys = list(pre_train.keys())
|
| 87 |
+
self.conv1.conv1_1.weight.data.copy_(pre_train[keys[0]])
|
| 88 |
+
self.conv1.conv1_2.weight.data.copy_(pre_train[keys[2]])
|
| 89 |
+
self.conv2.conv2_1.weight.data.copy_(pre_train[keys[4]])
|
| 90 |
+
self.conv2.conv2_2.weight.data.copy_(pre_train[keys[6]])
|
| 91 |
+
self.conv3.conv3_1.weight.data.copy_(pre_train[keys[8]])
|
| 92 |
+
self.conv3.conv3_2.weight.data.copy_(pre_train[keys[10]])
|
| 93 |
+
self.conv3.conv3_3.weight.data.copy_(pre_train[keys[12]])
|
| 94 |
+
self.conv4.conv4_1.weight.data.copy_(pre_train[keys[14]])
|
| 95 |
+
self.conv4.conv4_2.weight.data.copy_(pre_train[keys[16]])
|
| 96 |
+
self.conv4.conv4_3.weight.data.copy_(pre_train[keys[18]])
|
| 97 |
+
self.conv5.conv5_1.weight.data.copy_(pre_train[keys[20]])
|
| 98 |
+
self.conv5.conv5_2.weight.data.copy_(pre_train[keys[22]])
|
| 99 |
+
self.conv5.conv5_3.weight.data.copy_(pre_train[keys[24]])
|
| 100 |
+
|
| 101 |
+
self.conv1.conv1_1.bias.data.copy_(pre_train[keys[1]])
|
| 102 |
+
self.conv1.conv1_2.bias.data.copy_(pre_train[keys[3]])
|
| 103 |
+
self.conv2.conv2_1.bias.data.copy_(pre_train[keys[5]])
|
| 104 |
+
self.conv2.conv2_2.bias.data.copy_(pre_train[keys[7]])
|
| 105 |
+
self.conv3.conv3_1.bias.data.copy_(pre_train[keys[9]])
|
| 106 |
+
self.conv3.conv3_2.bias.data.copy_(pre_train[keys[11]])
|
| 107 |
+
self.conv3.conv3_3.bias.data.copy_(pre_train[keys[13]])
|
| 108 |
+
self.conv4.conv4_1.bias.data.copy_(pre_train[keys[15]])
|
| 109 |
+
self.conv4.conv4_2.bias.data.copy_(pre_train[keys[17]])
|
| 110 |
+
self.conv4.conv4_3.bias.data.copy_(pre_train[keys[19]])
|
| 111 |
+
self.conv5.conv5_1.bias.data.copy_(pre_train[keys[21]])
|
| 112 |
+
self.conv5.conv5_2.bias.data.copy_(pre_train[keys[23]])
|
| 113 |
+
self.conv5.conv5_3.bias.data.copy_(pre_train[keys[25]])
|
| 114 |
+
|
| 115 |
+
self.classifier[0].weight.data.copy_(pre_train[keys[26]])
|
| 116 |
+
self.classifier[0].bias.data.copy_(pre_train[keys[27]])
|
| 117 |
+
self.classifier[3].weight.data.copy_(pre_train[keys[28]])
|
| 118 |
+
self.classifier[3].bias.data.copy_(pre_train[keys[29]])
|
| 119 |
+
self.classifier[6].weight.data.copy_(pre_train[keys[30]])
|
| 120 |
+
self.classifier[6].bias.data.copy_(pre_train[keys[31]])
|