|
|
from basicsr.archs.rrdbnet_arch import RRDBNet |
|
|
from realesrgan import RealESRGANer |
|
|
from diffusers import StableDiffusionPipeline |
|
|
import base64 |
|
|
from PIL import Image |
|
|
from io import BytesIO |
|
|
import torch |
|
|
from torch.cuda.amp import autocast |
|
|
from typing import Dict, Any |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
class EndpointHandler(): |
|
|
|
|
|
def __init__(self, path=""): |
|
|
|
|
|
self.pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float32) |
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
self.pipe = self.pipe.to(device) |
|
|
|
|
|
|
|
|
checkpoint = torch.load("/repository/RealESRGAN_x4plus_anime_6B.pth") |
|
|
|
|
|
|
|
|
if "params_ema" in checkpoint: |
|
|
state_dict = checkpoint["params_ema"] |
|
|
else: |
|
|
state_dict = checkpoint |
|
|
|
|
|
|
|
|
self.model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) |
|
|
self.model.load_state_dict(state_dict) |
|
|
self.model.to(device) |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
self.upsampler = RealESRGANer(scale=4, model=self.model, tile=0, model_path="/repository/RealESRGAN_x4plus_anime_6B.pth") |
|
|
|
|
|
|
|
|
def __call__(self, data: Dict[str, Any], output_size=(512, )) -> Dict[str, str]: |
|
|
inputs = data.get("inputs") |
|
|
negative_prompt = data.get("negative_prompt", None) |
|
|
|
|
|
|
|
|
with autocast(): |
|
|
output = self.pipe(inputs, guidance_scale=7.5, negative_prompt=negative_prompt) |
|
|
image = output['images'][0] |
|
|
|
|
|
|
|
|
image = np.clip(image, 0, 255) / 255.0 |
|
|
|
|
|
|
|
|
tensor_image = torch.from_numpy(np.array(image)).float().permute(2, 0, 1).unsqueeze(0).to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
esrgan_output = self.model(tensor_image) |
|
|
|
|
|
|
|
|
esrgan_output = esrgan_output.squeeze().permute(1, 2, 0).cpu().numpy() |
|
|
esrgan_output = np.clip(esrgan_output, 0, 1) |
|
|
esrgan_image = Image.fromarray((esrgan_output * 255).astype('uint8')) |
|
|
|
|
|
|
|
|
buffered = BytesIO() |
|
|
esrgan_image.save(buffered, format="PNG") |
|
|
img_str = base64.b64encode(buffered.getvalue()) |
|
|
|
|
|
return {"image": img_str.decode()} |