model-sd-multi / pipelines /upscaler.py
jayparmr's picture
Upload 18 files
4adca93
import os
from pathlib import Path
from typing import Union
import cv2
import numpy as np
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from PIL import Image
from realesrgan import RealESRGANer
from util.commons import read_url
class Upscaler:
__model_esrgan_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
__model_esrgan_anime_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"
def load(self):
download_dir = Path(Path.home() / ".cache" / "realesrgan")
download_dir.mkdir(parents=True, exist_ok=True)
self.__model_path = self.__preload_model(self.__model_esrgan_url, download_dir)
self.__model_path_anime = self.__preload_model(
self.__model_esrgan_anime_url, download_dir
)
def upscale(self, image: Union[str, bytes]) -> bytes:
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4,
)
return self.__internal_upscale(image, self.__model_path, model)
def upscale_anime(self, image: Union[str, bytes]) -> bytes:
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4,
)
return self.__internal_upscale(image, self.__model_path_anime, model)
def __preload_model(self, url: str, download_dir: Path):
name = url.split("/")[-1]
if not os.path.exists(str(download_dir / name)):
return load_file_from_url(
url=url,
model_dir=str(download_dir),
progress=True,
file_name=None,
)
else:
return str(download_dir / name)
def __internal_upscale(
self,
image: Union[str, bytes],
model_path: str,
rrbdnet: RRDBNet,
) -> bytes:
if type(image) is str:
image = read_url(image)
upsampler = RealESRGANer(
scale=4, model_path=model_path, model=rrbdnet, half="fp16", gpu_id="0"
)
image_array = np.frombuffer(image, dtype=np.uint8)
input_image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
output, _ = upsampler.enhance(input_image, outscale=4)
out_bytes = cv2.imencode(".png", output)[1].tobytes()
return out_bytes