Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| from torchvision import transforms | |
| from PIL import Image | |
| import numpy as np | |
| from torch import nn | |
| import cv2 | |
| #device='cuda' if torch.cuda.is_available() else 'cpu' | |
| device='cpu' | |
| print(f'Using: {device}') | |
| def build_generator(): | |
| class ResidualBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels, expansion=6, stride=1, alpha=1.0): | |
| super(ResidualBlock, self).__init__() | |
| self.expansion = expansion | |
| self.stride = stride | |
| self.in_channels = in_channels | |
| self.out_channels = int(out_channels * alpha) | |
| self.pointwise_conv_filters = self._make_divisible(self.out_channels, 8) | |
| self.conv1 = nn.Conv2d(in_channels, in_channels * expansion, kernel_size=1, stride=1, padding=0, bias=True) | |
| self.bn1 = nn.BatchNorm2d(in_channels * expansion) | |
| self.conv2 = nn.Conv2d(in_channels * expansion, in_channels * expansion, kernel_size=3, stride=stride, padding=1, groups=in_channels * expansion, bias=True) | |
| self.bn2 = nn.BatchNorm2d(in_channels * expansion) | |
| self.conv3 = nn.Conv2d(in_channels * expansion, self.pointwise_conv_filters, kernel_size=1, stride=1, padding=0, bias=True) | |
| self.bn3 = nn.BatchNorm2d(self.pointwise_conv_filters) | |
| self.relu = nn.ReLU(inplace=True) | |
| self.skip_add = (stride == 1 and in_channels == self.pointwise_conv_filters) | |
| def forward(self, x): | |
| identity = x | |
| out = self.conv1(x) | |
| out = self.bn1(out) | |
| out = self.relu(out) | |
| out = self.conv2(out) | |
| out = self.bn2(out) | |
| out = self.relu(out) | |
| out = self.conv3(out) | |
| out = self.bn3(out) | |
| if self.skip_add: | |
| out = out + identity | |
| return out | |
| def _make_divisible(v, divisor, min_value=None): | |
| if min_value is None: | |
| min_value = divisor | |
| new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) | |
| if new_v < 0.9 * v: | |
| new_v += divisor | |
| return new_v | |
| class Generator(nn.Module): | |
| def __init__(self, in_channels, num_residual_blocks, gf): | |
| super(Generator, self).__init__() | |
| self.num_residual_blocks = num_residual_blocks | |
| self.gf = gf | |
| self.conv1 = nn.Conv2d(in_channels, gf, kernel_size=3, stride=1, padding=1) | |
| self.bn1 = nn.BatchNorm2d(gf) | |
| self.prelu1 = nn.PReLU() | |
| self.residual_blocks = self.make_layer(ResidualBlock, gf, num_residual_blocks) | |
| self.conv2 = nn.Conv2d(gf, gf, kernel_size=3, stride=1, padding=1) | |
| self.bn2 = nn.BatchNorm2d(gf) | |
| self.upsample1 = nn.Sequential( | |
| nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), | |
| nn.Conv2d(gf, gf, kernel_size=3, stride=1, padding=1), | |
| nn.PReLU() | |
| ) | |
| self.upsample2 = nn.Sequential( | |
| nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), | |
| nn.Conv2d(gf, gf, kernel_size=3, stride=1, padding=1), | |
| nn.PReLU() | |
| ) | |
| self.conv3 = nn.Conv2d(gf, 3, kernel_size=3, stride=1, padding=1) | |
| self.tanh = nn.Tanh() | |
| def make_layer(self, block, out_channels, blocks): | |
| layers = [] | |
| for _ in range(blocks): | |
| layers.append(block(out_channels, out_channels)) | |
| return nn.Sequential(*layers) | |
| def forward(self, x): | |
| out1 = self.prelu1(self.bn1(self.conv1(x))) | |
| out = self.residual_blocks(out1) | |
| out = self.bn2(self.conv2(out)) | |
| out = out + out1 | |
| out = self.upsample1(out) | |
| out = self.upsample2(out) | |
| out = self.tanh(self.conv3(out)) | |
| return out | |
| return Generator(3, 6, 32) | |
| model=build_generator().to(device) | |
| model.load_state_dict(torch.load('./generator_weight.pt', map_location=torch.device('cpu'))) | |
| def numpify(imgs): | |
| all_images = [] | |
| for img in imgs: | |
| img = img.permute(1,2,0).to('cpu') ### MIGHT CRASH HERE | |
| all_images.append(img) | |
| return np.stack(all_images, axis=0) | |
| transform = transforms.Compose([ | |
| transforms.ToTensor() | |
| ]) | |
| # Function to translate the image | |
| def translate_image(image, sharpen): | |
| print('Translating!') | |
| desired_width = 480 | |
| original_width, original_height = image.size | |
| desired_height = int((original_height / original_width) * desired_width) | |
| resized_image = image.resize((desired_width, desired_height)) | |
| low_res = transform(resized_image) | |
| low_res = low_res.unsqueeze(dim=0).to(device) | |
| model.eval() | |
| with torch.no_grad(): | |
| sr = model(low_res) | |
| fake_imgs = numpify(sr) | |
| sr_img = Image.fromarray((((fake_imgs[0] + 1) / 2) * 255).astype(np.uint8)) | |
| if sharpen: | |
| sr_img_cv = np.array(sr_img) | |
| sr_img_cv = cv2.cvtColor(sr_img_cv, cv2.COLOR_RGB2BGR) | |
| kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]]) | |
| sharpened_sr_img_cv = cv2.filter2D(sr_img_cv, -1, kernel) | |
| sharpened_sr_img = Image.fromarray(cv2.cvtColor(sharpened_sr_img_cv, cv2.COLOR_BGR2RGB)) | |
| sharpened_sr_img.save('super_resolved_image.png') | |
| return sharpened_sr_img | |
| else: | |
| sr_img.save('super_resolved_image.png') | |
| return sr_img | |
| # Set up the Gradio interface | |
| interface = gr.Interface( | |
| fn=translate_image, | |
| inputs=[ | |
| gr.Image(type="pil"), | |
| gr.Checkbox(label="Sharpen Image") | |
| ], | |
| outputs=gr.Image(type="pil", label="Translated Image"), | |
| title="Correction App", | |
| description="Upload an image and get the translated version. Some images may be blurry, you can tick the checkbox to sharpen them.", | |
| allow_flagging=None | |
| ) | |
| # Launch the Gradio app | |
| interface.launch() |