| import glob | |
| import io | |
| import os | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import wget | |
| from torchvision.transforms import Compose, ToTensor | |
| from model import decoder, encoder | |
| WEIGHT_PATH = './weights/best_weight.pth' | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| class Model(object): | |
| def __init__(self) -> None: | |
| self.model_Enc = encoder.Encoder_RRDB(num_feat=64).to(device=DEVICE) | |
| self.model_Dec_SR = decoder.Decoder_SR_RRDB(num_in_ch=64).to(device=DEVICE) | |
| self.preprocess = Compose([ToTensor()]) | |
| self.load_model() | |
| def load_model(self, weight_path=WEIGHT_PATH): | |
| if not os.path.isfile("./weights/best_weight.pth"): | |
| response = wget.download("https://raw.githubusercontent.com/hungnguyen2611/super-resolution/master/weights/best_weight.pth", "./weights/best_weight.pth") | |
| weight = torch.load(weight_path, map_location=torch.device(DEVICE)) | |
| print("[LOADING] Loading encoder...") | |
| self.model_Enc.load_state_dict(weight['model_Enc']) | |
| print("[LOADING] Loading decoder...") | |
| self.model_Dec_SR.load_state_dict(weight['model_Dec_SR']) | |
| print("[LOADING] Loading done!") | |
| self.model_Enc.eval() | |
| self.model_Dec_SR.eval() | |
| def predict(self, img): | |
| with torch.no_grad(): | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| img = self.preprocess(img) | |
| img = img.unsqueeze(0) | |
| img = img.to(DEVICE) | |
| feat = self.model_Enc(img) | |
| out = self.model_Dec_SR(feat) | |
| min_max = (0, 1) | |
| out = out.detach()[0].float().cpu() | |
| out = out.squeeze().float().cpu().clamp_(*min_max) | |
| out = (out - min_max[0]) / (min_max[1] - min_max[0]) | |
| out = out.numpy() | |
| out = np.transpose(out[[2, 1, 0], :, :], (1, 2, 0)) | |
| out = (out*255.0).round() | |
| out = out.astype(np.uint8) | |
| return out | |
| model = Model() | |
| def predict(img): | |
| global model | |
| img.save("test/1.png", "PNG") | |
| image = cv2.imread("test/1.png", cv2.IMREAD_COLOR) | |
| out = model.predict(img=image) | |
| cv2.imwrite(f'images_uploaded/1.png', out) | |
| return f"images_uploaded/1.png" | |
| if __name__ == '__main__': | |
| title = "Super-Resolution Demo USR-DA Unofficial πππ₯" | |
| description = ''' | |
| <br> | |
| **This Demo expects low-quality and low-resolution images, better visual on real-world images** | |
| </br> | |
| ''' | |
| article = "<p style='text-align: center'><a href='https://openaccess.thecvf.com/content/ICCV2021/papers/Wang_Unsupervised_Real-World_Super-Resolution_A_Domain_Adaptation_Perspective_ICCV_2021_paper.pdf' target='_blank'>Unsupervised Real-World Super-Resolution: A Domain Adaptation Perspective</a> | <a href='https://github.com/hungnguyen2611/super-resolution.git' target='_blank'>Github Repo</a></p>" | |
| examples= glob.glob("testsets/*.png") | |
| gr.Interface( | |
| predict, | |
| gr.inputs.Image(type="pil", label="Input").style(height=260), | |
| gr.inputs.Image(type="pil", label="Ouput").style(height=240), | |
| title=title, | |
| description=description, | |
| article=article, | |
| examples=examples, | |
| ).launch(enable_queue=True) | |