jingyi0305 commited on
Commit
206ebf9
·
verified ·
1 Parent(s): 0752931

Upload 8 files

Browse files
data/LOLdataset.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import random
4
+ import torch
5
+ import torch.utils.data as data
6
+ import numpy as np
7
+ from os import listdir
8
+ from os.path import join
9
+ from data.util import *
10
+ from torchvision import transforms as t
11
+
12
+
13
+ class LOLDatasetFromFolder(data.Dataset):
14
+ def __init__(self, data_dir, transform=None):
15
+ super(LOLDatasetFromFolder, self).__init__()
16
+ self.data_dir = data_dir
17
+ self.transform = transform
18
+ self.norm = t.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
19
+
20
+ def __getitem__(self, index):
21
+
22
+ folder = self.data_dir+'/low'
23
+ folder2= self.data_dir+'/high'
24
+ data_filenames = [join(folder, x) for x in listdir(folder) if is_image_file(x)]
25
+ data_filenames2 = [join(folder2, x) for x in listdir(folder2) if is_image_file(x)]
26
+ num = len(data_filenames)
27
+
28
+ im1 = load_img(data_filenames[index])
29
+ im2 = load_img(data_filenames2[index])
30
+ _, file1 = os.path.split(data_filenames[index])
31
+ _, file2 = os.path.split(data_filenames2[index])
32
+ seed = random.randint(1, 1000000)
33
+ seed = np.random.randint(seed) # make a seed with numpy generator
34
+ if self.transform:
35
+ random.seed(seed) # apply this seed to img tranfsorms
36
+ torch.manual_seed(seed) # needed for torchvision 0.7
37
+ im1 = self.transform(im1)
38
+ random.seed(seed)
39
+ torch.manual_seed(seed)
40
+ im2 = self.transform(im2)
41
+ return im1, im2, file1, file2
42
+
43
+ def __len__(self):
44
+ return 485
45
+
46
+
47
+ class LOLv2DatasetFromFolder(data.Dataset):
48
+ def __init__(self, data_dir, transform=None):
49
+ super(LOLv2DatasetFromFolder, self).__init__()
50
+ self.data_dir = data_dir
51
+ self.transform = transform
52
+
53
+ def __getitem__(self, index):
54
+
55
+ folder = self.data_dir+'/Low'
56
+ folder2= self.data_dir+'/Normal'
57
+ data_filenames = [join(folder, x) for x in listdir(folder) if is_image_file(x)]
58
+ data_filenames2 = [join(folder2, x) for x in listdir(folder2) if is_image_file(x)]
59
+
60
+ im1 = load_img(data_filenames[index])
61
+ im2 = load_img(data_filenames2[index])
62
+ _, file1 = os.path.split(data_filenames[index])
63
+ _, file2 = os.path.split(data_filenames2[index])
64
+ seed = random.randint(1, 1000000)
65
+ seed = np.random.randint(seed) # make a seed with numpy generator
66
+ if self.transform:
67
+ random.seed(seed) # apply this seed to img tranforms
68
+ torch.manual_seed(seed) # needed for torchvision 0.7
69
+ im1 = self.transform(im1)
70
+ random.seed(seed) # apply this seed to img tranforms
71
+ torch.manual_seed(seed) # needed for torchvision 0.7
72
+ im2 = self.transform(im2)
73
+ return im1, im2, file1, file2
74
+
75
+ def __len__(self):
76
+ return 685
77
+
78
+
79
+
80
+ class LOLv2SynDatasetFromFolder(data.Dataset):
81
+ def __init__(self, data_dir, transform=None):
82
+ super(LOLv2SynDatasetFromFolder, self).__init__()
83
+ self.data_dir = data_dir
84
+ self.transform = transform
85
+ self.norm = t.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
86
+
87
+ def __getitem__(self, index):
88
+
89
+ folder = self.data_dir+'/Low'
90
+ folder2= self.data_dir+'/Normal'
91
+ data_filenames = [join(folder, x) for x in listdir(folder) if is_image_file(x)]
92
+ data_filenames2 = [join(folder2, x) for x in listdir(folder2) if is_image_file(x)]
93
+
94
+
95
+ im1 = load_img(data_filenames[index])
96
+ im2 = load_img(data_filenames2[index])
97
+ _, file1 = os.path.split(data_filenames[index])
98
+ _, file2 = os.path.split(data_filenames2[index])
99
+ seed = random.randint(1, 1000000)
100
+ seed = np.random.randint(seed) # make a seed with numpy generator
101
+ if self.transform:
102
+ random.seed(seed) # apply this seed to img tranfsorms
103
+ torch.manual_seed(seed) # needed for torchvision 0.7
104
+ im1 = self.transform(im1)
105
+ random.seed(seed)
106
+ torch.manual_seed(seed)
107
+ im2 = self.transform(im2)
108
+ return im1, im2, file1, file2
109
+
110
+ def __len__(self):
111
+ return 900
112
+
113
+
114
+
115
+
116
+
data/SICE_blur_SID.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import random
4
+ import torch
5
+ import torch.utils.data as data
6
+ import numpy as np
7
+ from os import listdir
8
+ from os.path import join
9
+ from PIL import Image
10
+ from data.util import *
11
+ from torchvision import transforms as t
12
+ import torch.nn.functional as F
13
+
14
+ class LOLBlurDatasetFromFolder(data.Dataset):
15
+ def __init__(self, data_dir, transform=None):
16
+ super(LOLBlurDatasetFromFolder, self).__init__()
17
+ self.data_dir = data_dir
18
+ self.transform = transform
19
+
20
+ def __getitem__(self, index):
21
+ while True:
22
+ seed = random.randint(1, 1000000)
23
+ random.seed(seed)
24
+ index = random.randint(0, 259)
25
+ fill_index = str(index+1).zfill(4)
26
+ folder = join(self.data_dir+'/low_blur', fill_index)
27
+ folder2 = join(self.data_dir+'/high_sharp_scaled', fill_index)
28
+ if not os.path.exists(folder):
29
+ continue
30
+ data_filenames = [join(folder, x) for x in listdir(folder) if is_image_file(x)]
31
+ data_filenames2 = [join(folder2, x) for x in listdir(folder2) if is_image_file(x)]
32
+ num = len(data_filenames)
33
+ if num != 0: break
34
+ index1 = random.randint(1,num)
35
+
36
+ im1 = load_img(data_filenames[index1-1])
37
+ im2 = load_img(data_filenames2[index1-1])
38
+ seed = random.randint(1, 1000000)
39
+ seed = np.random.randint(seed) # make a seed with numpy generator
40
+ if self.transform:
41
+ random.seed(seed) # apply this seed to img tranfsorms
42
+ torch.manual_seed(seed) # needed for torchvision 0.7
43
+ im1 = self.transform(im1)
44
+ random.seed(seed)
45
+ torch.manual_seed(seed)
46
+ im2 = self.transform(im2)
47
+ return im1, im2, data_filenames[index1-1], data_filenames2[index1-1]
48
+
49
+ def __len__(self):
50
+ return 10200
51
+
52
+
53
+ class SIDDatasetFromFolder(data.Dataset):
54
+ def __init__(self, data_dir, transform=None):
55
+ super(SIDDatasetFromFolder, self).__init__()
56
+ self.data_dir = data_dir
57
+ self.transform = transform
58
+
59
+ def __getitem__(self, index):
60
+ while True:
61
+ seed = random.randint(1, 1000000)
62
+ random.seed(seed)
63
+ index = random.randint(0, 233)
64
+ fill_index = str(index+1).zfill(5)
65
+ folder = join(self.data_dir+'/short', fill_index)
66
+ folder2 = join(self.data_dir+'/long', fill_index)
67
+ if os.path.exists(folder):
68
+ data_filenames = [join(folder, x) for x in listdir(folder) if is_image_file(x)]
69
+ data_filenames2 = [join(folder2, x) for x in listdir(folder2) if is_image_file(x)]
70
+ num = len(data_filenames)
71
+ break
72
+ else:
73
+ continue
74
+ index1 = random.randint(1,num)
75
+
76
+
77
+ im1 = load_img(data_filenames[index1-1])
78
+ im2 = load_img(data_filenames2[0])
79
+ _, file1 = os.path.split(data_filenames[index1-1])
80
+ _, file2 = os.path.split(data_filenames2[0])
81
+ seed = np.random.randint(random.randint(1, 1000000)) # make a seed with numpy generator
82
+ if self.transform:
83
+ random.seed(seed) # apply this seed to img tranfsorms
84
+ torch.manual_seed(seed) # needed for torchvision 0.7
85
+ im1 = self.transform(im1)
86
+ random.seed(seed)
87
+ torch.manual_seed(seed)
88
+ im2 = self.transform(im2)
89
+ return im1, im2, file1, file2
90
+
91
+ def __len__(self):
92
+ return 2099
93
+
94
+
95
+
96
+ class SICEDatasetFromFolder(data.Dataset):
97
+ def __init__(self, data_dir, transform=None):
98
+ super(SICEDatasetFromFolder, self).__init__()
99
+ self.data_dir = data_dir
100
+ self.transform = transform
101
+
102
+ def __getitem__(self, index):
103
+ while True:
104
+ seed = random.randint(1, 1000000)
105
+ random.seed(seed)
106
+ index = random.randint(0, 590)
107
+ fill_index = str(index+1)
108
+ train, tail = os.path.split(self.data_dir)
109
+ folder = join(self.data_dir, fill_index)
110
+ data_gt = join(train+'/label', fill_index+'.JPG')
111
+ if os.path.exists(folder):
112
+ data_filenames = [join(folder, x) for x in listdir(folder) if is_image_file(x)]
113
+ num = len(data_filenames)
114
+ break
115
+ else:
116
+ continue
117
+ index1 = random.randint(1,num)
118
+
119
+ im1 = load_img(data_filenames[index1-1])
120
+ im2 = load_img(data_gt)
121
+ _, file1 = os.path.split(data_filenames[index1-1])
122
+ _, file2 = os.path.split(data_gt)
123
+ seed = np.random.randint(random.randint(1, 1000000)) # make a seed with numpy generator
124
+ if self.transform:
125
+ random.seed(seed) # apply this seed to img tranfsorms
126
+ torch.manual_seed(seed) # needed for torchvision 0.7
127
+ im1 = self.transform(im1)
128
+ random.seed(seed)
129
+ torch.manual_seed(seed)
130
+ im2 = self.transform(im2)
131
+ return im1, im2, file1, file2
132
+
133
+ def __len__(self):
134
+ return 4803
data/data.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision.transforms import Compose, ToTensor, RandomCrop, RandomHorizontalFlip, RandomVerticalFlip
2
+ from data.LOLdataset import *
3
+ from data.eval_sets import *
4
+ from data.SICE_blur_SID import *
5
+ from data.fivek import *
6
+
7
+ def transform1(size=256):
8
+ return Compose([
9
+ RandomCrop((size, size)),
10
+ RandomHorizontalFlip(),
11
+ RandomVerticalFlip(),
12
+ ToTensor(),
13
+ ])
14
+
15
+ def transform2():
16
+ return Compose([ToTensor()])
17
+
18
+
19
+
20
+ def get_lol_training_set(data_dir,size):
21
+ return LOLDatasetFromFolder(data_dir, transform=transform1(size))
22
+
23
+
24
+ def get_lol_v2_training_set(data_dir,size):
25
+ return LOLv2DatasetFromFolder(data_dir, transform=transform1(size))
26
+
27
+
28
+ def get_training_set_blur(data_dir,size):
29
+ return LOLBlurDatasetFromFolder(data_dir, transform=transform1(size))
30
+
31
+
32
+ def get_lol_v2_syn_training_set(data_dir,size):
33
+ return LOLv2SynDatasetFromFolder(data_dir, transform=transform1(size))
34
+
35
+
36
+ def get_SID_training_set(data_dir,size):
37
+ return SIDDatasetFromFolder(data_dir, transform=transform1(size))
38
+
39
+
40
+ def get_SICE_training_set(data_dir,size):
41
+ return SICEDatasetFromFolder(data_dir, transform=transform1(size))
42
+
43
+ def get_SICE_eval_set(data_dir):
44
+ return SICEDatasetFromFolderEval(data_dir, transform=transform2())
45
+
46
+ def get_eval_set(data_dir):
47
+ return DatasetFromFolderEval(data_dir, transform=transform2())
48
+
49
+ def get_fivek_training_set(data_dir,size):
50
+ return FiveKDatasetFromFolder(data_dir, transform=transform1(size))
51
+
52
+ def get_fivek_eval_set(data_dir):
53
+ return SICEDatasetFromFolderEval(data_dir, transform=transform2())
data/eval_sets.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import torch.utils.data as data
4
+ from os import listdir
5
+ from os.path import join
6
+ from data.util import *
7
+ import torch.nn.functional as F
8
+
9
+ class SICEDatasetFromFolderEval(data.Dataset):
10
+ def __init__(self, data_dir, transform=None):
11
+ super(SICEDatasetFromFolderEval, self).__init__()
12
+ data_filenames = [join(data_dir, x) for x in listdir(data_dir) if is_image_file(x)]
13
+ data_filenames.sort()
14
+ self.data_filenames = data_filenames
15
+ self.transform = transform
16
+
17
+ def __getitem__(self, index):
18
+ input = load_img(self.data_filenames[index])
19
+ _, file = os.path.split(self.data_filenames[index])
20
+
21
+ if self.transform:
22
+ input = self.transform(input)
23
+ factor = 8
24
+ h, w = input.shape[1], input.shape[2]
25
+ H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
26
+ padh = H - h if h % factor != 0 else 0
27
+ padw = W - w if w % factor != 0 else 0
28
+ input = F.pad(input.unsqueeze(0), (0,padw,0,padh), 'reflect').squeeze(0)
29
+ return input, file, h, w
30
+
31
+ def __len__(self):
32
+ return len(self.data_filenames)
33
+
34
+
35
+ class DatasetFromFolderEval(data.Dataset):
36
+ def __init__(self, data_dir, transform=None):
37
+ super(DatasetFromFolderEval, self).__init__()
38
+ data_filenames = [join(data_dir, x) for x in listdir(data_dir) if is_image_file(x)]
39
+ data_filenames.sort()
40
+ self.data_filenames = data_filenames
41
+ self.transform = transform
42
+
43
+ def __getitem__(self, index):
44
+ input = load_img(self.data_filenames[index])
45
+ _, file = os.path.split(self.data_filenames[index])
46
+
47
+ if self.transform:
48
+ input = self.transform(input)
49
+ return input, file
50
+
51
+ def __len__(self):
52
+ return len(self.data_filenames)
data/fivek.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Add new fivek dataset follow Retinexformer(https://github.com/caiyuanhao1998/Retinexformer)
2
+
3
+ import os
4
+ import random
5
+ import torch
6
+ import torch.utils.data as data
7
+ import numpy as np
8
+ from os import listdir
9
+ from os.path import join
10
+ from data.util import *
11
+
12
+ class FiveKDatasetFromFolder(data.Dataset):
13
+ def __init__(self, data_dir, transform=None):
14
+ super(FiveKDatasetFromFolder, self).__init__()
15
+ self.data_dir = data_dir
16
+ self.transform = transform
17
+
18
+ def __getitem__(self, index):
19
+
20
+ folder = self.data_dir+'/input'
21
+ folder2= self.data_dir+'/target'
22
+ data_filenames = [join(folder, x) for x in listdir(folder) if is_image_file(x)]
23
+ data_filenames2 = [join(folder2, x) for x in listdir(folder2) if is_image_file(x)]
24
+
25
+
26
+ im1 = load_img(data_filenames[index])
27
+ im2 = load_img(data_filenames2[index])
28
+ _, file1 = os.path.split(data_filenames[index])
29
+ _, file2 = os.path.split(data_filenames2[index])
30
+ seed = random.randint(1, 1000000)
31
+ seed = np.random.randint(seed) # make a seed with numpy generator
32
+ if self.transform:
33
+ random.seed(seed) # apply this seed to img tranfsorms
34
+ torch.manual_seed(seed) # needed for torchvision 0.7
35
+ im1 = self.transform(im1)
36
+ random.seed(seed)
37
+ torch.manual_seed(seed)
38
+ im2 = self.transform(im2)
39
+ return im1, im2, file1, file2
40
+
41
+ def __len__(self):
42
+ return 4500
data/options.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ def _str2bool(v):
4
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
5
+ return True
6
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
7
+ return False
8
+ else:
9
+ raise argparse.ArgumentTypeError('Boolean value expected.')
10
+
11
+ def option():
12
+ # Training settings
13
+ parser = argparse.ArgumentParser(description='CIDNet')
14
+ parser.add_argument('--batchSize', type=int, default=8, help='training batch size')
15
+ parser.add_argument('--cropSize', type=int, default=256, help='image crop size (patch size)')
16
+ parser.add_argument('--nEpochs', type=int, default=1000, help='number of epochs to train for end')
17
+ parser.add_argument('--start_epoch', type=int, default=0, help='number of epochs to start, >0 is retrained a pre-trained pth')
18
+ parser.add_argument('--snapshots', type=int, default=10, help='Snapshots for save checkpoints pth')
19
+ parser.add_argument('--lr', type=float, default=1e-4, help='Learning Rate')
20
+ parser.add_argument('--gpu_mode', type=_str2bool, default=True)
21
+ parser.add_argument('--shuffle', type=_str2bool, default=True)
22
+ parser.add_argument('--threads', type=int, default=16, help='number of threads for dataloader to use')
23
+
24
+ # choose a scheduler
25
+ parser.add_argument('--cos_restart_cyclic', type=_str2bool, default=False)
26
+ parser.add_argument('--cos_restart', type=_str2bool, default=True)
27
+
28
+ # warmup training
29
+ parser.add_argument('--warmup_epochs', type=int, default=3, help='warmup_epochs')
30
+ parser.add_argument('--start_warmup', type=_str2bool, default=True, help='turn False to train without warmup')
31
+
32
+ # train datasets
33
+ parser.add_argument('--data_train_lol_blur' , type=str, default='./datasets/LOL_blur/train')
34
+ parser.add_argument('--data_train_lol_v1' , type=str, default='./datasets/LOLdataset/our485')
35
+ parser.add_argument('--data_train_lolv2_real' , type=str, default='./datasets/LOLv2/Real_captured/Train')
36
+ parser.add_argument('--data_train_lolv2_syn' , type=str, default='./datasets/LOLv2/Synthetic/Train')
37
+ parser.add_argument('--data_train_SID' , type=str, default='./datasets/Sony_total_dark/train')
38
+ parser.add_argument('--data_train_SICE' , type=str, default='./datasets/SICE/Dataset/train')
39
+ parser.add_argument('--data_train_fivek' , type=str, default='./datasets/FiveK/train')
40
+
41
+ # validation input
42
+ parser.add_argument('--data_val_lol_blur' , type=str, default='./datasets/LOL_blur/eval/low_blur')
43
+ parser.add_argument('--data_val_lol_v1' , type=str, default='./datasets/LOLdataset/eval15/low')
44
+ parser.add_argument('--data_val_lolv2_real' , type=str, default='./datasets/LOLv2/Real_captured/Test/Low')
45
+ parser.add_argument('--data_val_lolv2_syn' , type=str, default='./datasets/LOLv2/Synthetic/Test/Low')
46
+ parser.add_argument('--data_val_SID' , type=str, default='./datasets/Sony_total_dark/eval/short')
47
+ parser.add_argument('--data_val_SICE_mix' , type=str, default='./datasets/SICE/Dataset/eval/test')
48
+ parser.add_argument('--data_val_SICE_grad' , type=str, default='./datasets/SICE/Dataset/eval/test')
49
+ parser.add_argument('--data_test_fivek' , type=str, default='./datasets/FiveK/test/input')
50
+
51
+ # validation groundtruth
52
+ parser.add_argument('--data_valgt_lol_blur' , type=str, default='./datasets/LOL_blur/eval/high_sharp_scaled/')
53
+ parser.add_argument('--data_valgt_lol_v1' , type=str, default='./datasets/LOLdataset/eval15/high/')
54
+ parser.add_argument('--data_valgt_lolv2_real' , type=str, default='./datasets/LOLv2/Real_captured/Test/Normal/')
55
+ parser.add_argument('--data_valgt_lolv2_syn' , type=str, default='./datasets/LOLv2/Synthetic/Test/Normal/')
56
+ parser.add_argument('--data_valgt_SID' , type=str, default='./datasets/Sony_total_dark/eval/long/')
57
+ parser.add_argument('--data_valgt_SICE_mix' , type=str, default='./datasets/SICE/Dataset/eval/target/')
58
+ parser.add_argument('--data_valgt_SICE_grad' , type=str, default='./datasets/SICE/Dataset/eval/target/')
59
+ parser.add_argument('--data_valgt_fivek' , type=str, default='./datasets/FiveK/test/target/')
60
+
61
+ parser.add_argument('--val_folder', default='./results/', help='Location to save validation datasets')
62
+
63
+ # loss weights
64
+ parser.add_argument('--HVI_weight', type=float, default=1.0)
65
+ parser.add_argument('--L1_weight', type=float, default=1.0)
66
+ parser.add_argument('--D_weight', type=float, default=0.5)
67
+ parser.add_argument('--E_weight', type=float, default=50.0)
68
+ parser.add_argument('--P_weight', type=float, default=1e-2)
69
+
70
+ # use random gamma function (enhancement curve) to improve generalization
71
+ parser.add_argument('--gamma', type=_str2bool, default=False)
72
+ parser.add_argument('--start_gamma', type=int, default=60)
73
+ parser.add_argument('--end_gamma', type=int, default=120)
74
+
75
+ # auto grad, turn off to speed up training
76
+ parser.add_argument('--grad_detect', type=_str2bool, default=False, help='if gradient explosion occurs, turn-on it')
77
+ parser.add_argument('--grad_clip', type=_str2bool, default=True, help='if gradient fluctuates too much, turn-on it')
78
+
79
+
80
+ # choose which dataset you want to train
81
+ parser.add_argument('--dataset', type=str, default='lol_v1',
82
+ choices=['lol_v1',
83
+ 'lolv2_real',
84
+ 'lolv2_syn',
85
+ 'lol_blur',
86
+ 'SID',
87
+ 'SICE_mix',
88
+ 'SICE_grad',
89
+ 'fivek'],
90
+ help='Select the dataset to train on (default: %(default)s)')
91
+
92
+ return parser
data/scheduler.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.optim.lr_scheduler import _LRScheduler
2
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
3
+ import math
4
+
5
+ class GradualWarmupScheduler(_LRScheduler):
6
+ """ Gradually warm-up(increasing) learning rate in optimizer.
7
+ Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
8
+
9
+ Args:
10
+ optimizer (Optimizer): Wrapped optimizer.
11
+ multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
12
+ total_epoch: target learning rate is reached at total_epoch, gradually
13
+ after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
14
+ """
15
+
16
+ def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
17
+ self.multiplier = multiplier
18
+ if self.multiplier < 1.:
19
+ raise ValueError('multiplier should be greater thant or equal to 1.')
20
+ self.total_epoch = total_epoch
21
+ self.after_scheduler = after_scheduler
22
+ self.finished = False
23
+ super(GradualWarmupScheduler, self).__init__(optimizer)
24
+
25
+ def get_lr(self):
26
+ if self.last_epoch > self.total_epoch:
27
+ if self.after_scheduler:
28
+ if not self.finished:
29
+ self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
30
+ self.finished = True
31
+ return self.after_scheduler.get_lr()
32
+ return [base_lr * self.multiplier for base_lr in self.base_lrs]
33
+
34
+ if self.multiplier == 1.0:
35
+ return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
36
+ else:
37
+ return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
38
+
39
+ def step_ReduceLROnPlateau(self, metrics, epoch=None):
40
+ if epoch is None:
41
+ epoch = self.last_epoch + 1
42
+ self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
43
+ if self.last_epoch <= self.total_epoch:
44
+ warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
45
+ for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
46
+ param_group['lr'] = lr
47
+ else:
48
+ if epoch is None:
49
+ self.after_scheduler.step(metrics, None)
50
+ else:
51
+ self.after_scheduler.step(metrics, epoch - self.total_epoch)
52
+
53
+ def step(self, epoch=None, metrics=None):
54
+ if type(self.after_scheduler) != ReduceLROnPlateau:
55
+ if self.finished and self.after_scheduler:
56
+ if epoch is None:
57
+ self.after_scheduler.step(None)
58
+ else:
59
+ self.after_scheduler.step(epoch - self.total_epoch)
60
+ else:
61
+ return super(GradualWarmupScheduler, self).step(epoch)
62
+ else:
63
+ self.step_ReduceLROnPlateau(metrics, epoch)
64
+
65
+ def get_position_from_periods(iteration, cumulative_period):
66
+ """Get the position from a period list.
67
+
68
+ It will return the index of the right-closest number in the period list.
69
+ For example, the cumulative_period = [100, 200, 300, 400],
70
+ if iteration == 50, return 0;
71
+ if iteration == 210, return 2;
72
+ if iteration == 300, return 2.
73
+
74
+ Args:
75
+ iteration (int): Current iteration.
76
+ cumulative_period (list[int]): Cumulative period list.
77
+
78
+ Returns:
79
+ int: The position of the right-closest number in the period list.
80
+ """
81
+ for i, period in enumerate(cumulative_period):
82
+ if iteration <= period:
83
+ return i
84
+
85
+ class CosineAnnealingRestartCyclicLR(_LRScheduler):
86
+ """ Cosine annealing with restarts learning rate scheme.
87
+ An example of config:
88
+ periods = [10, 10, 10, 10]
89
+ restart_weights = [1, 0.5, 0.5, 0.5]
90
+ eta_min=1e-7
91
+ It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
92
+ scheduler will restart with the weights in restart_weights.
93
+ Args:
94
+ optimizer (torch.nn.optimizer): Torch optimizer.
95
+ periods (list): Period for each cosine anneling cycle.
96
+ restart_weights (list): Restart weights at each restart iteration.
97
+ Default: [1].
98
+ eta_min (float): The mimimum lr. Default: 0.
99
+ last_epoch (int): Used in _LRScheduler. Default: -1.
100
+ """
101
+
102
+ def __init__(self,
103
+ optimizer,
104
+ periods,
105
+ restart_weights=(1, ),
106
+ eta_mins=(0, ),
107
+ last_epoch=-1):
108
+ self.periods = periods
109
+ self.restart_weights = restart_weights
110
+ self.eta_mins = eta_mins
111
+ assert (len(self.periods) == len(self.restart_weights)
112
+ ), 'periods and restart_weights should have the same length.'
113
+ self.cumulative_period = [
114
+ sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
115
+ ]
116
+ super(CosineAnnealingRestartCyclicLR, self).__init__(optimizer, last_epoch)
117
+
118
+ def get_lr(self):
119
+ idx = get_position_from_periods(self.last_epoch,
120
+ self.cumulative_period)
121
+ current_weight = self.restart_weights[idx]
122
+ nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
123
+ current_period = self.periods[idx]
124
+ eta_min = self.eta_mins[idx]
125
+
126
+ return [
127
+ eta_min + current_weight * 0.5 * (base_lr - eta_min) *
128
+ (1 + math.cos(math.pi * (
129
+ (self.last_epoch - nearest_restart) / current_period)))
130
+ for base_lr in self.base_lrs
131
+ ]
132
+
133
+ class CosineAnnealingRestartLR(_LRScheduler):
134
+ """ Cosine annealing with restarts learning rate scheme.
135
+
136
+ An example of config:
137
+ periods = [10, 10, 10, 10]
138
+ restart_weights = [1, 0.5, 0.5, 0.5]
139
+ eta_min=1e-7
140
+
141
+ It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
142
+ scheduler will restart with the weights in restart_weights.
143
+
144
+ Args:
145
+ optimizer (torch.nn.optimizer): Torch optimizer.
146
+ periods (list): Period for each cosine anneling cycle.
147
+ restart_weights (list): Restart weights at each restart iteration.
148
+ Default: [1].
149
+ eta_min (float): The mimimum lr. Default: 0.
150
+ last_epoch (int): Used in _LRScheduler. Default: -1.
151
+ """
152
+
153
+ def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1):
154
+ self.periods = periods
155
+ self.restart_weights = restart_weights
156
+ self.eta_min = eta_min
157
+ assert (len(self.periods) == len(
158
+ self.restart_weights)), 'periods and restart_weights should have the same length.'
159
+ self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))]
160
+ super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch)
161
+
162
+ def get_lr(self):
163
+ idx = get_position_from_periods(self.last_epoch, self.cumulative_period)
164
+ current_weight = self.restart_weights[idx]
165
+ nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
166
+ current_period = self.periods[idx]
167
+
168
+ return [
169
+ self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) *
170
+ (1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period)))
171
+ for base_lr in self.base_lrs
172
+ ]
data/util.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from PIL import Image
3
+
4
+ def is_image_file(filename):
5
+ return any(filename.endswith(extension) for extension in [".png", ".jpg", ".bmp", ".JPG", ".jpeg"])
6
+
7
+ def load_img(filepath):
8
+ img = Image.open(filepath).convert('RGB')
9
+ return img