File size: 3,799 Bytes
b77fd1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca337fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import torch
import torch.nn.functional as F
import torchvision.transforms.v2 as transforms
from models.base_model import UNET

def find_padding(img, depth=2**4):
    B, C, H, W = img.shape

    h_pad = (depth - H % depth) % depth
    w_pad = (depth - W % depth) % depth
    return h_pad, w_pad

def get_pretrained_path(model_name):
    # 'SRUNET_x2', 'SRUNET_x3', 'SRUNET_x4', 'SRUNET_x234', 'SRUNET_interpolation', 'SRUNET_x234_interpolation'

    current_path = os.path.dirname(os.path.abspath(__file__)).replace("\\", "/")
    if model_name == 'SRUNET_x2':
        return current_path + '/pretrained/SRUNET_scale_x2.pt'
    elif model_name == 'SRUNET_x3':
        return current_path + '/pretrained/SRUNET_scale_x3.pt'
    elif model_name == 'SRUNET_x4':
        return current_path + '/pretrained/SRUNET_scale_x4.pt'
    elif model_name == 'SRUNET_x234':
        return current_path + '/pretrained/SRUNET_scale_x234.pt'
    # elif model_name == 'SRUNET_interpolation':
    #     return current_path + '/pretrained/SRUNET_x3.pt'
    # elif model_name == 'SRUNET_x234_interpolation':
    #     return current_path + '/pretrained/SRUNET_x3.pt'
    else:
        raise Exception('Model not found')


def upscale_image(img, model_name, scale_factor):
    # get img width height
    width, height = img.size
    img_mode = img.mode
    if img.mode != "RGB":
        img = img.convert("RGB")

    transform = transforms.Compose([
        transforms.Resize((height * scale_factor, width * scale_factor),
                          interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.ToImage(),
        transforms.ToDtype(torch.float32, scale=True),
    ])
    
    #Load Model
    checkpoint = torch.load(get_pretrained_path(
        model_name), map_location=torch.device('cpu'))
    model = UNET()
    model.load_state_dict(checkpoint['best_model_state_dict'])
    model.eval()

    data = transform(img).clamp(0, 1).unsqueeze(0)
    # print(data.shape, img.mode)
    # return img
    h_pad, w_pad = find_padding(data)
    data = F.pad(data, (0, w_pad, 0, h_pad), mode='reflect')
    

    with torch.no_grad():
        img_scale_pred = model(data).clamp(0, 1)
        if h_pad > 0 and w_pad > 0:
            img_scale_pred = img_scale_pred[..., :-h_pad, :-w_pad]
        elif h_pad > 0:
            img_scale_pred = img_scale_pred[..., :-h_pad, :]
        elif w_pad > 0:
            img_scale_pred = img_scale_pred[..., :, :-w_pad]
        else:
            img_scale_pred = img_scale_pred

        img_scale_pred = img_scale_pred.squeeze(0)
    return transforms.ToPILImage()(img_scale_pred).convert(img_mode)

def enhanced_image(img, model_name):
    img_mode = img.mode
    if img.mode != "RGB":
        img = img.convert("RGB")

    transform = transforms.Compose([
        transforms.ToImage(),
        transforms.ToDtype(torch.float32, scale=True),
    ])
    
    #Load Model
    checkpoint = torch.load(get_pretrained_path(
        model_name), map_location=torch.device('cpu'))
    model = UNET()
    model.load_state_dict(checkpoint['best_model_state_dict'])
    model.eval()

    data = transform(img).clamp(0, 1).unsqueeze(0)
    h_pad, w_pad = find_padding(data)
    data = F.pad(data, (0, w_pad, 0, h_pad), mode='reflect')
    

    with torch.no_grad():
        img_scale_pred = model(data).clamp(0, 1)
        if h_pad > 0 and w_pad > 0:
            img_scale_pred = img_scale_pred[..., :-h_pad, :-w_pad]
        elif h_pad > 0:
            img_scale_pred = img_scale_pred[..., :-h_pad, :]
        elif w_pad > 0:
            img_scale_pred = img_scale_pred[..., :, :-w_pad]
        else:
            img_scale_pred = img_scale_pred

        img_scale_pred = img_scale_pred.squeeze(0)
    return transforms.ToPILImage()(img_scale_pred).convert(img_mode)