| import os |
| import time |
| import argparse |
| import torch |
| import torch.backends.cudnn as cudnn |
| from utils_ours.util import setup_logger, print_args |
| from torch.utils.data import DataLoader |
| from dataloader.dataset import imageSet |
| from models.archs.NAF_arch import NAF_Video |
| from torch.nn.parallel import DistributedDataParallel |
| import numpy as np |
| import torch.nn.functional as F |
| from collections import OrderedDict |
| import torch.nn as nn |
| from models.utils import chunkV3 |
| import pdb |
| from ISP_pipeline import process_pngs_isp |
| import os |
| import json |
| import cv2 |
| from skimage import io |
|
|
| ISO = [50,125,320,640,800] |
| a = [0.00025822882,0.000580020745,0.00141667975,0.00278965863,0.00347614807] |
| b = [2.32350645e-06,3.1125155625e-06,8.328992952e-06,3.3315971808e-05,5.205620595e-05] |
|
|
| |
| coeff_a = np.polyfit(ISO,a,1) |
| coeff_b = np.polyfit(ISO,b,2) |
|
|
| def main(): |
|
|
| parser = argparse.ArgumentParser(description='imageTest') |
|
|
| parser.add_argument('--frame', default=1, type=int) |
| parser.add_argument('--test_dir', default = "/data/", type=str) |
| parser.add_argument('--model_type', type=str, default='NAF_Video') |
| parser.add_argument('--save_folder', default='/data/', type=str) |
| parser.add_argument('--resume', default='', type=str) |
| parser.add_argument('--testoption', default='image', type=str) |
| parser.add_argument('--chunk', action='store_true') |
| parser.add_argument('--debug', action='store_true') |
| |
| args = parser.parse_args() |
| args.src_save_folder = '/data/' |
|
|
| print(args.src_save_folder,'**********************') |
| if not os.path.exists(args.src_save_folder): |
| os.makedirs(args.src_save_folder) |
| print(args.src_save_folder) |
|
|
| low_iso_model = "denoise_model/low_iso.pth" |
| mid_iso_model = "denoise_model/mid_iso.pth" |
| high_mid_iso_model = "denoise_model/high_mid_iso.pth" |
| high_iso_model = "denoise_model/high_iso.pth" |
| |
| network = NAF_Video(args).cuda() |
|
|
| load_low_iso_net = torch.load(low_iso_model, map_location=torch.device('cuda')) |
| load_low_iso_net_clean = OrderedDict() |
| for k, v in load_low_iso_net.items(): |
| if k.startswith('module.'): |
| load_low_iso_net_clean[k[7:]] = v |
| else: |
| load_low_iso_net_clean[k] = v |
|
|
| load_mid_iso_net = torch.load(mid_iso_model, map_location=torch.device('cpu')) |
| load_mid_iso_net_clean = OrderedDict() |
| for k, v in load_mid_iso_net.items(): |
| if k.startswith('module.'): |
| load_mid_iso_net_clean[k[7:]] = v |
| else: |
| load_mid_iso_net_clean[k] = v |
|
|
| load_high_mid_iso_net = torch.load(high_mid_iso_model, map_location=torch.device('cpu')) |
| load_high_mid_iso_net_clean = OrderedDict() |
| for k, v in load_high_mid_iso_net.items(): |
| if k.startswith('module.'): |
| load_high_mid_iso_net_clean[k[7:]] = v |
| else: |
| load_high_mid_iso_net_clean[k] = v |
|
|
| load_high_iso_net_clean = torch.load(high_iso_model, map_location=torch.device('cpu')) |
| |
| cudnn.benchmark = True |
|
|
| test_dataset = imageSet(args) |
| test_dataloader = DataLoader(test_dataset, batch_size=1, num_workers=0, shuffle=False) |
| inference_time = [] |
| with torch.no_grad(): |
|
|
| for data in test_dataloader: |
|
|
| noise = data['input'].cuda() |
| json_path = data['json_path'][0] |
| scene_name = os.path.splitext(os.path.basename(json_path))[0] |
| |
| |
| json_cfa = process_pngs_isp.readjson(json_path) |
| num_k = json_cfa['noise_profile'] |
| iso = (num_k[0] - coeff_a[1])/coeff_a[0] |
| |
| if iso < 900: |
| network.load_state_dict(load_low_iso_net_clean, strict=True) |
| network.eval() |
| elif iso < 1800: |
| network.load_state_dict(load_mid_iso_net_clean, strict=True) |
| network.eval() |
| elif iso < 5600: |
| network.load_state_dict(load_high_mid_iso_net_clean, strict=True) |
| network.eval() |
| else: |
| network.load_state_dict(load_high_iso_net_clean, strict=True) |
| network.eval() |
| |
| t0 = time.perf_counter() |
|
|
| out = chunkV3(network, noise, args.testoption, patch_h=1024, patch_w=1024) |
| out = torch.clamp(out, 0., 1.) |
|
|
| |
| name_rgb = os.path.join(args.src_save_folder, scene_name + '.jpg') |
|
|
| if not os.path.exists(os.path.dirname(name_rgb)): |
| os.makedirs(os.path.dirname(name_rgb)) |
|
|
| out = out[0] |
| del noise |
| torch.cuda.empty_cache() |
|
|
| img_pro = process_pngs_isp.isp_night_imaging(out, json_cfa, iso, |
| do_demosaic = True, |
|
|
| do_channel_gain_white_balance = True, |
| do_xyz_transform = True, |
| do_srgb_transform = True, |
|
|
| do_gamma_correct = True, |
|
|
| do_refinement = True, |
| do_to_uint8 = True, |
|
|
| do_resize_using_pil = True, |
| do_fix_orientation = True |
| ) |
| |
| t1 = time.perf_counter() |
| inference_time.append(t1-t0) |
| img_pro = cv2.cvtColor(img_pro, cv2.COLOR_RGB2BGR) |
| cv2.imwrite(name_rgb, img_pro, [cv2.IMWRITE_PNG_COMPRESSION, 0]) |
|
|
| print("Inference {} in {:.3f}s".format(scene_name, t1 - t0)) |
| print(f"Average inference time: {np.mean(inference_time)} seconds") |
|
|
| if __name__ == '__main__': |
| main() |