| import torch |
| import argparse |
| import torch.nn as nn |
| from torch.utils.data import DataLoader |
| from torchvision.utils import save_image as imwrite |
| import os |
| import time |
| import re |
| from torchvision import transforms |
|
|
| from test_dataset_for_testing import dehaze_test_dataset |
| from model_convnext2_hdr import fusion_net |
| import glob |
| import scipy.io |
| import torch.optim as optim |
| import cv2 |
| import matplotlib.image |
| from PIL import Image |
| import random |
| import math |
| import numpy as np |
| import sys |
| import json |
|
|
| os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
|
|
| |
|
|
|
|
| input_dir2 = '../data/' |
| input_dir = '../1/mid/' |
|
|
| result_dir = '../data/' |
| checkpoint_dir = './result_low_light_hdr/' |
|
|
| |
| train_fns = glob.glob(input_dir + '*_1.png') |
| train_ids = [os.path.basename(train_fn) for train_fn in train_fns] |
|
|
| if not os.path.exists(result_dir): |
| os.mkdir(result_dir) |
|
|
| def json_read(fname, **kwargs): |
| with open(fname) as j: |
| data = json.load(j, **kwargs) |
| return data |
|
|
| def fraction_from_json(json_object): |
| if 'Fraction' in json_object: |
| return Fraction(*json_object['Fraction']) |
| return json_object |
|
|
| def fractions2floats(fractions): |
| floats = [] |
| for fraction in fractions: |
| floats.append(float(fraction.numerator) / fraction.denominator) |
| return floats |
|
|
| def reprocessing(input): |
| output = np.zeros(input.shape) |
| |
| input_1 = input |
|
|
| output[:,:,0] = input_1[:,:,0] * 1.9021 - input_1[:,:,1] * 1.1651 + input_1[:,:,2] * 0.2630 |
| output[:,:,1] = input_1[:,:,0] * (-0.3189) + input_1[:,:,1] * 1.5831 - input_1[:,:,2] * 0.2643 |
| output[:,:,2] = input_1[:,:,0] * (-0.0662) - input_1[:,:,1] * 0.9350 + input_1[:,:,2] * 2.0013 |
|
|
| result = np.clip(output, 0, 255).astype(np.uint8) |
|
|
| return output |
|
|
| def reprocessing1(input): |
| output = np.zeros(input.shape) |
| |
| input_1 = input |
|
|
| output[:,:,0] = input_1[:,:,0] * 1.521689 - input_1[:,:,1] * 0.673763 + input_1[:,:,2] * 0.152074 |
| output[:,:,1] = input_1[:,:,0] * (-0.145724) + input_1[:,:,1] * 1.266507 - input_1[:,:,2] * 0.120783 |
| output[:,:,2] = input_1[:,:,0] * (-0.0397583) - input_1[:,:,1] * 0.561249 + input_1[:,:,2] * 1.60100734 |
|
|
| result = np.clip(output, 0, 255).astype(np.uint8) |
|
|
| return output |
|
|
| |
| device = torch.device("cuda:0") |
|
|
| |
|
|
| model_g = fusion_net() |
|
|
| model_g = nn.DataParallel(model_g) |
|
|
| MyEnsembleNet = model_g.to(device) |
|
|
| MyEnsembleNet.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'checkpoint_gen.pth'))) |
|
|
| |
| with torch.no_grad(): |
| MyEnsembleNet.eval() |
|
|
| for ind in range(len(train_ids)): |
| print(ind) |
| train_id = train_ids[ind] |
| in_path_in = input_dir + train_id[:-5] |
| in_path_in_js = input_dir2 + train_id[:-5] |
| metadata = json_read(in_path_in_js[:-1] + '.json', object_hook=fraction_from_json) |
|
|
| noise_profile = float(metadata['noise_profile'][0]) |
|
|
| pic_in1 = np.asarray(Image.open(in_path_in + '1.png'), np.float32) / 255. |
| pic_in2 = np.asarray(Image.open(in_path_in + '2.png'), np.float32) / 255. |
| pic_in3 = np.asarray(Image.open(in_path_in + '3.png'), np.float32) / 255. |
|
|
| pic_in = np.concatenate([pic_in1, pic_in2, pic_in3],axis=2) |
| |
| |
| [h,w,c] = pic_in.shape |
| |
| pad_h = 32 - h % 32 |
| pad_w = 32 - w % 32 |
|
|
| pic_in = np.expand_dims(np.pad(pic_in, ((0, pad_h), (0, pad_w),(0,0)), mode='reflect'),axis = 0) |
| |
| in_data = torch.from_numpy(pic_in).permute(0,3,1,2).to(device) |
| out_data = MyEnsembleNet(in_data) |
| out_datass = out_data.cpu().detach().numpy().transpose((0, 2, 3, 1)) |
| output = np.clip(out_datass[0,:,:,:], 0, 1) |
| |
| if noise_profile < 0.02: |
| output = reprocessing(output) |
| else: |
| output = reprocessing1(output) |
| |
| |
| cv2.imwrite(result_dir + train_id[:-6] + '.jpg', output[0:h,0:w,::-1] * 255, [cv2.IMWRITE_JPEG_QUALITY, 100]) |
| |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|