File size: 6,238 Bytes
4336727
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import argparse
from fileinput import filename
from locale import locale_encoding_alias
import torch
import torch.nn as nn
from Deraining.network.Math_Module import P, Q
from Deraining.network.decom import Decom
import os
import torchvision
import torchvision.transforms as transforms
from PIL import Image
import time
from utils import *
import glob

"""

  As different illumination adjustment ratio will cause

different enhanced results. Certainly you can tune the ratio youself

to get the best results.

  To get better result, we use the illumination of normal light image 

to adaptively generate ratio.

  Noted that KinD and KinD++ also use ratio to guide the illumination adjustment,

for fair comparison, the ratio of their methods also generate by the illumination

of normal light image.

"""

def one2three(x):
    return torch.cat([x, x, x], dim=1).to(x)

class Inference(nn.Module):
    def __init__(self, opts):
        super().__init__()
        self.opts = opts
        # loading decomposition model 
        self.model_Decom_low = Decom()
        self.model_Decom_high = Decom()
        self.model_Decom_low = load_initialize(self.model_Decom_low, self.opts.Decom_model_low_path)
        self.model_Decom_high = load_initialize(self.model_Decom_high, self.opts.Decom_model_high_path)
        # loading R; old_model_opts; and L model
        self.unfolding_opts, self.model_R, self.model_L= load_unfolding(self.opts.unfolding_model_path)
        # loading adjustment model
        self.adjust_model = load_adjustment(self.opts.adjust_model_path)
        self.P = P()
        self.Q = Q()
        transform = [
            transforms.ToTensor(),
        ]
        self.transform = transforms.Compose(transform)
        print(self.model_Decom_low)
        print(self.model_R)
        print(self.model_L)
        print(self.adjust_model)
        #time.sleep(8)
        
    def get_ratio(self, high_l, low_l):
        ratio = (low_l / (high_l + 0.0001)).mean()
        low_ratio = torch.ones(high_l.shape).cuda() * (1/(ratio+0.0001))
        return low_ratio

    def unfolding(self, input_low_img):
        for t in range(self.unfolding_opts.round):      
            if t == 0: # initialize R0, L0
                P, Q = self.model_Decom_low(input_low_img)
            else: # update P and Q
                w_p = (self.unfolding_opts.gamma + self.unfolding_opts.Roffset * t)
                w_q = (self.unfolding_opts.lamda + self.unfolding_opts.Loffset * t)
                P = self.P(I=input_low_img, Q=Q, R=R, gamma=w_p)
                Q = self.Q(I=input_low_img, P=P, L=L, lamda=w_q) 
            R = self.model_R(r=P, l=Q)
            L = self.model_L(l=Q)
        return R, L
    
    def lllumination_adjust(self, L, ratio):
        ratio = torch.ones(L.shape).cuda() * ratio
        return self.adjust_model(l=L, alpha=ratio)
    
    def forward(self, input_low_img, input_high_img):
        if torch.cuda.is_available():
            input_low_img = input_low_img.cuda()
            input_high_img = input_high_img.cuda()
        with torch.no_grad():
            start = time.time()  
            R, L = self.unfolding(input_low_img)
            # the ratio is calculated using the decomposed normal illumination
            _, high_L = self.model_Decom_high(input_high_img)
            ratio = self.get_ratio(high_L, L)
            High_L = self.lllumination_adjust(L, ratio)
            I_enhance = High_L * R
            p_time = (time.time() - start)
        return I_enhance, p_time

    def evaluate(self):
        low_files = glob.glob(self.opts.low_dir+"/*.png")
        for file in low_files:
            file_name = os.path.basename(file)
            name = file_name.split('.')[0]
            high_file = os.path.join(self.opts.high_dir, file_name)
            low_img = self.transform(Image.open(file)).unsqueeze(0)
            high_img = self.transform(Image.open(high_file)).unsqueeze(0)
            enhance, p_time = self.forward(low_img, high_img)
            if not os.path.exists(self.opts.output):
                os.makedirs(self.opts.output)
            save_path = os.path.join(self.opts.output, file_name.replace(name, "%s_URetinexNet"%(name)))
            np_save_TensorImg(enhance, save_path)  
            print("================================= time for %s: %f============================"%(file_name, p_time))


def unfolding(self, input_low_img):
        for t in range(self.unfolding_opts.round):
            if t == 0: # initialize R0, L0
                P, Q = self.model_Decom_low(input_low_img)
            else: # update P and Q
                w_p = (self.unfolding_opts.gamma + self.unfolding_opts.Roffset * t)
                w_q = (self.unfolding_opts.lamda + self.unfolding_opts.Loffset * t)
                P = self.P(I=input_low_img, Q=Q, R=R, gamma=w_p)
                Q = self.Q(I=input_low_img, P=P, L=L, lamda=w_q)
            R = self.model_R(r=P, l=Q)
            L = self.model_L(l=Q)
        return R, L
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Configure')
    # specify your data path here!
    parser.add_argument('--low_dir', type=str, default="./test_daat/LOLdataset/eval15/low")
    parser.add_argument('--high_dir', type=str, default="./test_data/LOLdataset/eval15/high")
    parser.add_argument('--output', type=str, default="./demo/output/LOL")
    # ratio are recommended to be 3-5, bigger ratio will lead to over-exposure 
    # model path
    parser.add_argument('--Decom_model_low_path', type=str, default="./ckpt/init_low.pth")
    parser.add_argument('--Decom_model_high_path', type=str, default="./ckpt/init_high.pth")
    parser.add_argument('--unfolding_model_path', type=str, default="./ckpt/unfolding.pth")
    parser.add_argument('--adjust_model_path', type=str, default="./ckpt/L_adjust.pth")
    parser.add_argument('--gpu_id', type=int, default=0)
    
    opts = parser.parse_args()
    for k, v in vars(opts).items():
        print(k, v)
    
    os.environ['CUDA_VISIBLE_DEVICES'] = str(opts.gpu_id)
    model = Inference(opts).cuda()
    model.evaluate()