import os import random import copy from PIL import Image import numpy as np from torch.utils.data import Dataset from torchvision.transforms import ToPILImage, Compose, RandomCrop, ToTensor import torch from utils.image_utils import random_augmentation, crop_img from utils.degradation_utils import Degradation class PromptTrainDataset(Dataset): def __init__(self, args): super(PromptTrainDataset, self).__init__() self.args = args self.rs_ids = [] self.hazy_ids = [] self.D = Degradation(args) self.de_temp = 0 self.de_type = self.args.de_type print(self.de_type) self.de_dict = {'denoise_15': 0, 'denoise_25': 1, 'denoise_50': 2, 'derain': 3, 'dehaze': 4, 'deblur' : 5} self._init_ids() self._merge_ids() self.crop_transform = Compose([ ToPILImage(), RandomCrop(args.patch_size), ]) self.toTensor = ToTensor() def _init_ids(self): if 'denoise_15' in self.de_type or 'denoise_25' in self.de_type or 'denoise_50' in self.de_type: self._init_clean_ids() if 'derain' in self.de_type: self._init_rs_ids() if 'dehaze' in self.de_type: self._init_hazy_ids() random.shuffle(self.de_type) def _init_clean_ids(self): ref_file = self.args.data_file_dir + "noisy/denoise_airnet.txt" temp_ids = [] temp_ids+= [id_.strip() for id_ in open(ref_file)] clean_ids = [] name_list = os.listdir(self.args.denoise_dir) clean_ids += [self.args.denoise_dir + id_ for id_ in name_list if id_.strip() in temp_ids] if 'denoise_15' in self.de_type: self.s15_ids = [{"clean_id": x,"de_type":0} for x in clean_ids] self.s15_ids = self.s15_ids * 3 random.shuffle(self.s15_ids) self.s15_counter = 0 if 'denoise_25' in self.de_type: self.s25_ids = [{"clean_id": x,"de_type":1} for x in clean_ids] self.s25_ids = self.s25_ids * 3 random.shuffle(self.s25_ids) self.s25_counter = 0 if 'denoise_50' in self.de_type: self.s50_ids = [{"clean_id": x,"de_type":2} for x in clean_ids] self.s50_ids = self.s50_ids * 3 random.shuffle(self.s50_ids) self.s50_counter = 0 self.num_clean = len(clean_ids) print("Total Denoise Ids : {}".format(self.num_clean)) def _init_hazy_ids(self): temp_ids = [] hazy = self.args.data_file_dir + "hazy/hazy_outside.txt" temp_ids+= [self.args.dehaze_dir + id_.strip() for id_ in open(hazy)] self.hazy_ids = [{"clean_id" : x,"de_type":4} for x in temp_ids] self.hazy_counter = 0 self.num_hazy = len(self.hazy_ids) print("Total Hazy Ids : {}".format(self.num_hazy)) def _init_rs_ids(self): temp_ids = [] rs = self.args.data_file_dir + "rainy/rainTrain.txt" temp_ids+= [self.args.derain_dir + id_.strip() for id_ in open(rs)] self.rs_ids = [{"clean_id":x,"de_type":3} for x in temp_ids] self.rs_ids = self.rs_ids * 120 self.rl_counter = 0 self.num_rl = len(self.rs_ids) print("Total Rainy Ids : {}".format(self.num_rl)) def _crop_patch(self, img_1, img_2): H = img_1.shape[0] W = img_1.shape[1] ind_H = random.randint(0, H - self.args.patch_size) ind_W = random.randint(0, W - self.args.patch_size) patch_1 = img_1[ind_H:ind_H + self.args.patch_size, ind_W:ind_W + self.args.patch_size] patch_2 = img_2[ind_H:ind_H + self.args.patch_size, ind_W:ind_W + self.args.patch_size] return patch_1, patch_2 def _get_gt_name(self, rainy_name): gt_name = rainy_name.split("rainy")[0] + 'gt/norain-' + rainy_name.split('rain-')[-1] return gt_name def _get_nonhazy_name(self, hazy_name): dir_name = hazy_name.split("synthetic")[0] + 'original/' name = hazy_name.split('/')[-1].split('_')[0] suffix = '.' + hazy_name.split('.')[-1] nonhazy_name = dir_name + name + suffix return nonhazy_name def _merge_ids(self): self.sample_ids = [] if "denoise_15" in self.de_type: self.sample_ids += self.s15_ids self.sample_ids += self.s25_ids self.sample_ids += self.s50_ids if "derain" in self.de_type: self.sample_ids+= self.rs_ids if "dehaze" in self.de_type: self.sample_ids+= self.hazy_ids print(len(self.sample_ids)) def __getitem__(self, idx): sample = self.sample_ids[idx] de_id = sample["de_type"] if de_id < 3: if de_id == 0: clean_id = sample["clean_id"] elif de_id == 1: clean_id = sample["clean_id"] elif de_id == 2: clean_id = sample["clean_id"] clean_img = crop_img(np.array(Image.open(clean_id).convert('RGB')), base=16) clean_patch = self.crop_transform(clean_img) clean_patch= np.array(clean_patch) clean_name = clean_id.split("/")[-1].split('.')[0] clean_patch = random_augmentation(clean_patch)[0] degrad_patch = self.D.single_degrade(clean_patch, de_id) else: if de_id == 3: # Rain Streak Removal degrad_img = crop_img(np.array(Image.open(sample["clean_id"]).convert('RGB')), base=16) clean_name = self._get_gt_name(sample["clean_id"]) clean_img = crop_img(np.array(Image.open(clean_name).convert('RGB')), base=16) elif de_id == 4: # Dehazing with SOTS outdoor training set degrad_img = crop_img(np.array(Image.open(sample["clean_id"]).convert('RGB')), base=16) clean_name = self._get_nonhazy_name(sample["clean_id"]) clean_img = crop_img(np.array(Image.open(clean_name).convert('RGB')), base=16) degrad_patch, clean_patch = random_augmentation(*self._crop_patch(degrad_img, clean_img)) clean_patch = self.toTensor(clean_patch) degrad_patch = self.toTensor(degrad_patch) return [clean_name, de_id], degrad_patch, clean_patch def __len__(self): return len(self.sample_ids) class DenoiseTestDataset(Dataset): def __init__(self, args): super(DenoiseTestDataset, self).__init__() self.args = args self.clean_ids = [] self.sigma = 15 self._init_clean_ids() self.toTensor = ToTensor() def _init_clean_ids(self): name_list = os.listdir(self.args.denoise_path) self.clean_ids += [self.args.denoise_path + id_ for id_ in name_list] self.num_clean = len(self.clean_ids) def _add_gaussian_noise(self, clean_patch): noise = np.random.randn(*clean_patch.shape) noisy_patch = np.clip(clean_patch + noise * self.sigma, 0, 255).astype(np.uint8) return noisy_patch, clean_patch def set_sigma(self, sigma): self.sigma = sigma def __getitem__(self, clean_id): clean_img = crop_img(np.array(Image.open(self.clean_ids[clean_id]).convert('RGB')), base=16) clean_name = self.clean_ids[clean_id].split("/")[-1].split('.')[0] noisy_img, _ = self._add_gaussian_noise(clean_img) clean_img, noisy_img = self.toTensor(clean_img), self.toTensor(noisy_img) return [clean_name], noisy_img, clean_img def tile_degrad(input_,tile=128,tile_overlap =0): sigma_dict = {0:0,1:15,2:25,3:50} b, c, h, w = input_.shape tile = min(tile, h, w) assert tile % 8 == 0, "tile size should be multiple of 8" stride = tile - tile_overlap h_idx_list = list(range(0, h-tile, stride)) + [h-tile] w_idx_list = list(range(0, w-tile, stride)) + [w-tile] E = torch.zeros(b, c, h, w).type_as(input_) W = torch.zeros_like(E) s = 0 for h_idx in h_idx_list: for w_idx in w_idx_list: in_patch = input_[..., h_idx:h_idx+tile, w_idx:w_idx+tile] out_patch = in_patch # out_patch = model(in_patch) out_patch_mask = torch.ones_like(in_patch) E[..., h_idx:(h_idx+tile), w_idx:(w_idx+tile)].add_(out_patch) W[..., h_idx:(h_idx+tile), w_idx:(w_idx+tile)].add_(out_patch_mask) # restored = E.div_(W) restored = torch.clamp(restored, 0, 1) return restored def __len__(self): return self.num_clean class DerainDehazeDataset(Dataset): def __init__(self, args, task="derain",addnoise = False,sigma = None): super(DerainDehazeDataset, self).__init__() self.ids = [] self.task_idx = 0 self.args = args self.task_dict = {'derain': 0, 'dehaze': 1} self.toTensor = ToTensor() self.addnoise = addnoise self.sigma = sigma self.set_dataset(task) def _add_gaussian_noise(self, clean_patch): noise = np.random.randn(*clean_patch.shape) noisy_patch = np.clip(clean_patch + noise * self.sigma, 0, 255).astype(np.uint8) return noisy_patch, clean_patch def _init_input_ids(self): if self.task_idx == 0: self.ids = [] name_list = os.listdir(self.args.derain_path + 'input/') # print(name_list) print(self.args.derain_path) self.ids += [self.args.derain_path + 'input/' + id_ for id_ in name_list] elif self.task_idx == 1: self.ids = [] name_list = os.listdir(self.args.dehaze_path + 'input/') self.ids += [self.args.dehaze_path + 'input/' + id_ for id_ in name_list] self.length = len(self.ids) def _get_gt_path(self, degraded_name): if self.task_idx == 0: gt_name = degraded_name.replace("input", "target") elif self.task_idx == 1: dir_name = degraded_name.split("input")[0] + 'target/' name = degraded_name.split('/')[-1].split('_')[0] + '.png' gt_name = dir_name + name return gt_name def set_dataset(self, task): self.task_idx = self.task_dict[task] self._init_input_ids() def __getitem__(self, idx): degraded_path = self.ids[idx] clean_path = self._get_gt_path(degraded_path) degraded_img = crop_img(np.array(Image.open(degraded_path).convert('RGB')), base=16) if self.addnoise: degraded_img,_ = self._add_gaussian_noise(degraded_img) clean_img = crop_img(np.array(Image.open(clean_path).convert('RGB')), base=16) clean_img, degraded_img = self.toTensor(clean_img), self.toTensor(degraded_img) degraded_name = degraded_path.split('/')[-1][:-4] return [degraded_name], degraded_img, clean_img def __len__(self): return self.length class TestSpecificDataset(Dataset): def __init__(self, args): super(TestSpecificDataset, self).__init__() self.args = args self.degraded_ids = [] self._init_clean_ids(args.test_path) self.toTensor = ToTensor() def _init_clean_ids(self, root): extensions = ['jpg', 'JPG', 'png', 'PNG', 'jpeg', 'JPEG', 'bmp', 'BMP'] if os.path.isdir(root): name_list = [] for image_file in os.listdir(root): if any([image_file.endswith(ext) for ext in extensions]): name_list.append(image_file) if len(name_list) == 0: raise Exception('The input directory does not contain any image files') self.degraded_ids += [root + id_ for id_ in name_list] else: if any([root.endswith(ext) for ext in extensions]): name_list = [root] else: raise Exception('Please pass an Image file') self.degraded_ids = name_list print("Total Images : {}".format(name_list)) self.num_img = len(self.degraded_ids) def __getitem__(self, idx): degraded_img = crop_img(np.array(Image.open(self.degraded_ids[idx]).convert('RGB')), base=16) name = self.degraded_ids[idx].split('/')[-1][:-4] degraded_img = self.toTensor(degraded_img) return [name], degraded_img def __len__(self): return self.num_img