File size: 3,837 Bytes
defb4d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import numpy as np
import os
import argparse
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F

from basicsr.models import create_model
from basicsr.utils.options import parse
from skimage import img_as_ubyte
from torch.cuda.amp import autocast, GradScaler
def load_img(filepath):
    return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB)
def save_img(filepath, img):
    cv2.imwrite(filepath, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
def self_ensemble(x, model):
    def forward_transformed(x, hflip, vflip, rotate, model):
        if hflip:
            x = torch.flip(x, (-2,))
        if vflip:
            x = torch.flip(x, (-1,))
        if rotate:
            x = torch.rot90(x, dims=(-2, -1))
        x = model(x)
        if rotate:
            x = torch.rot90(x, dims=(-2, -1), k=3)
        if vflip:
            x = torch.flip(x, (-1,))
        if hflip:
            x = torch.flip(x, (-2,))
        return x
    t = []
    for hflip in [False, True]:
        for vflip in [False, True]:
            for rot in [False, True]:
                t.append(forward_transformed(x, hflip, vflip, rot, model))
    t = torch.stack(t)
    return torch.mean(t, dim=0)


# Set GPU
gpu_list = ','.join(str(x) for x in '0')
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)

# Load YAML configuration
opt = parse("RetinexFormer_FiveK.yml", is_train=False)
opt['dist'] = False
print(opt)

# Load model
model_restoration = create_model(opt).net_g
checkpoint = torch.load("pretrained_weights/FiveK.pth")
model_restoration.load_state_dict(checkpoint['params'])
print("===>Testing using weights: ", "pretrained_weights/FiveK.pth")
model_restoration.cuda()
model_restoration = nn.DataParallel(model_restoration)
model_restoration.eval()


def process_image(inp_path, model_restoration, out_dir, factor=4):
    torch.cuda.ipc_collect()
    torch.cuda.empty_cache()

    img = np.float32(load_img(inp_path)) / 255.


    # Resize image to have height 1024px while maintaining aspect ratio
    max_height = 1024
    aspect_ratio = img.shape[1] / img.shape[0]
    new_width = int(max_height * aspect_ratio)
    img = cv2.resize(img, (new_width, max_height), interpolation=cv2.INTER_AREA)


    img = torch.from_numpy(img).permute(2, 0, 1)
    input_ = img.unsqueeze(0).cuda()

    # Padding in case images are not multiples of 4
    b, c, h, w = input_.shape
    H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
    padh = H - h if h % factor != 0 else 0
    padw = W - w if w % factor != 0 else 0
    input_ = F.pad(input_, (0, padw, 0, padh), 'reflect')

    scaler = GradScaler()  # for mixed precision

    with autocast():  # enable mixed precision
        if h < 3000 and w < 3000:
            if 1 == 0:
                restored = self_ensemble(input_, model_restoration)
            else:
                restored = model_restoration(input_)
        else:
            # split and test
            input_1 = input_[:, :, :, 1::2]
            input_2 = input_[:, :, :, 0::2]
            if 1 == 0:
                restored_1 = self_ensemble(input_1, model_restoration)
                restored_2 = self_ensemble(input_2, model_restoration)
            else:
                restored_1 = model_restoration(input_1)
                restored_2 = model_restoration(input_2)
            restored = torch.zeros_like(input_)
            restored[:, :, :, 1::2] = restored_1
            restored[:, :, :, 0::2] = restored_2

        # Unpad images to original dimensions
        restored = restored[:, :, :h, :w]

    restored = torch.clamp(restored, 0, 1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy()


    if True:
       save_img(os.path.join(out_dir, os.path.splitext(os.path.split(inp_path)[-1])[0] + '.png'), img_as_ubyte(restored))