Spaces:
Runtime error
Runtime error
| import io | |
| import numpy as np | |
| import onnxruntime | |
| from torch import nn | |
| import torch.utils.model_zoo as model_zoo | |
| import torch.onnx | |
| import torch.nn as nn | |
| import torch.nn.init as init | |
| import matplotlib.pyplot as plt | |
| import json | |
| from PIL import Image, ImageDraw, ImageFont | |
| from resizeimage import resizeimage | |
| import numpy as np | |
| import pdb | |
| import onnx | |
| import gradio as gr | |
| import os | |
| class SuperResolutionNet(nn.Module): | |
| def __init__(self, upscale_factor, inplace=False): | |
| super(SuperResolutionNet, self).__init__() | |
| self.relu = nn.ReLU(inplace=inplace) | |
| self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2)) | |
| self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) | |
| self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1)) | |
| self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1)) | |
| self.pixel_shuffle = nn.PixelShuffle(upscale_factor) | |
| self._initialize_weights() | |
| def forward(self, x): | |
| x = self.relu(self.conv1(x)) | |
| x = self.relu(self.conv2(x)) | |
| x = self.relu(self.conv3(x)) | |
| x = self.pixel_shuffle(self.conv4(x)) | |
| return x | |
| def _initialize_weights(self): | |
| init.orthogonal_(self.conv1.weight, init.calculate_gain('relu')) | |
| init.orthogonal_(self.conv2.weight, init.calculate_gain('relu')) | |
| init.orthogonal_(self.conv3.weight, init.calculate_gain('relu')) | |
| init.orthogonal_(self.conv4.weight) | |
| # Create the super-resolution model by using the above model definition. | |
| torch_model = SuperResolutionNet(upscale_factor=3) | |
| model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth' | |
| batch_size = 1 # just a random number | |
| # Initialize model with the pretrained weights | |
| map_location = lambda storage, loc: storage | |
| if torch.cuda.is_available(): | |
| map_location = None | |
| torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location)) | |
| x = torch.randn(1, 1, 224, 224, requires_grad=True) | |
| torch_model.eval() | |
| os.system("wget https://github.com/AK391/models/raw/main/vision/super_resolution/sub_pixel_cnn_2016/model/super-resolution-10.onnx") | |
| # Start from ORT 1.10, ORT requires explicitly setting the providers parameter if you want to use execution providers | |
| # other than the default CPU provider (as opposed to the previous behavior of providers getting set/registered by default | |
| # based on the build flags) when instantiating InferenceSession. | |
| # For example, if NVIDIA GPU is available and ORT Python package is built with CUDA, then call API as following: | |
| # onnxruntime.InferenceSession(path/to/model, providers=['CUDAExecutionProvider']) | |
| ort_session = onnxruntime.InferenceSession("super-resolution-10.onnx") | |
| def inference(img): | |
| orig_img = Image.open(img) | |
| img = resizeimage.resize_cover(orig_img, [224,224], validate=False) | |
| img_ycbcr = img.convert('YCbCr') | |
| img_y_0, img_cb, img_cr = img_ycbcr.split() | |
| img_ndarray = np.asarray(img_y_0) | |
| img_4 = np.expand_dims(np.expand_dims(img_ndarray, axis=0), axis=0) | |
| img_5 = img_4.astype(np.float32) / 255.0 | |
| ort_inputs = {ort_session.get_inputs()[0].name: img_5} | |
| ort_outs = ort_session.run(None, ort_inputs) | |
| img_out_y = ort_outs[0] | |
| img_out_y = Image.fromarray(np.uint8((img_out_y[0] * 255.0).clip(0, 255)[0]), mode='L') | |
| final_img = Image.merge( | |
| "YCbCr", [ | |
| img_out_y, | |
| img_cb.resize(img_out_y.size, Image.BICUBIC), | |
| img_cr.resize(img_out_y.size, Image.BICUBIC), | |
| ]).convert("RGB") | |
| return final_img | |
| title="sub_pixel_cnn_2016" | |
| description="The Super Resolution machine learning model sharpens and upscales the input image to refine the details and improve quality." | |
| gr.Interface(inference,gr.inputs.Image(type="filepath"),gr.outputs.Image(type="pil"),title=title,description=description).launch() |