Spaces:
Sleeping
Sleeping
Upload 8 files
Browse files- data/LOLdataset.py +116 -0
- data/SICE_blur_SID.py +134 -0
- data/data.py +53 -0
- data/eval_sets.py +52 -0
- data/fivek.py +42 -0
- data/options.py +92 -0
- data/scheduler.py +172 -0
- data/util.py +9 -0
data/LOLdataset.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import torch
|
| 5 |
+
import torch.utils.data as data
|
| 6 |
+
import numpy as np
|
| 7 |
+
from os import listdir
|
| 8 |
+
from os.path import join
|
| 9 |
+
from data.util import *
|
| 10 |
+
from torchvision import transforms as t
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class LOLDatasetFromFolder(data.Dataset):
|
| 14 |
+
def __init__(self, data_dir, transform=None):
|
| 15 |
+
super(LOLDatasetFromFolder, self).__init__()
|
| 16 |
+
self.data_dir = data_dir
|
| 17 |
+
self.transform = transform
|
| 18 |
+
self.norm = t.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 19 |
+
|
| 20 |
+
def __getitem__(self, index):
|
| 21 |
+
|
| 22 |
+
folder = self.data_dir+'/low'
|
| 23 |
+
folder2= self.data_dir+'/high'
|
| 24 |
+
data_filenames = [join(folder, x) for x in listdir(folder) if is_image_file(x)]
|
| 25 |
+
data_filenames2 = [join(folder2, x) for x in listdir(folder2) if is_image_file(x)]
|
| 26 |
+
num = len(data_filenames)
|
| 27 |
+
|
| 28 |
+
im1 = load_img(data_filenames[index])
|
| 29 |
+
im2 = load_img(data_filenames2[index])
|
| 30 |
+
_, file1 = os.path.split(data_filenames[index])
|
| 31 |
+
_, file2 = os.path.split(data_filenames2[index])
|
| 32 |
+
seed = random.randint(1, 1000000)
|
| 33 |
+
seed = np.random.randint(seed) # make a seed with numpy generator
|
| 34 |
+
if self.transform:
|
| 35 |
+
random.seed(seed) # apply this seed to img tranfsorms
|
| 36 |
+
torch.manual_seed(seed) # needed for torchvision 0.7
|
| 37 |
+
im1 = self.transform(im1)
|
| 38 |
+
random.seed(seed)
|
| 39 |
+
torch.manual_seed(seed)
|
| 40 |
+
im2 = self.transform(im2)
|
| 41 |
+
return im1, im2, file1, file2
|
| 42 |
+
|
| 43 |
+
def __len__(self):
|
| 44 |
+
return 485
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class LOLv2DatasetFromFolder(data.Dataset):
|
| 48 |
+
def __init__(self, data_dir, transform=None):
|
| 49 |
+
super(LOLv2DatasetFromFolder, self).__init__()
|
| 50 |
+
self.data_dir = data_dir
|
| 51 |
+
self.transform = transform
|
| 52 |
+
|
| 53 |
+
def __getitem__(self, index):
|
| 54 |
+
|
| 55 |
+
folder = self.data_dir+'/Low'
|
| 56 |
+
folder2= self.data_dir+'/Normal'
|
| 57 |
+
data_filenames = [join(folder, x) for x in listdir(folder) if is_image_file(x)]
|
| 58 |
+
data_filenames2 = [join(folder2, x) for x in listdir(folder2) if is_image_file(x)]
|
| 59 |
+
|
| 60 |
+
im1 = load_img(data_filenames[index])
|
| 61 |
+
im2 = load_img(data_filenames2[index])
|
| 62 |
+
_, file1 = os.path.split(data_filenames[index])
|
| 63 |
+
_, file2 = os.path.split(data_filenames2[index])
|
| 64 |
+
seed = random.randint(1, 1000000)
|
| 65 |
+
seed = np.random.randint(seed) # make a seed with numpy generator
|
| 66 |
+
if self.transform:
|
| 67 |
+
random.seed(seed) # apply this seed to img tranforms
|
| 68 |
+
torch.manual_seed(seed) # needed for torchvision 0.7
|
| 69 |
+
im1 = self.transform(im1)
|
| 70 |
+
random.seed(seed) # apply this seed to img tranforms
|
| 71 |
+
torch.manual_seed(seed) # needed for torchvision 0.7
|
| 72 |
+
im2 = self.transform(im2)
|
| 73 |
+
return im1, im2, file1, file2
|
| 74 |
+
|
| 75 |
+
def __len__(self):
|
| 76 |
+
return 685
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class LOLv2SynDatasetFromFolder(data.Dataset):
|
| 81 |
+
def __init__(self, data_dir, transform=None):
|
| 82 |
+
super(LOLv2SynDatasetFromFolder, self).__init__()
|
| 83 |
+
self.data_dir = data_dir
|
| 84 |
+
self.transform = transform
|
| 85 |
+
self.norm = t.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 86 |
+
|
| 87 |
+
def __getitem__(self, index):
|
| 88 |
+
|
| 89 |
+
folder = self.data_dir+'/Low'
|
| 90 |
+
folder2= self.data_dir+'/Normal'
|
| 91 |
+
data_filenames = [join(folder, x) for x in listdir(folder) if is_image_file(x)]
|
| 92 |
+
data_filenames2 = [join(folder2, x) for x in listdir(folder2) if is_image_file(x)]
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
im1 = load_img(data_filenames[index])
|
| 96 |
+
im2 = load_img(data_filenames2[index])
|
| 97 |
+
_, file1 = os.path.split(data_filenames[index])
|
| 98 |
+
_, file2 = os.path.split(data_filenames2[index])
|
| 99 |
+
seed = random.randint(1, 1000000)
|
| 100 |
+
seed = np.random.randint(seed) # make a seed with numpy generator
|
| 101 |
+
if self.transform:
|
| 102 |
+
random.seed(seed) # apply this seed to img tranfsorms
|
| 103 |
+
torch.manual_seed(seed) # needed for torchvision 0.7
|
| 104 |
+
im1 = self.transform(im1)
|
| 105 |
+
random.seed(seed)
|
| 106 |
+
torch.manual_seed(seed)
|
| 107 |
+
im2 = self.transform(im2)
|
| 108 |
+
return im1, im2, file1, file2
|
| 109 |
+
|
| 110 |
+
def __len__(self):
|
| 111 |
+
return 900
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
|
data/SICE_blur_SID.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import torch
|
| 5 |
+
import torch.utils.data as data
|
| 6 |
+
import numpy as np
|
| 7 |
+
from os import listdir
|
| 8 |
+
from os.path import join
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from data.util import *
|
| 11 |
+
from torchvision import transforms as t
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
class LOLBlurDatasetFromFolder(data.Dataset):
|
| 15 |
+
def __init__(self, data_dir, transform=None):
|
| 16 |
+
super(LOLBlurDatasetFromFolder, self).__init__()
|
| 17 |
+
self.data_dir = data_dir
|
| 18 |
+
self.transform = transform
|
| 19 |
+
|
| 20 |
+
def __getitem__(self, index):
|
| 21 |
+
while True:
|
| 22 |
+
seed = random.randint(1, 1000000)
|
| 23 |
+
random.seed(seed)
|
| 24 |
+
index = random.randint(0, 259)
|
| 25 |
+
fill_index = str(index+1).zfill(4)
|
| 26 |
+
folder = join(self.data_dir+'/low_blur', fill_index)
|
| 27 |
+
folder2 = join(self.data_dir+'/high_sharp_scaled', fill_index)
|
| 28 |
+
if not os.path.exists(folder):
|
| 29 |
+
continue
|
| 30 |
+
data_filenames = [join(folder, x) for x in listdir(folder) if is_image_file(x)]
|
| 31 |
+
data_filenames2 = [join(folder2, x) for x in listdir(folder2) if is_image_file(x)]
|
| 32 |
+
num = len(data_filenames)
|
| 33 |
+
if num != 0: break
|
| 34 |
+
index1 = random.randint(1,num)
|
| 35 |
+
|
| 36 |
+
im1 = load_img(data_filenames[index1-1])
|
| 37 |
+
im2 = load_img(data_filenames2[index1-1])
|
| 38 |
+
seed = random.randint(1, 1000000)
|
| 39 |
+
seed = np.random.randint(seed) # make a seed with numpy generator
|
| 40 |
+
if self.transform:
|
| 41 |
+
random.seed(seed) # apply this seed to img tranfsorms
|
| 42 |
+
torch.manual_seed(seed) # needed for torchvision 0.7
|
| 43 |
+
im1 = self.transform(im1)
|
| 44 |
+
random.seed(seed)
|
| 45 |
+
torch.manual_seed(seed)
|
| 46 |
+
im2 = self.transform(im2)
|
| 47 |
+
return im1, im2, data_filenames[index1-1], data_filenames2[index1-1]
|
| 48 |
+
|
| 49 |
+
def __len__(self):
|
| 50 |
+
return 10200
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class SIDDatasetFromFolder(data.Dataset):
|
| 54 |
+
def __init__(self, data_dir, transform=None):
|
| 55 |
+
super(SIDDatasetFromFolder, self).__init__()
|
| 56 |
+
self.data_dir = data_dir
|
| 57 |
+
self.transform = transform
|
| 58 |
+
|
| 59 |
+
def __getitem__(self, index):
|
| 60 |
+
while True:
|
| 61 |
+
seed = random.randint(1, 1000000)
|
| 62 |
+
random.seed(seed)
|
| 63 |
+
index = random.randint(0, 233)
|
| 64 |
+
fill_index = str(index+1).zfill(5)
|
| 65 |
+
folder = join(self.data_dir+'/short', fill_index)
|
| 66 |
+
folder2 = join(self.data_dir+'/long', fill_index)
|
| 67 |
+
if os.path.exists(folder):
|
| 68 |
+
data_filenames = [join(folder, x) for x in listdir(folder) if is_image_file(x)]
|
| 69 |
+
data_filenames2 = [join(folder2, x) for x in listdir(folder2) if is_image_file(x)]
|
| 70 |
+
num = len(data_filenames)
|
| 71 |
+
break
|
| 72 |
+
else:
|
| 73 |
+
continue
|
| 74 |
+
index1 = random.randint(1,num)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
im1 = load_img(data_filenames[index1-1])
|
| 78 |
+
im2 = load_img(data_filenames2[0])
|
| 79 |
+
_, file1 = os.path.split(data_filenames[index1-1])
|
| 80 |
+
_, file2 = os.path.split(data_filenames2[0])
|
| 81 |
+
seed = np.random.randint(random.randint(1, 1000000)) # make a seed with numpy generator
|
| 82 |
+
if self.transform:
|
| 83 |
+
random.seed(seed) # apply this seed to img tranfsorms
|
| 84 |
+
torch.manual_seed(seed) # needed for torchvision 0.7
|
| 85 |
+
im1 = self.transform(im1)
|
| 86 |
+
random.seed(seed)
|
| 87 |
+
torch.manual_seed(seed)
|
| 88 |
+
im2 = self.transform(im2)
|
| 89 |
+
return im1, im2, file1, file2
|
| 90 |
+
|
| 91 |
+
def __len__(self):
|
| 92 |
+
return 2099
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class SICEDatasetFromFolder(data.Dataset):
|
| 97 |
+
def __init__(self, data_dir, transform=None):
|
| 98 |
+
super(SICEDatasetFromFolder, self).__init__()
|
| 99 |
+
self.data_dir = data_dir
|
| 100 |
+
self.transform = transform
|
| 101 |
+
|
| 102 |
+
def __getitem__(self, index):
|
| 103 |
+
while True:
|
| 104 |
+
seed = random.randint(1, 1000000)
|
| 105 |
+
random.seed(seed)
|
| 106 |
+
index = random.randint(0, 590)
|
| 107 |
+
fill_index = str(index+1)
|
| 108 |
+
train, tail = os.path.split(self.data_dir)
|
| 109 |
+
folder = join(self.data_dir, fill_index)
|
| 110 |
+
data_gt = join(train+'/label', fill_index+'.JPG')
|
| 111 |
+
if os.path.exists(folder):
|
| 112 |
+
data_filenames = [join(folder, x) for x in listdir(folder) if is_image_file(x)]
|
| 113 |
+
num = len(data_filenames)
|
| 114 |
+
break
|
| 115 |
+
else:
|
| 116 |
+
continue
|
| 117 |
+
index1 = random.randint(1,num)
|
| 118 |
+
|
| 119 |
+
im1 = load_img(data_filenames[index1-1])
|
| 120 |
+
im2 = load_img(data_gt)
|
| 121 |
+
_, file1 = os.path.split(data_filenames[index1-1])
|
| 122 |
+
_, file2 = os.path.split(data_gt)
|
| 123 |
+
seed = np.random.randint(random.randint(1, 1000000)) # make a seed with numpy generator
|
| 124 |
+
if self.transform:
|
| 125 |
+
random.seed(seed) # apply this seed to img tranfsorms
|
| 126 |
+
torch.manual_seed(seed) # needed for torchvision 0.7
|
| 127 |
+
im1 = self.transform(im1)
|
| 128 |
+
random.seed(seed)
|
| 129 |
+
torch.manual_seed(seed)
|
| 130 |
+
im2 = self.transform(im2)
|
| 131 |
+
return im1, im2, file1, file2
|
| 132 |
+
|
| 133 |
+
def __len__(self):
|
| 134 |
+
return 4803
|
data/data.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torchvision.transforms import Compose, ToTensor, RandomCrop, RandomHorizontalFlip, RandomVerticalFlip
|
| 2 |
+
from data.LOLdataset import *
|
| 3 |
+
from data.eval_sets import *
|
| 4 |
+
from data.SICE_blur_SID import *
|
| 5 |
+
from data.fivek import *
|
| 6 |
+
|
| 7 |
+
def transform1(size=256):
|
| 8 |
+
return Compose([
|
| 9 |
+
RandomCrop((size, size)),
|
| 10 |
+
RandomHorizontalFlip(),
|
| 11 |
+
RandomVerticalFlip(),
|
| 12 |
+
ToTensor(),
|
| 13 |
+
])
|
| 14 |
+
|
| 15 |
+
def transform2():
|
| 16 |
+
return Compose([ToTensor()])
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_lol_training_set(data_dir,size):
|
| 21 |
+
return LOLDatasetFromFolder(data_dir, transform=transform1(size))
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def get_lol_v2_training_set(data_dir,size):
|
| 25 |
+
return LOLv2DatasetFromFolder(data_dir, transform=transform1(size))
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def get_training_set_blur(data_dir,size):
|
| 29 |
+
return LOLBlurDatasetFromFolder(data_dir, transform=transform1(size))
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get_lol_v2_syn_training_set(data_dir,size):
|
| 33 |
+
return LOLv2SynDatasetFromFolder(data_dir, transform=transform1(size))
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def get_SID_training_set(data_dir,size):
|
| 37 |
+
return SIDDatasetFromFolder(data_dir, transform=transform1(size))
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_SICE_training_set(data_dir,size):
|
| 41 |
+
return SICEDatasetFromFolder(data_dir, transform=transform1(size))
|
| 42 |
+
|
| 43 |
+
def get_SICE_eval_set(data_dir):
|
| 44 |
+
return SICEDatasetFromFolderEval(data_dir, transform=transform2())
|
| 45 |
+
|
| 46 |
+
def get_eval_set(data_dir):
|
| 47 |
+
return DatasetFromFolderEval(data_dir, transform=transform2())
|
| 48 |
+
|
| 49 |
+
def get_fivek_training_set(data_dir,size):
|
| 50 |
+
return FiveKDatasetFromFolder(data_dir, transform=transform1(size))
|
| 51 |
+
|
| 52 |
+
def get_fivek_eval_set(data_dir):
|
| 53 |
+
return SICEDatasetFromFolderEval(data_dir, transform=transform2())
|
data/eval_sets.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os
|
| 3 |
+
import torch.utils.data as data
|
| 4 |
+
from os import listdir
|
| 5 |
+
from os.path import join
|
| 6 |
+
from data.util import *
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
class SICEDatasetFromFolderEval(data.Dataset):
|
| 10 |
+
def __init__(self, data_dir, transform=None):
|
| 11 |
+
super(SICEDatasetFromFolderEval, self).__init__()
|
| 12 |
+
data_filenames = [join(data_dir, x) for x in listdir(data_dir) if is_image_file(x)]
|
| 13 |
+
data_filenames.sort()
|
| 14 |
+
self.data_filenames = data_filenames
|
| 15 |
+
self.transform = transform
|
| 16 |
+
|
| 17 |
+
def __getitem__(self, index):
|
| 18 |
+
input = load_img(self.data_filenames[index])
|
| 19 |
+
_, file = os.path.split(self.data_filenames[index])
|
| 20 |
+
|
| 21 |
+
if self.transform:
|
| 22 |
+
input = self.transform(input)
|
| 23 |
+
factor = 8
|
| 24 |
+
h, w = input.shape[1], input.shape[2]
|
| 25 |
+
H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
|
| 26 |
+
padh = H - h if h % factor != 0 else 0
|
| 27 |
+
padw = W - w if w % factor != 0 else 0
|
| 28 |
+
input = F.pad(input.unsqueeze(0), (0,padw,0,padh), 'reflect').squeeze(0)
|
| 29 |
+
return input, file, h, w
|
| 30 |
+
|
| 31 |
+
def __len__(self):
|
| 32 |
+
return len(self.data_filenames)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class DatasetFromFolderEval(data.Dataset):
|
| 36 |
+
def __init__(self, data_dir, transform=None):
|
| 37 |
+
super(DatasetFromFolderEval, self).__init__()
|
| 38 |
+
data_filenames = [join(data_dir, x) for x in listdir(data_dir) if is_image_file(x)]
|
| 39 |
+
data_filenames.sort()
|
| 40 |
+
self.data_filenames = data_filenames
|
| 41 |
+
self.transform = transform
|
| 42 |
+
|
| 43 |
+
def __getitem__(self, index):
|
| 44 |
+
input = load_img(self.data_filenames[index])
|
| 45 |
+
_, file = os.path.split(self.data_filenames[index])
|
| 46 |
+
|
| 47 |
+
if self.transform:
|
| 48 |
+
input = self.transform(input)
|
| 49 |
+
return input, file
|
| 50 |
+
|
| 51 |
+
def __len__(self):
|
| 52 |
+
return len(self.data_filenames)
|
data/fivek.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Add new fivek dataset follow Retinexformer(https://github.com/caiyuanhao1998/Retinexformer)
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
import torch
|
| 6 |
+
import torch.utils.data as data
|
| 7 |
+
import numpy as np
|
| 8 |
+
from os import listdir
|
| 9 |
+
from os.path import join
|
| 10 |
+
from data.util import *
|
| 11 |
+
|
| 12 |
+
class FiveKDatasetFromFolder(data.Dataset):
|
| 13 |
+
def __init__(self, data_dir, transform=None):
|
| 14 |
+
super(FiveKDatasetFromFolder, self).__init__()
|
| 15 |
+
self.data_dir = data_dir
|
| 16 |
+
self.transform = transform
|
| 17 |
+
|
| 18 |
+
def __getitem__(self, index):
|
| 19 |
+
|
| 20 |
+
folder = self.data_dir+'/input'
|
| 21 |
+
folder2= self.data_dir+'/target'
|
| 22 |
+
data_filenames = [join(folder, x) for x in listdir(folder) if is_image_file(x)]
|
| 23 |
+
data_filenames2 = [join(folder2, x) for x in listdir(folder2) if is_image_file(x)]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
im1 = load_img(data_filenames[index])
|
| 27 |
+
im2 = load_img(data_filenames2[index])
|
| 28 |
+
_, file1 = os.path.split(data_filenames[index])
|
| 29 |
+
_, file2 = os.path.split(data_filenames2[index])
|
| 30 |
+
seed = random.randint(1, 1000000)
|
| 31 |
+
seed = np.random.randint(seed) # make a seed with numpy generator
|
| 32 |
+
if self.transform:
|
| 33 |
+
random.seed(seed) # apply this seed to img tranfsorms
|
| 34 |
+
torch.manual_seed(seed) # needed for torchvision 0.7
|
| 35 |
+
im1 = self.transform(im1)
|
| 36 |
+
random.seed(seed)
|
| 37 |
+
torch.manual_seed(seed)
|
| 38 |
+
im2 = self.transform(im2)
|
| 39 |
+
return im1, im2, file1, file2
|
| 40 |
+
|
| 41 |
+
def __len__(self):
|
| 42 |
+
return 4500
|
data/options.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
def _str2bool(v):
|
| 4 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
| 5 |
+
return True
|
| 6 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
| 7 |
+
return False
|
| 8 |
+
else:
|
| 9 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
| 10 |
+
|
| 11 |
+
def option():
|
| 12 |
+
# Training settings
|
| 13 |
+
parser = argparse.ArgumentParser(description='CIDNet')
|
| 14 |
+
parser.add_argument('--batchSize', type=int, default=8, help='training batch size')
|
| 15 |
+
parser.add_argument('--cropSize', type=int, default=256, help='image crop size (patch size)')
|
| 16 |
+
parser.add_argument('--nEpochs', type=int, default=1000, help='number of epochs to train for end')
|
| 17 |
+
parser.add_argument('--start_epoch', type=int, default=0, help='number of epochs to start, >0 is retrained a pre-trained pth')
|
| 18 |
+
parser.add_argument('--snapshots', type=int, default=10, help='Snapshots for save checkpoints pth')
|
| 19 |
+
parser.add_argument('--lr', type=float, default=1e-4, help='Learning Rate')
|
| 20 |
+
parser.add_argument('--gpu_mode', type=_str2bool, default=True)
|
| 21 |
+
parser.add_argument('--shuffle', type=_str2bool, default=True)
|
| 22 |
+
parser.add_argument('--threads', type=int, default=16, help='number of threads for dataloader to use')
|
| 23 |
+
|
| 24 |
+
# choose a scheduler
|
| 25 |
+
parser.add_argument('--cos_restart_cyclic', type=_str2bool, default=False)
|
| 26 |
+
parser.add_argument('--cos_restart', type=_str2bool, default=True)
|
| 27 |
+
|
| 28 |
+
# warmup training
|
| 29 |
+
parser.add_argument('--warmup_epochs', type=int, default=3, help='warmup_epochs')
|
| 30 |
+
parser.add_argument('--start_warmup', type=_str2bool, default=True, help='turn False to train without warmup')
|
| 31 |
+
|
| 32 |
+
# train datasets
|
| 33 |
+
parser.add_argument('--data_train_lol_blur' , type=str, default='./datasets/LOL_blur/train')
|
| 34 |
+
parser.add_argument('--data_train_lol_v1' , type=str, default='./datasets/LOLdataset/our485')
|
| 35 |
+
parser.add_argument('--data_train_lolv2_real' , type=str, default='./datasets/LOLv2/Real_captured/Train')
|
| 36 |
+
parser.add_argument('--data_train_lolv2_syn' , type=str, default='./datasets/LOLv2/Synthetic/Train')
|
| 37 |
+
parser.add_argument('--data_train_SID' , type=str, default='./datasets/Sony_total_dark/train')
|
| 38 |
+
parser.add_argument('--data_train_SICE' , type=str, default='./datasets/SICE/Dataset/train')
|
| 39 |
+
parser.add_argument('--data_train_fivek' , type=str, default='./datasets/FiveK/train')
|
| 40 |
+
|
| 41 |
+
# validation input
|
| 42 |
+
parser.add_argument('--data_val_lol_blur' , type=str, default='./datasets/LOL_blur/eval/low_blur')
|
| 43 |
+
parser.add_argument('--data_val_lol_v1' , type=str, default='./datasets/LOLdataset/eval15/low')
|
| 44 |
+
parser.add_argument('--data_val_lolv2_real' , type=str, default='./datasets/LOLv2/Real_captured/Test/Low')
|
| 45 |
+
parser.add_argument('--data_val_lolv2_syn' , type=str, default='./datasets/LOLv2/Synthetic/Test/Low')
|
| 46 |
+
parser.add_argument('--data_val_SID' , type=str, default='./datasets/Sony_total_dark/eval/short')
|
| 47 |
+
parser.add_argument('--data_val_SICE_mix' , type=str, default='./datasets/SICE/Dataset/eval/test')
|
| 48 |
+
parser.add_argument('--data_val_SICE_grad' , type=str, default='./datasets/SICE/Dataset/eval/test')
|
| 49 |
+
parser.add_argument('--data_test_fivek' , type=str, default='./datasets/FiveK/test/input')
|
| 50 |
+
|
| 51 |
+
# validation groundtruth
|
| 52 |
+
parser.add_argument('--data_valgt_lol_blur' , type=str, default='./datasets/LOL_blur/eval/high_sharp_scaled/')
|
| 53 |
+
parser.add_argument('--data_valgt_lol_v1' , type=str, default='./datasets/LOLdataset/eval15/high/')
|
| 54 |
+
parser.add_argument('--data_valgt_lolv2_real' , type=str, default='./datasets/LOLv2/Real_captured/Test/Normal/')
|
| 55 |
+
parser.add_argument('--data_valgt_lolv2_syn' , type=str, default='./datasets/LOLv2/Synthetic/Test/Normal/')
|
| 56 |
+
parser.add_argument('--data_valgt_SID' , type=str, default='./datasets/Sony_total_dark/eval/long/')
|
| 57 |
+
parser.add_argument('--data_valgt_SICE_mix' , type=str, default='./datasets/SICE/Dataset/eval/target/')
|
| 58 |
+
parser.add_argument('--data_valgt_SICE_grad' , type=str, default='./datasets/SICE/Dataset/eval/target/')
|
| 59 |
+
parser.add_argument('--data_valgt_fivek' , type=str, default='./datasets/FiveK/test/target/')
|
| 60 |
+
|
| 61 |
+
parser.add_argument('--val_folder', default='./results/', help='Location to save validation datasets')
|
| 62 |
+
|
| 63 |
+
# loss weights
|
| 64 |
+
parser.add_argument('--HVI_weight', type=float, default=1.0)
|
| 65 |
+
parser.add_argument('--L1_weight', type=float, default=1.0)
|
| 66 |
+
parser.add_argument('--D_weight', type=float, default=0.5)
|
| 67 |
+
parser.add_argument('--E_weight', type=float, default=50.0)
|
| 68 |
+
parser.add_argument('--P_weight', type=float, default=1e-2)
|
| 69 |
+
|
| 70 |
+
# use random gamma function (enhancement curve) to improve generalization
|
| 71 |
+
parser.add_argument('--gamma', type=_str2bool, default=False)
|
| 72 |
+
parser.add_argument('--start_gamma', type=int, default=60)
|
| 73 |
+
parser.add_argument('--end_gamma', type=int, default=120)
|
| 74 |
+
|
| 75 |
+
# auto grad, turn off to speed up training
|
| 76 |
+
parser.add_argument('--grad_detect', type=_str2bool, default=False, help='if gradient explosion occurs, turn-on it')
|
| 77 |
+
parser.add_argument('--grad_clip', type=_str2bool, default=True, help='if gradient fluctuates too much, turn-on it')
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# choose which dataset you want to train
|
| 81 |
+
parser.add_argument('--dataset', type=str, default='lol_v1',
|
| 82 |
+
choices=['lol_v1',
|
| 83 |
+
'lolv2_real',
|
| 84 |
+
'lolv2_syn',
|
| 85 |
+
'lol_blur',
|
| 86 |
+
'SID',
|
| 87 |
+
'SICE_mix',
|
| 88 |
+
'SICE_grad',
|
| 89 |
+
'fivek'],
|
| 90 |
+
help='Select the dataset to train on (default: %(default)s)')
|
| 91 |
+
|
| 92 |
+
return parser
|
data/scheduler.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
| 2 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
class GradualWarmupScheduler(_LRScheduler):
|
| 6 |
+
""" Gradually warm-up(increasing) learning rate in optimizer.
|
| 7 |
+
Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
optimizer (Optimizer): Wrapped optimizer.
|
| 11 |
+
multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
|
| 12 |
+
total_epoch: target learning rate is reached at total_epoch, gradually
|
| 13 |
+
after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
|
| 17 |
+
self.multiplier = multiplier
|
| 18 |
+
if self.multiplier < 1.:
|
| 19 |
+
raise ValueError('multiplier should be greater thant or equal to 1.')
|
| 20 |
+
self.total_epoch = total_epoch
|
| 21 |
+
self.after_scheduler = after_scheduler
|
| 22 |
+
self.finished = False
|
| 23 |
+
super(GradualWarmupScheduler, self).__init__(optimizer)
|
| 24 |
+
|
| 25 |
+
def get_lr(self):
|
| 26 |
+
if self.last_epoch > self.total_epoch:
|
| 27 |
+
if self.after_scheduler:
|
| 28 |
+
if not self.finished:
|
| 29 |
+
self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
|
| 30 |
+
self.finished = True
|
| 31 |
+
return self.after_scheduler.get_lr()
|
| 32 |
+
return [base_lr * self.multiplier for base_lr in self.base_lrs]
|
| 33 |
+
|
| 34 |
+
if self.multiplier == 1.0:
|
| 35 |
+
return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
|
| 36 |
+
else:
|
| 37 |
+
return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
|
| 38 |
+
|
| 39 |
+
def step_ReduceLROnPlateau(self, metrics, epoch=None):
|
| 40 |
+
if epoch is None:
|
| 41 |
+
epoch = self.last_epoch + 1
|
| 42 |
+
self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
|
| 43 |
+
if self.last_epoch <= self.total_epoch:
|
| 44 |
+
warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
|
| 45 |
+
for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
|
| 46 |
+
param_group['lr'] = lr
|
| 47 |
+
else:
|
| 48 |
+
if epoch is None:
|
| 49 |
+
self.after_scheduler.step(metrics, None)
|
| 50 |
+
else:
|
| 51 |
+
self.after_scheduler.step(metrics, epoch - self.total_epoch)
|
| 52 |
+
|
| 53 |
+
def step(self, epoch=None, metrics=None):
|
| 54 |
+
if type(self.after_scheduler) != ReduceLROnPlateau:
|
| 55 |
+
if self.finished and self.after_scheduler:
|
| 56 |
+
if epoch is None:
|
| 57 |
+
self.after_scheduler.step(None)
|
| 58 |
+
else:
|
| 59 |
+
self.after_scheduler.step(epoch - self.total_epoch)
|
| 60 |
+
else:
|
| 61 |
+
return super(GradualWarmupScheduler, self).step(epoch)
|
| 62 |
+
else:
|
| 63 |
+
self.step_ReduceLROnPlateau(metrics, epoch)
|
| 64 |
+
|
| 65 |
+
def get_position_from_periods(iteration, cumulative_period):
|
| 66 |
+
"""Get the position from a period list.
|
| 67 |
+
|
| 68 |
+
It will return the index of the right-closest number in the period list.
|
| 69 |
+
For example, the cumulative_period = [100, 200, 300, 400],
|
| 70 |
+
if iteration == 50, return 0;
|
| 71 |
+
if iteration == 210, return 2;
|
| 72 |
+
if iteration == 300, return 2.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
iteration (int): Current iteration.
|
| 76 |
+
cumulative_period (list[int]): Cumulative period list.
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
int: The position of the right-closest number in the period list.
|
| 80 |
+
"""
|
| 81 |
+
for i, period in enumerate(cumulative_period):
|
| 82 |
+
if iteration <= period:
|
| 83 |
+
return i
|
| 84 |
+
|
| 85 |
+
class CosineAnnealingRestartCyclicLR(_LRScheduler):
|
| 86 |
+
""" Cosine annealing with restarts learning rate scheme.
|
| 87 |
+
An example of config:
|
| 88 |
+
periods = [10, 10, 10, 10]
|
| 89 |
+
restart_weights = [1, 0.5, 0.5, 0.5]
|
| 90 |
+
eta_min=1e-7
|
| 91 |
+
It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
|
| 92 |
+
scheduler will restart with the weights in restart_weights.
|
| 93 |
+
Args:
|
| 94 |
+
optimizer (torch.nn.optimizer): Torch optimizer.
|
| 95 |
+
periods (list): Period for each cosine anneling cycle.
|
| 96 |
+
restart_weights (list): Restart weights at each restart iteration.
|
| 97 |
+
Default: [1].
|
| 98 |
+
eta_min (float): The mimimum lr. Default: 0.
|
| 99 |
+
last_epoch (int): Used in _LRScheduler. Default: -1.
|
| 100 |
+
"""
|
| 101 |
+
|
| 102 |
+
def __init__(self,
|
| 103 |
+
optimizer,
|
| 104 |
+
periods,
|
| 105 |
+
restart_weights=(1, ),
|
| 106 |
+
eta_mins=(0, ),
|
| 107 |
+
last_epoch=-1):
|
| 108 |
+
self.periods = periods
|
| 109 |
+
self.restart_weights = restart_weights
|
| 110 |
+
self.eta_mins = eta_mins
|
| 111 |
+
assert (len(self.periods) == len(self.restart_weights)
|
| 112 |
+
), 'periods and restart_weights should have the same length.'
|
| 113 |
+
self.cumulative_period = [
|
| 114 |
+
sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
|
| 115 |
+
]
|
| 116 |
+
super(CosineAnnealingRestartCyclicLR, self).__init__(optimizer, last_epoch)
|
| 117 |
+
|
| 118 |
+
def get_lr(self):
|
| 119 |
+
idx = get_position_from_periods(self.last_epoch,
|
| 120 |
+
self.cumulative_period)
|
| 121 |
+
current_weight = self.restart_weights[idx]
|
| 122 |
+
nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
|
| 123 |
+
current_period = self.periods[idx]
|
| 124 |
+
eta_min = self.eta_mins[idx]
|
| 125 |
+
|
| 126 |
+
return [
|
| 127 |
+
eta_min + current_weight * 0.5 * (base_lr - eta_min) *
|
| 128 |
+
(1 + math.cos(math.pi * (
|
| 129 |
+
(self.last_epoch - nearest_restart) / current_period)))
|
| 130 |
+
for base_lr in self.base_lrs
|
| 131 |
+
]
|
| 132 |
+
|
| 133 |
+
class CosineAnnealingRestartLR(_LRScheduler):
|
| 134 |
+
""" Cosine annealing with restarts learning rate scheme.
|
| 135 |
+
|
| 136 |
+
An example of config:
|
| 137 |
+
periods = [10, 10, 10, 10]
|
| 138 |
+
restart_weights = [1, 0.5, 0.5, 0.5]
|
| 139 |
+
eta_min=1e-7
|
| 140 |
+
|
| 141 |
+
It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
|
| 142 |
+
scheduler will restart with the weights in restart_weights.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
optimizer (torch.nn.optimizer): Torch optimizer.
|
| 146 |
+
periods (list): Period for each cosine anneling cycle.
|
| 147 |
+
restart_weights (list): Restart weights at each restart iteration.
|
| 148 |
+
Default: [1].
|
| 149 |
+
eta_min (float): The mimimum lr. Default: 0.
|
| 150 |
+
last_epoch (int): Used in _LRScheduler. Default: -1.
|
| 151 |
+
"""
|
| 152 |
+
|
| 153 |
+
def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1):
|
| 154 |
+
self.periods = periods
|
| 155 |
+
self.restart_weights = restart_weights
|
| 156 |
+
self.eta_min = eta_min
|
| 157 |
+
assert (len(self.periods) == len(
|
| 158 |
+
self.restart_weights)), 'periods and restart_weights should have the same length.'
|
| 159 |
+
self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))]
|
| 160 |
+
super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch)
|
| 161 |
+
|
| 162 |
+
def get_lr(self):
|
| 163 |
+
idx = get_position_from_periods(self.last_epoch, self.cumulative_period)
|
| 164 |
+
current_weight = self.restart_weights[idx]
|
| 165 |
+
nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
|
| 166 |
+
current_period = self.periods[idx]
|
| 167 |
+
|
| 168 |
+
return [
|
| 169 |
+
self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) *
|
| 170 |
+
(1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period)))
|
| 171 |
+
for base_lr in self.base_lrs
|
| 172 |
+
]
|
data/util.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from PIL import Image
|
| 3 |
+
|
| 4 |
+
def is_image_file(filename):
|
| 5 |
+
return any(filename.endswith(extension) for extension in [".png", ".jpg", ".bmp", ".JPG", ".jpeg"])
|
| 6 |
+
|
| 7 |
+
def load_img(filepath):
|
| 8 |
+
img = Image.open(filepath).convert('RGB')
|
| 9 |
+
return img
|