File size: 6,041 Bytes
9f00155
 
 
 
 
 
 
 
2b419d6
 
 
9f00155
 
2b419d6
9f00155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b419d6
00ec9ab
2b419d6
 
 
 
 
 
 
 
9f00155
2b419d6
 
 
9f00155
 
2b419d6
9f00155
2b419d6
9f00155
2b419d6
9f00155
 
2b419d6
9f00155
2b419d6
 
 
9f00155
2b419d6
9f00155
2b419d6
9f00155
2b419d6
9f00155
 
 
 
 
 
 
 
 
 
2b419d6
 
9f00155
 
2b419d6
 
 
9f00155
 
2b419d6
 
9f00155
 
2b419d6
 
9f00155
2b419d6
9f00155
2b419d6
 
9f00155
 
2b419d6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
from torch import nn
import cv2

#device='cuda' if torch.cuda.is_available() else 'cpu'
device='cpu'
print(f'Using: {device}')

def build_generator():
    
    class ResidualBlock(nn.Module):
        def __init__(self, in_channels, out_channels, expansion=6, stride=1, alpha=1.0):
            super(ResidualBlock, self).__init__()
            self.expansion = expansion
            self.stride = stride
            self.in_channels = in_channels
            self.out_channels = int(out_channels * alpha)
            self.pointwise_conv_filters = self._make_divisible(self.out_channels, 8)
            self.conv1 = nn.Conv2d(in_channels, in_channels * expansion, kernel_size=1, stride=1, padding=0, bias=True)
            self.bn1 = nn.BatchNorm2d(in_channels * expansion)
            self.conv2 = nn.Conv2d(in_channels * expansion, in_channels * expansion, kernel_size=3, stride=stride, padding=1, groups=in_channels * expansion, bias=True)
            self.bn2 = nn.BatchNorm2d(in_channels * expansion)
            self.conv3 = nn.Conv2d(in_channels * expansion, self.pointwise_conv_filters, kernel_size=1, stride=1, padding=0, bias=True)
            self.bn3 = nn.BatchNorm2d(self.pointwise_conv_filters)
            self.relu = nn.ReLU(inplace=True)
            self.skip_add = (stride == 1 and in_channels == self.pointwise_conv_filters)

        def forward(self, x):
            identity = x

            out = self.conv1(x)
            out = self.bn1(out)
            out = self.relu(out)

            out = self.conv2(out)
            out = self.bn2(out)
            out = self.relu(out)

            out = self.conv3(out)
            out = self.bn3(out)

            if self.skip_add:
                out = out + identity

            return out

        @staticmethod
        def _make_divisible(v, divisor, min_value=None):
            if min_value is None:
                min_value = divisor
            new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
            if new_v < 0.9 * v:
                new_v += divisor
            return new_v

    class Generator(nn.Module):
        def __init__(self, in_channels, num_residual_blocks, gf):
            super(Generator, self).__init__()
            self.num_residual_blocks = num_residual_blocks
            self.gf = gf

            self.conv1 = nn.Conv2d(in_channels, gf, kernel_size=3, stride=1, padding=1)
            self.bn1 = nn.BatchNorm2d(gf)
            self.prelu1 = nn.PReLU()

            self.residual_blocks = self.make_layer(ResidualBlock, gf, num_residual_blocks)

            self.conv2 = nn.Conv2d(gf, gf, kernel_size=3, stride=1, padding=1)
            self.bn2 = nn.BatchNorm2d(gf)

            self.upsample1 = nn.Sequential(
                nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
                nn.Conv2d(gf, gf, kernel_size=3, stride=1, padding=1),
                nn.PReLU()
            )

            self.upsample2 = nn.Sequential(
                nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
                nn.Conv2d(gf, gf, kernel_size=3, stride=1, padding=1),
                nn.PReLU()
            )

            self.conv3 = nn.Conv2d(gf, 3, kernel_size=3, stride=1, padding=1)
            self.tanh = nn.Tanh()

        def make_layer(self, block, out_channels, blocks):
            layers = []
            for _ in range(blocks):
                layers.append(block(out_channels, out_channels))
            return nn.Sequential(*layers)

        def forward(self, x):
            out1 = self.prelu1(self.bn1(self.conv1(x)))
            out = self.residual_blocks(out1)
            out = self.bn2(self.conv2(out))
            out = out + out1
            out = self.upsample1(out)
            out = self.upsample2(out)
            out = self.tanh(self.conv3(out))
            return out

    return Generator(3, 6, 32)

model=build_generator().to(device)
model.load_state_dict(torch.load('./generator_weight.pt', map_location=torch.device('cpu')))


def numpify(imgs):
    all_images = []
    for img in imgs:
        img = img.permute(1,2,0).to('cpu') ### MIGHT CRASH HERE
        all_images.append(img)
    return np.stack(all_images, axis=0)

transform = transforms.Compose([
            transforms.ToTensor()
        ])


# Function to translate the image
def translate_image(image, sharpen):
    print('Translating!')
    desired_width = 480
    
    original_width, original_height = image.size
    desired_height = int((original_height / original_width) * desired_width)

    resized_image = image.resize((desired_width, desired_height))
    low_res = transform(resized_image)
    low_res = low_res.unsqueeze(dim=0).to(device)
    model.eval()
    with torch.no_grad():
        sr = model(low_res)
        
    fake_imgs = numpify(sr)
    
    sr_img = Image.fromarray((((fake_imgs[0] + 1) / 2) * 255).astype(np.uint8))
    
    if sharpen:
        sr_img_cv = np.array(sr_img)
        sr_img_cv = cv2.cvtColor(sr_img_cv, cv2.COLOR_RGB2BGR)
        
        kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])
        sharpened_sr_img_cv = cv2.filter2D(sr_img_cv, -1, kernel)
        
        sharpened_sr_img = Image.fromarray(cv2.cvtColor(sharpened_sr_img_cv, cv2.COLOR_BGR2RGB))
        
        sharpened_sr_img.save('super_resolved_image.png')
        
        return sharpened_sr_img
    else:
        
        sr_img.save('super_resolved_image.png')
        
        return sr_img

# Set up the Gradio interface
interface = gr.Interface(
    fn=translate_image,
    inputs=[
        gr.Image(type="pil"),
        gr.Checkbox(label="Sharpen Image")
    ],
    outputs=gr.Image(type="pil", label="Translated Image"),
    title="Correction App",
    description="Upload an image and get the translated version. Some images may be blurry, you can tick the checkbox to sharpen them.",
    allow_flagging=None
)

# Launch the Gradio app
interface.launch()