| import argparse
|
| from fileinput import filename
|
| from locale import locale_encoding_alias
|
| import torch
|
| import torch.nn as nn
|
| from Deraining.network.Math_Module import P, Q
|
| from Deraining.network.decom import Decom
|
| import os
|
| import torchvision
|
| import torchvision.transforms as transforms
|
| from PIL import Image
|
| import time
|
| from utils import *
|
| import glob
|
|
|
| """
|
| As different illumination adjustment ratio will cause
|
| different enhanced results. Certainly you can tune the ratio youself
|
| to get the best results.
|
| To get better result, we use the illumination of normal light image
|
| to adaptively generate ratio.
|
| Noted that KinD and KinD++ also use ratio to guide the illumination adjustment,
|
| for fair comparison, the ratio of their methods also generate by the illumination
|
| of normal light image.
|
| """
|
|
|
| def one2three(x):
|
| return torch.cat([x, x, x], dim=1).to(x)
|
|
|
| class Inference(nn.Module):
|
| def __init__(self, opts):
|
| super().__init__()
|
| self.opts = opts
|
|
|
| self.model_Decom_low = Decom()
|
| self.model_Decom_high = Decom()
|
| self.model_Decom_low = load_initialize(self.model_Decom_low, self.opts.Decom_model_low_path)
|
| self.model_Decom_high = load_initialize(self.model_Decom_high, self.opts.Decom_model_high_path)
|
|
|
| self.unfolding_opts, self.model_R, self.model_L= load_unfolding(self.opts.unfolding_model_path)
|
|
|
| self.adjust_model = load_adjustment(self.opts.adjust_model_path)
|
| self.P = P()
|
| self.Q = Q()
|
| transform = [
|
| transforms.ToTensor(),
|
| ]
|
| self.transform = transforms.Compose(transform)
|
| print(self.model_Decom_low)
|
| print(self.model_R)
|
| print(self.model_L)
|
| print(self.adjust_model)
|
|
|
|
|
| def get_ratio(self, high_l, low_l):
|
| ratio = (low_l / (high_l + 0.0001)).mean()
|
| low_ratio = torch.ones(high_l.shape).cuda() * (1/(ratio+0.0001))
|
| return low_ratio
|
|
|
| def unfolding(self, input_low_img):
|
| for t in range(self.unfolding_opts.round):
|
| if t == 0:
|
| P, Q = self.model_Decom_low(input_low_img)
|
| else:
|
| w_p = (self.unfolding_opts.gamma + self.unfolding_opts.Roffset * t)
|
| w_q = (self.unfolding_opts.lamda + self.unfolding_opts.Loffset * t)
|
| P = self.P(I=input_low_img, Q=Q, R=R, gamma=w_p)
|
| Q = self.Q(I=input_low_img, P=P, L=L, lamda=w_q)
|
| R = self.model_R(r=P, l=Q)
|
| L = self.model_L(l=Q)
|
| return R, L
|
|
|
| def lllumination_adjust(self, L, ratio):
|
| ratio = torch.ones(L.shape).cuda() * ratio
|
| return self.adjust_model(l=L, alpha=ratio)
|
|
|
| def forward(self, input_low_img, input_high_img):
|
| if torch.cuda.is_available():
|
| input_low_img = input_low_img.cuda()
|
| input_high_img = input_high_img.cuda()
|
| with torch.no_grad():
|
| start = time.time()
|
| R, L = self.unfolding(input_low_img)
|
|
|
| _, high_L = self.model_Decom_high(input_high_img)
|
| ratio = self.get_ratio(high_L, L)
|
| High_L = self.lllumination_adjust(L, ratio)
|
| I_enhance = High_L * R
|
| p_time = (time.time() - start)
|
| return I_enhance, p_time
|
|
|
| def evaluate(self):
|
| low_files = glob.glob(self.opts.low_dir+"/*.png")
|
| for file in low_files:
|
| file_name = os.path.basename(file)
|
| name = file_name.split('.')[0]
|
| high_file = os.path.join(self.opts.high_dir, file_name)
|
| low_img = self.transform(Image.open(file)).unsqueeze(0)
|
| high_img = self.transform(Image.open(high_file)).unsqueeze(0)
|
| enhance, p_time = self.forward(low_img, high_img)
|
| if not os.path.exists(self.opts.output):
|
| os.makedirs(self.opts.output)
|
| save_path = os.path.join(self.opts.output, file_name.replace(name, "%s_URetinexNet"%(name)))
|
| np_save_TensorImg(enhance, save_path)
|
| print("================================= time for %s: %f============================"%(file_name, p_time))
|
|
|
|
|
| def unfolding(self, input_low_img):
|
| for t in range(self.unfolding_opts.round):
|
| if t == 0:
|
| P, Q = self.model_Decom_low(input_low_img)
|
| else:
|
| w_p = (self.unfolding_opts.gamma + self.unfolding_opts.Roffset * t)
|
| w_q = (self.unfolding_opts.lamda + self.unfolding_opts.Loffset * t)
|
| P = self.P(I=input_low_img, Q=Q, R=R, gamma=w_p)
|
| Q = self.Q(I=input_low_img, P=P, L=L, lamda=w_q)
|
| R = self.model_R(r=P, l=Q)
|
| L = self.model_L(l=Q)
|
| return R, L
|
|
|
| if __name__ == "__main__":
|
| parser = argparse.ArgumentParser(description='Configure')
|
|
|
| parser.add_argument('--low_dir', type=str, default="./test_daat/LOLdataset/eval15/low")
|
| parser.add_argument('--high_dir', type=str, default="./test_data/LOLdataset/eval15/high")
|
| parser.add_argument('--output', type=str, default="./demo/output/LOL")
|
|
|
|
|
| parser.add_argument('--Decom_model_low_path', type=str, default="./ckpt/init_low.pth")
|
| parser.add_argument('--Decom_model_high_path', type=str, default="./ckpt/init_high.pth")
|
| parser.add_argument('--unfolding_model_path', type=str, default="./ckpt/unfolding.pth")
|
| parser.add_argument('--adjust_model_path', type=str, default="./ckpt/L_adjust.pth")
|
| parser.add_argument('--gpu_id', type=int, default=0)
|
|
|
| opts = parser.parse_args()
|
| for k, v in vars(opts).items():
|
| print(k, v)
|
|
|
| os.environ['CUDA_VISIBLE_DEVICES'] = str(opts.gpu_id)
|
| model = Inference(opts).cuda()
|
| model.evaluate()
|
|
|