Real-ESRGAN / app.py
Marne's picture
Update All
784930d
import imageio
import gradio as gr
import numpy as np
from PIL import Image, ImageSequence
from tempfile import NamedTemporaryFile
from pathlib import Path
from loguru import logger
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer, torch
DEFUALT_TILE = 128
UNLIMITED = False
BASE_WIDTH = 256
if torch.cuda.is_available():
logger.info("CUDA available, Using Device: GPU")
DEFUALT_TILE = 256
UNLIMITED = True
else:
logger.info("CUDA not found, Using Device: CPU")
upsampler_anime = RealESRGANer(
scale=4,
model_path=str(Path(__file__).parent.joinpath("RealESRGAN_x4plus_anime_6B.pth")),
model=RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=6,
num_grow_ch=32,
scale=4,
),
tile=DEFUALT_TILE,
tile_pad=10,
pre_pad=0,
half=False,
)
upsampler_base = RealESRGANer(
scale=4,
model_path=str(Path(__file__).parent.joinpath("RealESRGAN_x4plus.pth")),
model=RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4,
),
tile=DEFUALT_TILE,
tile_pad=10,
pre_pad=0,
half=False,
)
# torch.hub.download_url_to_file('https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth', 'RealESRGAN_x4plus.pth')
# torch.hub.download_url_to_file('https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth', 'RealESRGAN_x4plus_anime_6B.pth')
def inference(image_path: str, mode: str):
upsampler = upsampler_anime if mode == "anime" else upsampler_base
image = Image.open(image_path)
outputs = []
fd = NamedTemporaryFile(delete=False)
if not UNLIMITED and image.size[0] > BASE_WIDTH:
logger.info("Image is too large, resizing to 256px")
wpercent = BASE_WIDTH / float(image.size[0])
hsize = int((float(image.size[1]) * float(wpercent)))
image = image.resize((BASE_WIDTH, hsize))
logger.debug("TempFile path: " + fd.name)
is_gif = image.format == "GIF"
if is_gif:
for i in ImageSequence.Iterator(image):
image_array: np.ndarray = np.array(i)
output, _ = upsampler.enhance(image_array, 4)
outputs.append(output)
imageio.mimsave(
fd,
outputs[1:],
format="gif",
duration=image.info["duration"] / 1000,
)
else:
image_array: np.ndarray = np.array(image)
output, _ = upsampler.enhance(image_array, 4)
img = Image.fromarray(output)
img.save(fd, format="PNG") # format: PNG / JPEG
return fd.name
title = "Real-ESRGAN"
description = "Gradio demo for Real-ESRGAN. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below. Please click submit only once"
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2107.10833'>Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data</a> | <a href='https://github.com/xinntao/Real-ESRGAN'>Github Repo</a></p>"
gr.Interface(
inference,
[
gr.Image(type="filepath", label="Input"),
gr.Radio(["anime", "base"], type="value", label="model type"),
],
gr.Image(type="file", label="Output"),
title=title,
description=description,
article=article,
examples=[["bear.jpg", "base"], ["anime.png", "anime"]],
).launch()