NPRC24 / DH-AISP /2 /test.py
Neverlios's picture
dh-aisp
bd1c686 verified
import torch
import argparse
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.utils import save_image as imwrite
import os
import time
import re
from torchvision import transforms
from test_dataset_for_testing import dehaze_test_dataset
from model_convnext2_hdr import fusion_net
import glob
import scipy.io
import torch.optim as optim
import cv2
import matplotlib.image
from PIL import Image
import random
import math
import numpy as np
import sys
import json
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
#run python test_05_hdr.py ./data/ ./result/ ./daylight_isp_03/ 1 2 4
input_dir2 = '../data/'
input_dir = '../1/mid/'
result_dir = '../data/'
checkpoint_dir = './result_low_light_hdr/'
# get train IDs
train_fns = glob.glob(input_dir + '*_1.png')
train_ids = [os.path.basename(train_fn) for train_fn in train_fns]
if not os.path.exists(result_dir):
os.mkdir(result_dir)
def json_read(fname, **kwargs):
with open(fname) as j:
data = json.load(j, **kwargs)
return data
def fraction_from_json(json_object):
if 'Fraction' in json_object:
return Fraction(*json_object['Fraction'])
return json_object
def fractions2floats(fractions):
floats = []
for fraction in fractions:
floats.append(float(fraction.numerator) / fraction.denominator)
return floats
def reprocessing(input):
output = np.zeros(input.shape)
input_1 = input
output[:,:,0] = input_1[:,:,0] * 1.9021 - input_1[:,:,1] * 1.1651 + input_1[:,:,2] * 0.2630
output[:,:,1] = input_1[:,:,0] * (-0.3189) + input_1[:,:,1] * 1.5831 - input_1[:,:,2] * 0.2643
output[:,:,2] = input_1[:,:,0] * (-0.0662) - input_1[:,:,1] * 0.9350 + input_1[:,:,2] * 2.0013
result = np.clip(output, 0, 255).astype(np.uint8)
return output
def reprocessing1(input):
output = np.zeros(input.shape)
input_1 = input
output[:,:,0] = input_1[:,:,0] * 1.521689 - input_1[:,:,1] * 0.673763 + input_1[:,:,2] * 0.152074
output[:,:,1] = input_1[:,:,0] * (-0.145724) + input_1[:,:,1] * 1.266507 - input_1[:,:,2] * 0.120783
output[:,:,2] = input_1[:,:,0] * (-0.0397583) - input_1[:,:,1] * 0.561249 + input_1[:,:,2] * 1.60100734
result = np.clip(output, 0, 255).astype(np.uint8)
return output
# --- Gpu device --- #
device = torch.device("cuda:0")
# --- Define the network --- #
model_g = fusion_net()
model_g = nn.DataParallel(model_g)
MyEnsembleNet = model_g.to(device)
MyEnsembleNet.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'checkpoint_gen.pth')))
# --- Test --- #
with torch.no_grad():
MyEnsembleNet.eval()
for ind in range(len(train_ids)):
print(ind)
train_id = train_ids[ind]
in_path_in = input_dir + train_id[:-5]
in_path_in_js = input_dir2 + train_id[:-5]
metadata = json_read(in_path_in_js[:-1] + '.json', object_hook=fraction_from_json)
noise_profile = float(metadata['noise_profile'][0])
pic_in1 = np.asarray(Image.open(in_path_in + '1.png'), np.float32) / 255.
pic_in2 = np.asarray(Image.open(in_path_in + '2.png'), np.float32) / 255.
pic_in3 = np.asarray(Image.open(in_path_in + '3.png'), np.float32) / 255.
pic_in = np.concatenate([pic_in1, pic_in2, pic_in3],axis=2)
#pic_in = cv2.resize(pic_in, None, fx = 0.5, fy = 0.5, interpolation=cv2.INTER_CUBIC )
[h,w,c] = pic_in.shape
pad_h = 32 - h % 32
pad_w = 32 - w % 32
pic_in = np.expand_dims(np.pad(pic_in, ((0, pad_h), (0, pad_w),(0,0)), mode='reflect'),axis = 0)
in_data = torch.from_numpy(pic_in).permute(0,3,1,2).to(device)
out_data = MyEnsembleNet(in_data)
out_datass = out_data.cpu().detach().numpy().transpose((0, 2, 3, 1))
output = np.clip(out_datass[0,:,:,:], 0, 1)
if noise_profile < 0.02:
output = reprocessing(output)
else:
output = reprocessing1(output)
#cv2.imwrite(result_dir + train_id[:-6] + '.png', output[0:h,0:w,::-1] * 255)
cv2.imwrite(result_dir + train_id[:-6] + '.jpg', output[0:h,0:w,::-1] * 255, [cv2.IMWRITE_JPEG_QUALITY, 100])