| from __future__ import division |
|
|
| import numpy as np |
| import cv2 |
| import random |
| import torch |
| import glob |
| import os |
| from random import choices |
| from scipy.stats import poisson |
|
|
| def Rawread(path,low=0): |
| if path.endswith('.raw'): |
| return read_img(path,low) |
| if path.endswith('.npy'): |
| return read_npy(path,low) |
| if path.endswith('.png'): |
| return read_png(path,low) |
| |
| def read_img(path,low): |
| w = 4000 |
| h = 3000 |
|
|
| raw = np.fromfile(path,np.uint16) |
| raw = raw.reshape((h,w)) |
| raw = raw.astype(np.float32)-64 |
| raw = rggb_raw(raw) |
| raw = np.clip(raw, low, 959) |
|
|
| return raw |
|
|
|
|
| def read_npy(path,low): |
| |
| raw = np.load(path) |
|
|
| if raw.shape[0] == 4: |
| return raw * 959 |
| raw = raw.astype(np.float32)-64 |
| raw = rggb_raw(raw) |
| raw = np.clip(raw, low, 959) |
| return raw |
|
|
| def read_rawpng(path, metadata): |
| |
| raw = cv2.imread(str(path), cv2.IMREAD_UNCHANGED) |
|
|
| |
| |
| raw = ((raw.astype(np.float32) - 256.) / (4095.- 256.)).clip(0, 1) |
| |
| raw = bayer2raw(raw, metadata) |
| raw = np.clip(raw, 0., 1.) |
| return raw |
|
|
| def read_png(path, low): |
| |
| raw = cv2.imread(str(path), cv2.IMREAD_UNCHANGED) |
|
|
| if raw.shape[0] == 4: |
| return raw * 959 |
| raw = raw.astype(np.float32)-256 |
| raw = rggb_raw(raw) |
| raw = np.clip(raw, low, 4095) |
| return raw |
|
|
| def random_crop(frames_0,frames_1=None ,crop_size=128): |
|
|
| F,C, H, W = frames_0.shape |
|
|
| rnd_w = random.randint(0, W - crop_size) |
| rnd_h = random.randint(0, H - crop_size) |
|
|
| patch = frames_0[..., rnd_h:rnd_h + crop_size, rnd_w:rnd_w + crop_size] |
| if not frames_1 is None: |
| path1 = frames_1[..., rnd_h:rnd_h + crop_size, rnd_w:rnd_w + crop_size] |
| return np.concatenate([patch,path1],axis=0) |
|
|
| return patch |
|
|
| def rggb_raw(raw): |
| |
| H, W = raw.shape |
| raw = raw[None, ...] |
| raw_pack = np.concatenate((raw[:, 0:H:2, 0:W:2], |
| raw[:, 0:H:2, 1:W:2], |
| raw[:, 1:H:2, 0:W:2], |
| raw[:, 1:H:2, 1:W:2]), axis=0) |
| return raw_pack |
|
|
| def bayer2raw(raw, metadata): |
| |
| H, W = raw.shape |
| raw = raw[None, ...] |
| if metadata['cfa_pattern'][0] == 0: |
| |
| raw_pack = np.concatenate((raw[:, 0:H:2, 0:W:2], |
| raw[:, 0:H:2, 1:W:2], |
| raw[:, 1:H:2, 0:W:2], |
| raw[:, 1:H:2, 1:W:2]), axis=0) |
| else : |
| |
| raw_pack = np.concatenate((raw[:, 1:H:2, 1:W:2], |
| raw[:, 0:H:2, 1:W:2], |
| raw[:, 1:H:2, 0:W:2], |
| raw[:, 0:H:2, 0:W:2]), axis=0) |
| return raw_pack |
|
|
| def raw_rggb(raws): |
| |
| C, H, W = raws.shape |
| output = np.zeros((H * 2, W * 2)).astype(np.uint16) |
|
|
| output[0:2 * H:2, 0:2 * W:2] = raws[0:1, :, :] |
| output[0:2 * H:2, 1:2 * W:2] = raws[1:2, :, :] |
| output[1:2 * H:2, 0:2 * W:2] = raws[2:3, :, :] |
| output[1:2 * H:2, 1:2 * W:2] = raws[3:4, :, :] |
|
|
| return output |
|
|
|
|
| def raw_rggb_float32(raws): |
| |
| C, H, W = raws.shape |
| output = np.zeros((H * 2, W * 2)).astype(np.float32) |
|
|
| output[0:2 * H:2, 0:2 * W:2] = raws[0:1, :, :] |
| output[0:2 * H:2, 1:2 * W:2] = raws[1:2, :, :] |
| output[1:2 * H:2, 0:2 * W:2] = raws[2:3, :, :] |
| output[1:2 * H:2, 1:2 * W:2] = raws[3:4, :, :] |
|
|
| return output |
|
|
|
|
| def depack_rggb_raws(raws): |
| |
| N, C, H, W = raws.shape |
| output = torch.zeros((N, 1, H * 2, W * 2)) |
|
|
| output[:, :, 0:2 * H:2, 0:2 * W:2] = raws[:, 0:1, :, :] |
| output[:, :, 0:2 * H:2, 1:2 * W:2] = raws[:, 1:2, :, :] |
| output[:, :, 1:2 * H:2, 0:2 * W:2] = raws[:, 2:3, :, :] |
| output[:, :, 1:2 * H:2, 1:2 * W:2] = raws[:, 3:4, :, :] |
|
|
| return output |
|
|
|
|
|
|
| |
| IMAGETYPES = ('*.npy','*.raw',) |
|
|
| def get_imagenames(seq_dir, pattern=None): |
| """ Get ordered list of filenames |
| """ |
| files = [] |
| for typ in IMAGETYPES: |
| files.extend(glob.glob(os.path.join(seq_dir, typ))) |
|
|
| |
| if not pattern is None: |
| ffiltered = [] |
| ffiltered = [f for f in files if pattern in os.path.split(f)[-1]] |
| files = ffiltered |
| del ffiltered |
|
|
| |
| files.sort(key=lambda f: int(''.join(filter(str.isdigit, f)))) |
| return files |
|
|
|
|
|
|
|
|
| def get_imagenames(seq_dir, pattern=None): |
| """ Get ordered list of filenames |
| """ |
| files = [] |
| for typ in IMAGETYPES: |
| files.extend(glob.glob(os.path.join(seq_dir, typ))) |
|
|
| |
| if not pattern is None: |
| ffiltered = [] |
| ffiltered = [f for f in files if pattern in os.path.split(f)[-1]] |
| files = ffiltered |
| del ffiltered |
|
|
| |
| files.sort(key=lambda f: int(''.join(filter(str.isdigit, f)))) |
| return files |
|
|
| def open_sequence(seq_dir, gray_mode, expand_if_needed=False, max_num_fr=100): |
| r""" Opens a sequence of images and expands it to even sizes if necesary |
| Args: |
| fpath: string, path to image sequence |
| gray_mode: boolean, True indicating if images is to be open are in grayscale mode |
| expand_if_needed: if True, the spatial dimensions will be expanded if |
| size is odd |
| expand_axis0: if True, output will have a fourth dimension |
| max_num_fr: maximum number of frames to load |
| Returns: |
| seq: array of dims [num_frames, C, H, W], C=1 grayscale or C=3 RGB, H and W are even. |
| The image gets normalized gets normalized to the range [0, 1]. |
| expanded_h: True if original dim H was odd and image got expanded in this dimension. |
| expanded_w: True if original dim W was odd and image got expanded in this dimension. |
| """ |
| |
| files = get_imagenames(seq_dir) |
|
|
| seq_list_raw = [] |
| seq_list_raw_noise = [] |
| print("\tOpen sequence in folder: ", seq_dir) |
| for fpath in files[0:max_num_fr]: |
|
|
| raw, raw_noise, expanded_h, expanded_w = open_image(fpath,\ |
| gray_mode=gray_mode,\ |
| expand_if_needed=expand_if_needed,\ |
| expand_axis0=False) |
| |
| raw = rggb_raw(raw) |
| raw_noise = rggb_raw(raw_noise) |
|
|
|
|
| seq_list_raw.append(raw) |
| seq_list_raw_noise.append(raw_noise) |
| seq_raw = np.stack(seq_list_raw, axis=0) |
| seq_raw_noise = np.stack(seq_list_raw_noise, axis=0) |
| return seq_raw, seq_raw_noise, expanded_h, expanded_w |
|
|
| def open_image(fpath, gray_mode, expand_if_needed=False, expand_axis0=True, normalize_data=True): |
| r""" Opens an image and expands it if necesary |
| Args: |
| fpath: string, path of image file |
| gray_mode: boolean, True indicating if image is to be open |
| in grayscale mode |
| expand_if_needed: if True, the spatial dimensions will be expanded if |
| size is odd |
| expand_axis0: if True, output will have a fourth dimension |
| Returns: |
| img: image of dims NxCxHxW, N=1, C=1 grayscale or C=3 RGB, H and W are even. |
| if expand_axis0=False, the output will have a shape CxHxW. |
| The image gets normalized to the range [0, 1]. |
| expanded_h: True if original dim H was odd and image got expanded in this dimension. |
| expanded_w: True if original dim W was odd and image got expanded in this dimension. |
| """ |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
|
|
| |
| |
| |
| |
|
|
| |
| w = 4000 |
| h = 3000 |
| raw_img = np.fromfile(fpath,dtype=np.uint16,count=w*h) |
| raw_img = raw_img.reshape((h,w)).astype(np.float32)-64 |
| raw_img = np.clip(raw_img, 0, 959) |
|
|
| noise_fpath =fpath.replace('onlyraw_test_clean_raw','onlyraw_test_noise_raw') |
| raw_img_noise = np.fromfile(noise_fpath,dtype=np.uint16,count=w*h) |
| raw_img_noise = raw_img_noise.reshape((h,w)).astype(np.float32)-64 |
| raw_img_noise = np.clip(raw_img_noise, 0, 959) |
|
|
|
|
| |
|
|
|
|
| |
| |
|
|
| |
| expanded_h = False |
| expanded_w = False |
| sh_im = raw_img.shape |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| if normalize_data: |
| raw_img = normalize(raw_img) |
| raw_img_noise = normalize(raw_img_noise) |
| return raw_img, raw_img_noise, expanded_h, expanded_w |
|
|
|
|
| def normalize(data): |
| r"""Normalizes a unit8 image to a float32 image in the range [0, 1] |
| |
| Args: |
| data: a unint8 numpy array to normalize from [0, 255] to [0, 1] |
| """ |
| return np.float32(data/(959)) |
|
|
|
|
| def augment_cuda(batches, args, spynet=None): |
|
|
| def _augment(img, hflip=True, rot=True): |
|
|
| hflip = hflip and random.random() < 0.5 |
| vflip = rot and random.random() < 0.5 |
| |
| k1 = np.random.randint(0, 4) |
| if hflip: img = img.flip(-1) |
| if vflip: img = img.flip(-2) |
| |
| img = torch.rot90(img, k=k1, dims=[-2, -1]) |
| |
| return img |
|
|
| batches_aug = _augment(batches) |
|
|
| if args.pair: |
| noise = batches_aug[:,:args.frame,...]/959 |
| clean = batches_aug[:,args.frame,...]/959 |
|
|
| |
| else: |
| clean, noise = Noise_simulation(batches_aug,args) |
| if not args.consistent_loss: |
| clean = clean[:, args.frame // 2, ...] |
| B, F, C , H, W = noise.shape |
| noise = noise.reshape(B, F*C , H, W ) |
|
|
|
|
| return clean, noise, None |
| |
|
|
| def Noise_simulation(batches_aug,args): |
| batches_aug = batches_aug/959 |
| batches_aug = torch.clamp(batches_aug , 0, 1) |
| B = batches_aug.shape[0] |
| batch_aug_mean = batches_aug.mean(dim=(1,2,3,4)) |
| if args.need_Scaling: |
| if args.sample_gain == 'type1': |
| |
| rand_avg = (torch.rand((B)) * 0.12 + 0.001).cuda(args.local_rank) |
| if args.sample_gain == 'type2': |
| rand_avg = Gain_Sampler(B).cuda(args.local_rank) |
|
|
| coef = (batch_aug_mean / rand_avg).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) |
| batch_aug_dark = torch.clamp(batches_aug / coef, 0, 1) |
| else: |
| batch_aug_dark = batches_aug |
|
|
| a,b, again, dgain = random_noise_levels_nightimaging(B, args) |
| batch_aug_dark,batch_aug_dark_noise = add_noise(args, batch_aug_dark,a.cuda(args.local_rank),b.cuda(args.local_rank),dgain.cuda(args.local_rank)) |
|
|
| batch_aug_dark_noise = torch.clamp(batch_aug_dark_noise, -0.1, 1) |
|
|
| |
| return batch_aug_dark.float(), batch_aug_dark_noise.float() |
|
|
| def random_noise_levels_nightimaging(B, args): |
| |
| g = torch.FloatTensor(B).uniform_(0, 125).int().long() |
| noise_profile = torch.from_numpy(np.load('/data1/chengqihua/02_code/03_night_photogrphy/nightimage_v1/dataloader/json_all_2nd.npy')) |
|
|
| a = noise_profile[g,0] |
| b = noise_profile[g,1] |
|
|
| return a, b, 1, 1*torch.ones(1) |
|
|
| def random_noise_levels(B, args): |
| ak1=0.05244803 |
| ak2=0.01498041 |
| bk1=0.00648923 |
| bk2= 0.05899386 |
| bk3 = 0.21520193 |
| g = torch.FloatTensor(B).uniform_(args.min_gain, args.max_gain) |
|
|
| maskA = g > 16 |
|
|
| again = g.clone() |
| again[maskA] = 16 |
|
|
| maskB = g < 16 |
|
|
| dgain = g.clone() / 16 |
| dgain[maskB] = 1 |
|
|
|
|
|
|
| a = ak1 * again + ak2 |
| b = bk1 * again*again + bk2* again + bk3 |
|
|
| return a, b, again, dgain |
|
|
| def add_noise(args, image, a, b, dgain): |
|
|
| dgain = dgain.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1) |
| a = a.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1) |
| b = b.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1) |
| |
| |
| B, F, C, H, W = image.size() |
|
|
| image = image / dgain |
|
|
|
|
| poisson_noisy_img = torch.poisson(image/a)*a |
|
|
| gaussian_noise = torch.sqrt(b)*torch.randn(B, F, C, H, W).cuda(args.local_rank) |
|
|
| noiseimg = poisson_noisy_img + gaussian_noise |
|
|
| if args.usedgain : |
| noiseimg = noiseimg * dgain |
| image = image * dgain |
| return image, noiseimg |
|
|
|
|
|
|
| def normalize_augment(datain): |
| '''Normalizes and augments an input patch of dim [N, num_frames, C. H, W] in [0., 255.] to \ |
| [N, num_frames*C. H, W] in [0., 1.]. It also returns the central (edited by cjm : now all frames) frame of the temporal \ |
| patch as a ground truth. |
| ''' |
| def transform(sample): |
| |
| do_nothing = lambda x: x |
| do_nothing.__name__ = 'do_nothing' |
| flipud = lambda x: torch.flip(x, dims=[2]) |
| flipud.__name__ = 'flipup' |
| rot90 = lambda x: torch.rot90(x, k=1, dims=[2, 3]) |
| rot90.__name__ = 'rot90' |
| rot90_flipud = lambda x: torch.flip(torch.rot90(x, k=1, dims=[2, 3]), dims=[2]) |
| rot90_flipud.__name__ = 'rot90_flipud' |
| rot180 = lambda x: torch.rot90(x, k=2, dims=[2, 3]) |
| rot180.__name__ = 'rot180' |
| rot180_flipud = lambda x: torch.flip(torch.rot90(x, k=2, dims=[2, 3]), dims=[2]) |
| rot180_flipud.__name__ = 'rot180_flipud' |
| rot270 = lambda x: torch.rot90(x, k=3, dims=[2, 3]) |
| rot270.__name__ = 'rot270' |
| rot270_flipud = lambda x: torch.flip(torch.rot90(x, k=3, dims=[2, 3]), dims=[2]) |
| rot270_flipud.__name__ = 'rot270_flipud' |
| add_csnt = lambda x: x + torch.normal(mean=torch.zeros(x.size()[0], 1, 1, 1), \ |
| std=(5/255.)).expand_as(x).to(x.device) |
| add_csnt.__name__ = 'add_csnt' |
|
|
| |
| aug_list = [do_nothing, flipud, rot90, rot90_flipud, \ |
| rot180, rot180_flipud, rot270, rot270_flipud, add_csnt] |
| w_aug = [32, 12, 12, 12, 12, 12, 12, 12, 12] |
| transf = choices(aug_list, w_aug) |
|
|
| |
| return transf[0](sample) |
|
|
| img_train = datain |
| |
| N, F, C, H, W = img_train.shape |
| img_train = img_train.view(img_train.size()[0], -1, \ |
| img_train.size()[-2], img_train.size()[-1]) / 255. |
|
|
| |
| img_train = transform(img_train) |
| img_train = img_train.view(N, F, C, H, W) |
| |
| |
| return img_train, img_train |
|
|
| def Gain_Sampler(B): |
| gain_dict = { |
| 'low':[5,35], |
| 'mid':[35,60], |
| 'high':[60,100] |
| } |
|
|
| level = ['low','mid','high'] |
| sampled = np.random.choice(level,B,[0.7,0.2,0.1]) |
| all = [] |
| for index in sampled: |
| all.append(torch.randint(gain_dict[index][0],gain_dict[index][1],(1,))) |
|
|
| return torch.Tensor(all) |
|
|
| def path_replace(path,args): |
| for i in range(len(args.replace_left)): |
| path = path.replace(args.replace_left[i],args.replace_right[i]) |
| return path |