from basicsr.archs.rrdbnet_arch import RRDBNet from realesrgan import RealESRGANer from diffusers import StableDiffusionPipeline import base64 from PIL import Image from io import BytesIO import torch from torch.cuda.amp import autocast from typing import Dict, Any import numpy as np # Setting the device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') class EndpointHandler(): def __init__(self, path=""): # Load the StableDiffusionPipeline self.pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float32) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.pipe = self.pipe.to(device) # Load the ESRGAN state dictionary checkpoint = torch.load("/repository/RealESRGAN_x4plus_anime_6B.pth") # Check if 'params_ema' is in the keys and filter the state_dict if "params_ema" in checkpoint: state_dict = checkpoint["params_ema"] else: state_dict = checkpoint # Define the ESRGAN model architecture self.model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) self.model.load_state_dict(state_dict) self.model.to(device) self.model.eval() # Create a RealESRGANer object for inference self.upsampler = RealESRGANer(scale=4, model=self.model, tile=0, model_path="/repository/RealESRGAN_x4plus_anime_6B.pth") def __call__(self, data: Dict[str, Any], output_size=(512, )) -> Dict[str, str]: inputs = data.get("inputs") negative_prompt = data.get("negative_prompt", None) # Run StableDiffusionPipeline with autocast(): output = self.pipe(inputs, guidance_scale=7.5, negative_prompt=negative_prompt) image = output['images'][0] # Normalize the image to [0, 1] range if it's not image = np.clip(image, 0, 255) / 255.0 # Convert the StableDiffusionPipeline output to suitable format for ESRGAN tensor_image = torch.from_numpy(np.array(image)).float().permute(2, 0, 1).unsqueeze(0).to(device) # Process the image with ESRGAN with torch.no_grad(): esrgan_output = self.model(tensor_image) # Post-process the ESRGAN output to make it a PIL image esrgan_output = esrgan_output.squeeze().permute(1, 2, 0).cpu().numpy() esrgan_output = np.clip(esrgan_output, 0, 1) # Ensure the values are within [0, 1] esrgan_image = Image.fromarray((esrgan_output * 255).astype('uint8')) # Encoding ESRGAN image as base64 buffered = BytesIO() esrgan_image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()) return {"image": img_str.decode()}