File size: 3,490 Bytes
d6a6b07
784930d
d6a6b07
 
0a85d0e
d6a6b07
784930d
d6a6b07
784930d
 
f0ebdad
784930d
 
 
 
 
 
 
 
 
 
 
 
d6a6b07
 
 
 
 
 
 
 
 
 
784930d
d6a6b07
 
 
 
f0ebdad
784930d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0ebdad
d6a6b07
f0ebdad
784930d
f0ebdad
 
784930d
 
 
d6a6b07
0a85d0e
784930d
 
 
 
 
 
 
d6a6b07
 
 
 
 
 
 
 
 
 
 
 
784930d
 
d6a6b07
 
 
f0ebdad
 
 
 
 
 
 
d6a6b07
 
 
784930d
d6a6b07
 
f0ebdad
 
 
d6a6b07
8404571
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import imageio
import gradio as gr
import numpy as np
from PIL import Image, ImageSequence
from tempfile import NamedTemporaryFile
from pathlib import Path
from loguru import logger
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer, torch


DEFUALT_TILE = 128
UNLIMITED = False
BASE_WIDTH = 256

if torch.cuda.is_available():
    logger.info("CUDA available, Using Device: GPU")
    DEFUALT_TILE = 256
    UNLIMITED = True
else:
    logger.info("CUDA not found, Using Device: CPU")

upsampler_anime = RealESRGANer(
    scale=4,
    model_path=str(Path(__file__).parent.joinpath("RealESRGAN_x4plus_anime_6B.pth")),
    model=RRDBNet(
        num_in_ch=3,
        num_out_ch=3,
        num_feat=64,
        num_block=6,
        num_grow_ch=32,
        scale=4,
    ),
    tile=DEFUALT_TILE,
    tile_pad=10,
    pre_pad=0,
    half=False,
)

upsampler_base = RealESRGANer(
    scale=4,
    model_path=str(Path(__file__).parent.joinpath("RealESRGAN_x4plus.pth")),
    model=RRDBNet(
        num_in_ch=3,
        num_out_ch=3,
        num_feat=64,
        num_block=23,
        num_grow_ch=32,
        scale=4,
    ),
    tile=DEFUALT_TILE,
    tile_pad=10,
    pre_pad=0,
    half=False,
)

# torch.hub.download_url_to_file('https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth', 'RealESRGAN_x4plus.pth')

# torch.hub.download_url_to_file('https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth', 'RealESRGAN_x4plus_anime_6B.pth')


def inference(image_path: str, mode: str):
    upsampler = upsampler_anime if mode == "anime" else upsampler_base
    image = Image.open(image_path)
    outputs = []
    fd = NamedTemporaryFile(delete=False)
    if not UNLIMITED and image.size[0] > BASE_WIDTH:
        logger.info("Image is too large, resizing to 256px")
        wpercent = BASE_WIDTH / float(image.size[0])
        hsize = int((float(image.size[1]) * float(wpercent)))
        image = image.resize((BASE_WIDTH, hsize))
    logger.debug("TempFile path: " + fd.name)
    is_gif = image.format == "GIF"
    if is_gif:
        for i in ImageSequence.Iterator(image):
            image_array: np.ndarray = np.array(i)
            output, _ = upsampler.enhance(image_array, 4)
            outputs.append(output)
        imageio.mimsave(
            fd,
            outputs[1:],
            format="gif",
            duration=image.info["duration"] / 1000,
        )
    else:
        image_array: np.ndarray = np.array(image)
        output, _ = upsampler.enhance(image_array, 4)
        img = Image.fromarray(output)
        img.save(fd, format="PNG")  # format: PNG / JPEG
    return fd.name


title = "Real-ESRGAN"
description = "Gradio demo for Real-ESRGAN. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below. Please click submit only once"
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2107.10833'>Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data</a> | <a href='https://github.com/xinntao/Real-ESRGAN'>Github Repo</a></p>"

gr.Interface(
    inference,
    [
        gr.Image(type="filepath", label="Input"),
        gr.Radio(["anime", "base"], type="value", label="model type"),
    ],
    gr.Image(type="file", label="Output"),
    title=title,
    description=description,
    article=article,
    examples=[["bear.jpg", "base"], ["anime.png", "anime"]],
).launch()