vtuber-maker / handler.py
Alexmikupro's picture
Update handler.py
f041201
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
from diffusers import StableDiffusionPipeline
import base64
from PIL import Image
from io import BytesIO
import torch
from torch.cuda.amp import autocast
from typing import Dict, Any
import numpy as np
# Setting the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class EndpointHandler():
def __init__(self, path=""):
# Load the StableDiffusionPipeline
self.pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float32)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.pipe = self.pipe.to(device)
# Load the ESRGAN state dictionary
checkpoint = torch.load("/repository/RealESRGAN_x4plus_anime_6B.pth")
# Check if 'params_ema' is in the keys and filter the state_dict
if "params_ema" in checkpoint:
state_dict = checkpoint["params_ema"]
else:
state_dict = checkpoint
# Define the ESRGAN model architecture
self.model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
self.model.load_state_dict(state_dict)
self.model.to(device)
self.model.eval()
# Create a RealESRGANer object for inference
self.upsampler = RealESRGANer(scale=4, model=self.model, tile=0, model_path="/repository/RealESRGAN_x4plus_anime_6B.pth")
def __call__(self, data: Dict[str, Any], output_size=(512, )) -> Dict[str, str]:
inputs = data.get("inputs")
negative_prompt = data.get("negative_prompt", None)
# Run StableDiffusionPipeline
with autocast():
output = self.pipe(inputs, guidance_scale=7.5, negative_prompt=negative_prompt)
image = output['images'][0]
# Normalize the image to [0, 1] range if it's not
image = np.clip(image, 0, 255) / 255.0
# Convert the StableDiffusionPipeline output to suitable format for ESRGAN
tensor_image = torch.from_numpy(np.array(image)).float().permute(2, 0, 1).unsqueeze(0).to(device)
# Process the image with ESRGAN
with torch.no_grad():
esrgan_output = self.model(tensor_image)
# Post-process the ESRGAN output to make it a PIL image
esrgan_output = esrgan_output.squeeze().permute(1, 2, 0).cpu().numpy()
esrgan_output = np.clip(esrgan_output, 0, 1) # Ensure the values are within [0, 1]
esrgan_image = Image.fromarray((esrgan_output * 255).astype('uint8'))
# Encoding ESRGAN image as base64
buffered = BytesIO()
esrgan_image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue())
return {"image": img_str.decode()}