| import os |
|
|
| import multiprocessing as mp |
| import numpy as np |
| import cv2 as cv |
| from tqdm import tqdm |
|
|
| from glob import glob |
| from utils import * |
|
|
| from grayness_index import GraynessIndex |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| from time import time |
|
|
| from network_raw_denoise import sc_net_1f |
| from network import MWRCANv4 as NET |
| from classes.fc4.ModelFC4 import ModelFC4 |
|
|
| def load_img(img_path): |
| meta_all = {} |
| meta_all['img_path'] = img_path |
|
|
| |
| metadata = json_read(img_path.replace(".png", ".json"), object_hook=fraction_from_json) |
| meta_all['meta'] = metadata |
|
|
| |
| img = cv.imread(img_path, cv.IMREAD_UNCHANGED) |
| meta_all['img'] = img |
|
|
| return meta_all |
|
|
|
|
|
|
|
|
| def pre_process(meta_all): |
| img = meta_all['img'] |
|
|
| metadata = meta_all['meta'] |
| cfa_pattern = metadata['cfa_pattern'] |
| cfa_pattern_ = "" |
| for tt in cfa_pattern: |
| if tt == 0: |
| cfa_pattern_ += "r" |
| elif tt == 1: |
| cfa_pattern_ += "g" |
| elif tt == 2: |
| cfa_pattern_ += "b" |
| else: |
| raise |
|
|
| offsets = bayer_to_offsets(cfa_pattern_) |
| img = pack_raw_to_4ch(img, offsets) |
| |
| if img.shape[0] != 768 and img.shape[1] != 1024: |
| img = cv.resize(img, (1024, 768), interpolation=cv.INTER_AREA) |
| |
| bl_fix = np.clip((float(metadata["noise_profile"][0])-0.005) * 1000, 0, 10) |
| img = normalize(img, metadata['black_level'], metadata['white_level'], bl_fix).astype(np.float32) |
|
|
| noise_profile = float(metadata["noise_profile"][0]) |
| noise_list = [0.00025822882, 0.000580020745, 0.00141667975, 0.00278965863, 0.00347614807] |
|
|
| if noise_profile < 0.005: |
| if noise_profile < noise_list[0]: |
| weight1 = noise_profile / noise_list[0] |
| final_lsc = lsc_npy[0] * weight1 |
| linear_idx1, linear_idx2 = 0, 0 |
| elif noise_profile > noise_list[-1]: |
| final_lsc = lsc_npy[-1] |
| linear_idx1, linear_idx2 = -1, -1 |
| else: |
| for idx, nn in enumerate(noise_list): |
| if noise_profile < nn: |
| linear_idx1 = idx - 1 |
| linear_idx2 = idx |
| break |
|
|
| weight1 = (noise_profile - noise_list[linear_idx1]) / (noise_list[linear_idx2] - noise_list[linear_idx1]) |
| weight2 = 1-weight1 |
| final_lsc = lsc_npy[linear_idx1] * weight1 + lsc_npy[linear_idx2] * weight2 |
|
|
| ones = np.ones_like(final_lsc) |
| final_lsc = final_lsc * 0.6 + ones * 0.4 |
| final_lsc[:, :512, :] = final_lsc[:, 1024:511:-1, :] |
| |
| img = img * final_lsc |
| |
| img = np.clip(img, 0.0, 1.0) |
| meta_all["img"] = img |
| |
| rgb_gain = metadata['as_shot_neutral'] |
| ra, ga, ba = rgb_gain |
| ra, ga, ba = 1/ra, 1/ga, 1/ba |
| |
| meta_all['r_gains'] = [ra] |
| meta_all['g_gains'] = [ga] |
| meta_all['b_gains'] = [ba] |
| |
| return meta_all |
|
|
|
|
|
|
| def raw_denoise(results): |
| checkpoint_path = "checkpoint/raw_denoise.pth" |
| device = torch.device("cuda") |
| model = get_net(sc_net_1f, checkpoint_path, device) |
|
|
| for meta_all in tqdm(results): |
| img = meta_all['img'] |
| |
| img = np.expand_dims(img, axis=0) |
| ori_inp = img.copy() |
| |
| clip_min = max(np.mean(img)*3, 0.9) |
| img = np.clip(img, 0, clip_min) |
| img = torch.from_numpy(img.transpose(0, 3, 1, 2)).cuda() |
|
|
| with torch.no_grad(): |
| output = model(img) |
| |
| output = output.detach().cpu().numpy().transpose(0, 2, 3, 1) |
|
|
| img = ori_inp + output |
| img = np.clip(img, 0, 1) |
| img = np.squeeze(img) |
| |
| meta_all['img'] = img |
|
|
|
|
|
|
| def predict_white_balance(results): |
| model = ModelFC4() |
| for model_index in [0, 1, 2]: |
| path_to_pretrained = os.path.join("./trained_models", "fc4_cwp", "fold_{}".format(model_index)) |
| model.load(path_to_pretrained) |
| model.evaluation_mode() |
|
|
| for meta_all in tqdm(results): |
| img = meta_all['img'].copy() |
| img[:, :, 1] = (img[:, :, 1] + img[:, :, 3]) / 2 |
| img = img[:, :, :-1] |
|
|
| img = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).cuda() |
| img_tmp = torch.pow(img, 1.0 / 2.2) |
|
|
| with torch.no_grad(): |
| pred = model.predict(img_tmp, return_steps=False) |
| pred = pred.detach().cpu().squeeze(0).numpy() |
|
|
| |
| r, g, b = pred |
|
|
| r /= g |
| b /= g |
| g /= g |
|
|
| r = 1./ r |
| g = 1./ g |
| b = 1./ b |
| |
| meta_all['r_gains'].append(r) |
| meta_all['g_gains'].append(g) |
| meta_all['b_gains'].append(b) |
|
|
|
|
| def convert_to_rgb(meta_all): |
| img = meta_all['img'] |
| img[:, :, 1] = (img[:, :, 1] + img[:, :, 3]) / 2 |
| img = img[:, :, :-1] |
|
|
| |
| |
| r_gains = sorted(meta_all['r_gains']) |
| b_gains = sorted(meta_all['b_gains']) |
|
|
| r_final = (r_gains[0] + r_gains[1] + r_gains[2]) / 3 |
| g_final = 1 |
| b_final = (b_gains[1] + b_gains[2] + b_gains[3]) / 3 |
|
|
| img[:, :, 0] *= r_final |
| img[:, :, 1] *= g_final |
| img[:, :, 2] *= b_final |
|
|
| img = np.clip(img, 0, 1) |
| |
| |
| |
| img = apply_color_space_transform(img, color_matrix, color_matrix) |
|
|
|
|
| |
| img = transform_xyz_to_srgb(img) |
|
|
|
|
| |
| if float(meta_all['meta']["noise_profile"][0]) > 0.005: |
| lsc_m = lsc ** ((float(meta_all['meta']["noise_profile"][0])-0.005) * 100) |
| lsc_inv = 1 / lsc |
| lsc_inv = np.mean(lsc_inv, axis=-1, keepdims=True) |
| |
| gray = cv.cvtColor(img.astype(np.float32), cv.COLOR_RGB2GRAY) |
| gray = gray[:, :, np.newaxis] |
|
|
| lsc_inv = lsc_inv * np.clip(gray*10, 0, 1) * np.clip((2 - (float(meta_all['meta']["noise_profile"][0])-0.005) * 100), 1, 2) |
| lsc_inv = np.clip(lsc_inv, 0.4, 1) |
|
|
| img = img * lsc_inv + gray * (1-lsc_inv) |
| img = img / lsc_m |
|
|
|
|
| |
| img = apply_tone_map(img) |
|
|
|
|
| |
| img = apply_gamma(img).astype(np.float32) |
| img = np.clip(img, 0, 1) |
| |
|
|
| |
| mm = np.mean(img) |
| meta_all['mm'] = mm |
| if mm <= 0.1: |
| pass |
| elif float(meta_all['meta']["noise_profile"][0]) > 0.01: |
| yuv = cv.cvtColor(img, cv.COLOR_BGR2YUV) |
| y, u, v = cv.split(yuv) |
| y = autocontrast_using_pil(y) |
| yuv = np.stack([y, u, v], axis=-1) |
| rgb = cv.cvtColor(yuv, cv.COLOR_YUV2BGR) |
| rgb = np.clip(rgb, 0, 1) |
|
|
| img = img * 0.5 + rgb * 0.5 |
| |
| img = np.clip(img*255, 0, 255).round().astype(np.uint8) |
|
|
| if float(meta_all['meta']["noise_profile"][0]) > 0.02: |
| noise_params = 6 |
| else: |
| noise_params = 3 |
| |
| img = cv.fastNlMeansDenoisingColored(img, None, noise_params, noise_params, 7, 21) |
| img = img.astype(np.float32) / 255. |
| |
| img = usm_sharp(img) |
| else: |
| img = autocontrast_using_pil(img) |
|
|
|
|
| |
| img = np.clip(img, 0, 1) |
| img_con = img ** (1/1.5) |
| gray = np.max(img_con, axis=-1, keepdims=True) |
| gray = np.clip(gray, 0.3, 1) |
| img = img_con * gray + img * (1-gray) |
|
|
|
|
| |
| img = img[:, :, ::-1] |
| gi = GraynessIndex() |
| pred_illum = gi.apply(img) |
| r, g, b = pred_illum |
| pred_illum = pred_illum / g |
| r, g, b = pred_illum |
| if r < 1: |
| img = white_balance(img, pred_illum) |
| img = img[:, :, ::-1] |
| img = np.clip(img, 0, 1) |
| |
| |
| |
| img = fix_orientation(img, meta_all['meta']["orientation"]) |
| |
| meta_all['img'] = img |
| return meta_all |
|
|
|
|
| def nn_enhancement(results): |
| checkpoint_path1 = "checkpoint/nn_enhance.pth" |
| device = torch.device("cuda") |
| model = get_net(NET, checkpoint_path1, device) |
|
|
|
|
| for meta_all in tqdm(results): |
| |
| |
| |
| |
|
|
| img = meta_all['img'] |
| img = img.astype(np.float32) |
| img = torch.from_numpy(img.copy().transpose(2, 0, 1)).unsqueeze(0).to(device) |
|
|
| with torch.no_grad(): |
| img = model(img) |
| |
|
|
| img = img.detach().cpu().squeeze(0).numpy().transpose(1, 2, 0) |
| img = np.clip(img, 0, 1) |
|
|
| img = img * 255. |
| img = img.round() |
| img = img.astype(np.uint8) |
| |
| meta_all['img'] = img |
|
|
|
|
| def post_process(meta_all): |
| |
| img = meta_all['img'] |
| |
| |
| |
| increment=0.5 |
| ori_img = img.copy() |
|
|
| hls = cv2.cvtColor(img, cv2.COLOR_RGB2HLS).astype(np.float32) |
| _, L, S = cv2.split(hls) |
| S = S / 255. |
| |
| img = img.astype(np.float32) |
|
|
| temp = increment + S |
| mask_2 = temp > 1 |
| alpha_1 = S |
| alpha_2 = 1 - increment |
| alpha = alpha_1 * mask_2 + alpha_2 * (1 - mask_2) |
| L = L[:, :, np.newaxis] |
| alpha = alpha[:, :, np.newaxis] |
| |
| alpha = 1/alpha -1 |
|
|
| img = img + (img - L) * alpha |
| |
| img = np.clip(img, 0, 255) |
| |
| ori_img = ori_img.astype(np.float32) |
| mask = ori_img[:, :, 2] / 255. |
| |
| mask = mask[:, :, np.newaxis] |
| mask = np.clip(mask - 0.1, 0, 1) |
| img = img * mask + ori_img * (1-mask) |
| img = np.clip(img, 0, 255).round().astype(np.uint8) |
| |
| |
| |
| |
| hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HLS) |
| hsv = hsv.astype(np.float32) |
| |
| |
| mmax = 105 |
| mmin = 40 |
| |
| alpha = 1 |
| beta = 4 |
| gamma = 0.1 |
|
|
| mid = mmin + ((mmax - mmin) / 2) |
| green_weight = np.abs(hsv[:, :, 0] - mid) / ((mmax - mmin)/2) |
| green_weight = np.clip(green_weight, 0, 1) |
| |
| green_weight = green_weight**beta + gamma |
| green_weight = np.clip(green_weight, 0, 1) |
| |
| green_weight = cv2.blur(green_weight, (11, 11)) |
| hsv[:, :, 2] = hsv[:, :, 2] * green_weight |
| |
| |
| |
| mmax = 180 |
| mmin = 130 |
| |
| alpha = 1 |
| beta = 8 |
| |
| gamma = -0.5 |
|
|
| mid = mmin + ((mmax - mmin) / 2) |
| green_weight = np.abs(hsv[:, :, 0] - mid) / ((mmax - mmin)/2) |
| green_weight = np.clip(green_weight, 0, 1) |
| |
| green_weight = (green_weight**beta + gamma) * 2 |
| green_weight = np.clip(green_weight, 0, 1) |
| |
| green_weight = cv2.blur(green_weight, (11, 11)) |
| hsv[:, :, 2] = hsv[:, :, 2] * green_weight |
| |
| |
| hsv = np.clip(hsv, 0, 255) |
| hsv = hsv.round().astype(np.uint8) |
|
|
| img = cv2.cvtColor(hsv, cv2.COLOR_HLS2RGB) |
| img = np.clip(img, 0, 255) |
|
|
|
|
| |
| |
| |
| img = np.clip(img, 0, 255).round().astype(np.uint8) |
| meta_all['img'] = img |
| return meta_all |
| |
|
|
|
|
| def sky_enhancement(results): |
| model_path = "sky_seg.pt" |
| model = torch.load(model_path) |
| model.cuda() |
| model.eval() |
|
|
| for meta_all in tqdm(results): |
| if float(meta_all['meta']["noise_profile"][0]) >= 0.005: |
| continue |
| |
| ori_img = meta_all['img'].copy().astype(np.float32) |
| |
| img = ori_img.copy() |
| img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
| |
| |
| |
| scene_image = img.copy().astype(np.float32) |
| |
| |
| img = img / 255. |
| lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) |
| |
| lab[:,:,1] = lab[:,:,1] - (lab[:,:,2] + 127) * 0.03 |
| lab[:,:,2] = lab[:,:,2] - (lab[:,:,2] + 127) * 0.1 |
| |
| img = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB) |
| img = img * 255 |
| img = np.clip(img, 0, 255).round().astype(np.float32) |
|
|
| |
| |
| img_mean = 0 |
| img_std = 255.0 |
| size = (512, 512) |
| img_h , img_w = img.shape[:2] |
|
|
| img = cv2.resize(img, size) |
| img = (img - img_mean) / img_std |
| img = np.transpose(img, [2, 0, 1]) |
| img = np.expand_dims(img, axis=0) |
| img = torch.from_numpy(img).cuda() |
|
|
| with torch.no_grad(): |
| mask = model(img) |
| |
| mask = mask.detach().cpu() |
| mask = mask.permute((0,3,1,2)) |
| mask = F.interpolate(mask, |
| size=[img_h , img_w], |
| mode='bilinear') |
| mask = mask[0].permute((1,2,0)) |
| sky_mask = torch.argmax(mask, axis=2).numpy().astype(np.float32) |
| |
| if sky_mask.max() < 0.1: |
| continue |
|
|
| |
| |
| img = ori_img.copy() |
| mask = img[:, :, 2] - np.max(img[:, :, :2], axis=-1) |
| mask[sky_mask==0]=0 |
| a = np.sum(mask) |
| b = np.sum(sky_mask) |
| ratio_blue = a/b |
| |
|
|
| |
| if ratio_blue < 10: |
| img = ori_img.copy() |
| mask = np.mean(img[:, :, :2], axis=-1) |
| mask[sky_mask==0]=0 |
| a = np.sum(mask) |
| b = np.sum(sky_mask) |
| ratio_light = a/b |
| |
| |
| |
| if ratio_light<50: |
| img = ori_img.copy() |
| img = img * 0.88 |
| img = np.clip(img, 0, 255) |
| |
| elif ratio_light < 200: |
| img = ori_img.copy() |
| img = img * 1.1 |
| img = np.clip(img, 0, 255) |
| else: |
| pass |
|
|
| hsv = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_RGB2HSV).astype(np.float32) |
| hsv[:, :, 1] = hsv[:, :, 1]* 0.4 |
| hsv = np.clip(hsv, 0, 255).astype(np.uint8) |
| img = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB).astype(np.float32) |
|
|
| |
| else: |
| |
| img = ori_img.copy() |
| img = img / 255. |
| lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB) |
|
|
| |
| lab[:,:,1] = lab[:,:,1] - (lab[:,:,2] + 127) * 0.03 |
| lab[:,:,2] = lab[:,:,2] - (lab[:,:,2] + 127) * 0.1 |
|
|
| |
| img = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB) |
|
|
| img = img * 255 |
| img = np.clip(img, 0, 255).round().astype(np.float32) |
| |
|
|
| sky_image = img.copy().astype(np.float32) |
| sky_image = cv2.cvtColor(sky_image, cv2.COLOR_RGB2BGR) |
| |
| sky_mask_ori = sky_mask.copy() |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (11, 11)) |
| sky_mask_ori = cv2.erode(sky_mask_ori, kernel) |
|
|
| sky_mask_ori = sky_mask_ori > 0.9 |
| |
| if np.sum(sky_mask_ori) > 0: |
| h, w = sky_mask.shape |
|
|
| sky_mask = cv2.resize(sky_mask, None, fx=0.1, fy=0.1, interpolation=cv2.INTER_NEAREST) |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) |
| sky_mask = cv2.dilate(sky_mask, kernel) |
| sky_mask_blur = cv2.blur(sky_mask, (21, 21)) |
| sky_mask_blur[sky_mask>0.5] = sky_mask[sky_mask>0.5] |
| sky_mask = sky_mask_blur |
| sky_mask = cv2.resize(sky_mask, (w, h), interpolation=cv2.INTER_LINEAR) |
| sky_mask = np.clip(sky_mask, 0.1, 1) |
|
|
| sky_area_img = np.zeros_like(sky_image) |
| sky_area_img[sky_mask_ori] = sky_image[sky_mask_ori] |
| sky_area_img = cv2.cvtColor(sky_area_img, cv2.COLOR_BGR2GRAY) |
| |
| sky_area_img_mean = np.sum(sky_area_img) / np.sum(sky_mask_ori) |
| if sky_area_img_mean > 20: |
| res = pyrblend(scene_image, sky_image, sky_mask) |
| res = np.clip(res, 0, 255) |
| |
| res = res.round().astype(np.uint8) |
| res = cv2.cvtColor(res, cv2.COLOR_BGR2RGB) |
| meta_all['img'] = res |
|
|
|
|
|
|
| def post_process2(meta_all): |
| |
| img = meta_all['img'].copy() |
| |
| img = img.astype(np.float32) / 255. |
| |
| yuv = cv.cvtColor(img, cv.COLOR_RGB2YUV) |
| y, u, v = cv.split(yuv) |
| y = autocontrast_using_pil(y) |
| yuv = np.stack([y, u, v], axis=-1) |
| rgb = cv.cvtColor(yuv, cv.COLOR_YUV2RGB) |
| rgb = np.clip(rgb, 0, 1) |
|
|
| img = rgb |
| |
| img = np.clip(img*255, 0, 255) |
|
|
|
|
| ori_img = meta_all['img'].copy().astype(np.float32) |
| mask = np.mean(ori_img, axis=-1) / 255. |
| mask = mask[:, :, np.newaxis] |
| mask = np.clip(mask - 0.1, 0, 1) |
| img = img * mask + ori_img * (1-mask) |
| img = np.clip(img, 0, 255) |
|
|
| img = img.round().astype(np.uint8) |
| meta_all['img'] = img |
| |
| return meta_all |
|
|
|
|
| def save_jpg(meta_all): |
| img = meta_all['img'] |
| out_path = os.path.join(output_path, meta_all['img_path'].split("/")[-1].split(".")[0] + ".jpg") |
|
|
| cv.imwrite(out_path, img[:, :, ::-1], [cv.IMWRITE_JPEG_QUALITY, 100]) |
| |
|
|
|
|
| if __name__ == "__main__": |
| num_worker = 4 |
| |
| all_time = time() |
| |
| input_path = "/data" |
| output_path = "/data" |
| |
| |
| os.makedirs(output_path, exist_ok=True) |
|
|
| |
| |
| s_time = time() |
| input_list = sorted(glob(os.path.join(input_path, "*.png"))) |
| |
| if num_worker > 1: |
| with mp.Pool(num_worker) as pool: |
| results = list(tqdm(pool.imap(load_img, input_list), total=len(input_list))) |
| else: |
| results = [] |
| for p in tqdm(input_list): |
| results.append(load_img(p)) |
| load_time = time()-s_time |
| print("load_img time is: ", load_time) |
| |
|
|
| |
| s_time = time() |
| iso_list = [50, 125, 320, 640, 800] |
| lsc_npy = [np.load("./lsc_npy/{}.npy".format(iso)) for iso in iso_list] |
| |
| if num_worker > 1: |
| with mp.Pool(num_worker) as pool: |
| results = list(tqdm(pool.imap(pre_process, results), total=len(results))) |
| else: |
| for r in tqdm(results): |
| r = pre_process(r) |
| del lsc_npy |
| print("pre_process time is: ", time()-s_time) |
|
|
|
|
| |
| s_time = time() |
| raw_denoise(results) |
| print("raw_denoise time is: ", time()-s_time) |
| |
|
|
| |
| s_time = time() |
| predict_white_balance(results) |
| print("predict_white_balance time is: ", time()-s_time) |
|
|
|
|
| |
| s_time = time() |
| color_matrix = [1.06835938, -0.29882812, -0.14257812, -0.43164062, 1.35546875, 0.05078125, -0.1015625, 0.24414062, 0.5859375] |
| lsc = np.load("lsc.npy") |
| if num_worker > 1: |
| with mp.Pool(num_worker) as pool: |
| results = list(tqdm(pool.imap(convert_to_rgb, results), total=len(results))) |
| else: |
| for r in tqdm(results): |
| r = convert_to_rgb(r) |
| del lsc |
| print("convert_to_rgb time is: ", time()-s_time) |
|
|
|
|
| |
| s_time = time() |
| nn_enhancement(results) |
| print("nn_enhancement time is: ", time()-s_time) |
| |
| |
| |
| |
| s_time = time() |
| kernel = cv.getStructuringElement(cv.MORPH_RECT, (5, 5)) |
| if num_worker > 1: |
| with mp.Pool(num_worker) as pool: |
| results = list(tqdm(pool.imap(post_process, results), total=len(results))) |
| else: |
| for r in tqdm(results): |
| r = post_process(r) |
| print("post_process time is: ", time()-s_time) |
|
|
| |
|
|
| |
| s_time = time() |
| sky_enhancement(results) |
| print("sky_enhancement time is: ", time()-s_time) |
| |
| |
| |
| |
| s_time = time() |
| if num_worker > 1: |
| with mp.Pool(num_worker) as pool: |
| results = list(tqdm(pool.imap(post_process2, results), total=len(results))) |
| else: |
| for r in tqdm(results): |
| r = post_process2(r) |
| print("post_process2 time is: ", time()-s_time) |
| |
|
|
|
|
| |
| s_time = time() |
| if num_worker > 1: |
| with mp.Pool(num_worker) as pool: |
| _ = list(tqdm(pool.imap(save_jpg, results), total=len(results))) |
| else: |
| for r in tqdm(results): |
| save_jpg(r) |
| save_time = time()-s_time |
| print("save_jpg time is: ", save_time) |
|
|
|
|
| total_time = time()-all_time |
| total_time_without_load_save = total_time - load_time - save_time |
| print("per image inference time (without load and save) is: ", total_time_without_load_save / len(results), "s") |