Spaces:
Running
Running
| import numpy as np | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision | |
| from torchvision import transforms | |
| from PIL import Image | |
| title = "Super Resolution with CNN" | |
| description = """ | |
| Your low resolution image will be reconstructed to high resolution with a scale of 2 with a convolutional neural network!<br> | |
| Detailed training and dataset can be found on my [github repo](https://github.com/susuhu/super-resolution).<br> | |
| """ | |
| article = """ | |
| <div style='margin:20px auto;'> | |
| <p>Sources:<p> | |
| <p>๐ <a href="https://arxiv.org/abs/1501.00092">Image Super-Resolution Using Deep Convolutional Networks</a></p> | |
| <p>๐ฆ Dataset <a href="https://github.com/eugenesiow/super-image-data">this GitHub repo</a></p> | |
| </div> | |
| """ | |
| examples = [ | |
| ["peperoni.png"], | |
| ["barbara.png"], | |
| ] | |
| class SRCNNModel(nn.Module): | |
| def __init__(self): | |
| super(SRCNNModel, self).__init__() | |
| self.conv1 = nn.Conv2d(1, 64, 9, padding=4) | |
| self.conv2 = nn.Conv2d(64, 32, 1, padding=0) | |
| self.conv3 = nn.Conv2d(32, 1, 5, padding=2) | |
| def forward(self, x): | |
| out = F.relu(self.conv1(x)) | |
| out = F.relu(self.conv2(out)) | |
| out = self.conv3(out) | |
| return out | |
| def pred_SRCNN(model, image, device, scale_factor=2): | |
| """ | |
| model: SRCNN model | |
| image: low resolution image PILLOW image | |
| scale_factor: scale factor for resolution | |
| device: cuda or cpu | |
| """ | |
| model.to(device) | |
| model.eval() | |
| # open image, gradio opens image as nparray | |
| image = Image.fromarray(image) | |
| # split channels | |
| y, cb, cr = image.convert("YCbCr").split() | |
| # size will be used in image transform | |
| original_size = y.size | |
| # bicubic interpolate it to the original size | |
| y_bicubic = transforms.Resize( | |
| (original_size[1] * scale_factor, original_size[0] * scale_factor), | |
| interpolation=transforms.InterpolationMode.BICUBIC, | |
| )(y) | |
| cb_bicubic = transforms.Resize( | |
| (original_size[1] * scale_factor, original_size[0] * scale_factor), | |
| interpolation=transforms.InterpolationMode.BICUBIC, | |
| )(cb) | |
| cr_bicubic = transforms.Resize( | |
| (original_size[1] * scale_factor, original_size[0] * scale_factor), | |
| interpolation=transforms.InterpolationMode.BICUBIC, | |
| )(cr) | |
| # turn it into tensor and add batch dimension | |
| y_bicubic = transforms.ToTensor()(y_bicubic).to(device).unsqueeze(0) | |
| # get the y channel SRCNN prediction | |
| y_pred = model(y_bicubic) | |
| # convert it to numpy image | |
| y_pred = y_pred[0].cpu().detach().numpy() | |
| # convert it into regular image pixel values | |
| y_pred = y_pred * 255 | |
| y_pred.clip(0, 255) | |
| # conver y channel from array to PIL image format for merging | |
| y_pred_PIL = Image.fromarray(np.uint8(y_pred[0]), mode="L") | |
| # merge the SRCNN y channel with cb cr channels | |
| out_final = Image.merge("YCbCr", [y_pred_PIL, cb_bicubic, cr_bicubic]).convert( | |
| "RGB" | |
| ) | |
| image_bicubic = transforms.Resize( | |
| (original_size[1] * scale_factor, original_size[0] * scale_factor), | |
| interpolation=transforms.InterpolationMode.BICUBIC, | |
| )(image) | |
| return out_final, image_bicubic | |
| # load model | |
| # print("Loading SRCNN model...") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = SRCNNModel().to(device) | |
| model.load_state_dict( | |
| torch.load("SRCNNmodel_trained.pt", map_location=torch.device(device)) | |
| ) | |
| model.eval() | |
| # print("SRCNN model loaded!") | |
| # def image_grid(imgs, rows, cols): | |
| # ''' | |
| # imgs:list of PILImage | |
| # ''' | |
| # assert len(imgs) == rows*cols | |
| # w, h = imgs[0].size | |
| # grid = Image.new('RGB', size=(cols*w, rows*h)) | |
| # grid_w, grid_h = grid.size | |
| # for i, img in enumerate(imgs): | |
| # grid.paste(img, box=(i%cols*w, i//cols*h)) | |
| # return grid | |
| def super_reso(input_image): | |
| # gradio open image as np array | |
| #image_array = np.asarray(image_path) | |
| #image = Image.fromarray(image_array, mode="RGB") | |
| # prediction | |
| with torch.no_grad(): | |
| out_final, image_bicubic = pred_SRCNN( | |
| model=model, image=input_image, device=device | |
| ) | |
| # grid = image_grid([out_final,image_bicubic],1,2) | |
| return out_final, image_bicubic | |
| gr.Interface( | |
| fn=super_reso, | |
| inputs=gr.Image(label="Upload image"), | |
| outputs=[ | |
| gr.Image(label="Convolutional neural network"), | |
| gr.Image(label="Bicubic interpoloation"), | |
| ], | |
| title=title, | |
| description=description, | |
| article=article, | |
| examples=examples, | |
| ).launch() | |
| # TypeError: AsyncConnectionPool.__init__() got an unexpected keyword argument 'socket_options' |