Spaces:
Build error
Build error
| import imageio | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image, ImageSequence | |
| from tempfile import NamedTemporaryFile | |
| from pathlib import Path | |
| from loguru import logger | |
| from basicsr.archs.rrdbnet_arch import RRDBNet | |
| from realesrgan import RealESRGANer, torch | |
| DEFUALT_TILE = 128 | |
| UNLIMITED = False | |
| BASE_WIDTH = 256 | |
| if torch.cuda.is_available(): | |
| logger.info("CUDA available, Using Device: GPU") | |
| DEFUALT_TILE = 256 | |
| UNLIMITED = True | |
| else: | |
| logger.info("CUDA not found, Using Device: CPU") | |
| upsampler_anime = RealESRGANer( | |
| scale=4, | |
| model_path=str(Path(__file__).parent.joinpath("RealESRGAN_x4plus_anime_6B.pth")), | |
| model=RRDBNet( | |
| num_in_ch=3, | |
| num_out_ch=3, | |
| num_feat=64, | |
| num_block=6, | |
| num_grow_ch=32, | |
| scale=4, | |
| ), | |
| tile=DEFUALT_TILE, | |
| tile_pad=10, | |
| pre_pad=0, | |
| half=False, | |
| ) | |
| upsampler_base = RealESRGANer( | |
| scale=4, | |
| model_path=str(Path(__file__).parent.joinpath("RealESRGAN_x4plus.pth")), | |
| model=RRDBNet( | |
| num_in_ch=3, | |
| num_out_ch=3, | |
| num_feat=64, | |
| num_block=23, | |
| num_grow_ch=32, | |
| scale=4, | |
| ), | |
| tile=DEFUALT_TILE, | |
| tile_pad=10, | |
| pre_pad=0, | |
| half=False, | |
| ) | |
| # torch.hub.download_url_to_file('https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth', 'RealESRGAN_x4plus.pth') | |
| # torch.hub.download_url_to_file('https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth', 'RealESRGAN_x4plus_anime_6B.pth') | |
| def inference(image_path: str, mode: str): | |
| upsampler = upsampler_anime if mode == "anime" else upsampler_base | |
| image = Image.open(image_path) | |
| outputs = [] | |
| fd = NamedTemporaryFile(delete=False) | |
| if not UNLIMITED and image.size[0] > BASE_WIDTH: | |
| logger.info("Image is too large, resizing to 256px") | |
| wpercent = BASE_WIDTH / float(image.size[0]) | |
| hsize = int((float(image.size[1]) * float(wpercent))) | |
| image = image.resize((BASE_WIDTH, hsize)) | |
| logger.debug("TempFile path: " + fd.name) | |
| is_gif = image.format == "GIF" | |
| if is_gif: | |
| for i in ImageSequence.Iterator(image): | |
| image_array: np.ndarray = np.array(i) | |
| output, _ = upsampler.enhance(image_array, 4) | |
| outputs.append(output) | |
| imageio.mimsave( | |
| fd, | |
| outputs[1:], | |
| format="gif", | |
| duration=image.info["duration"] / 1000, | |
| ) | |
| else: | |
| image_array: np.ndarray = np.array(image) | |
| output, _ = upsampler.enhance(image_array, 4) | |
| img = Image.fromarray(output) | |
| img.save(fd, format="PNG") # format: PNG / JPEG | |
| return fd.name | |
| title = "Real-ESRGAN" | |
| description = "Gradio demo for Real-ESRGAN. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below. Please click submit only once" | |
| article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2107.10833'>Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data</a> | <a href='https://github.com/xinntao/Real-ESRGAN'>Github Repo</a></p>" | |
| gr.Interface( | |
| inference, | |
| [ | |
| gr.Image(type="filepath", label="Input"), | |
| gr.Radio(["anime", "base"], type="value", label="model type"), | |
| ], | |
| gr.Image(type="file", label="Output"), | |
| title=title, | |
| description=description, | |
| article=article, | |
| examples=[["bear.jpg", "base"], ["anime.png", "anime"]], | |
| ).launch() | |