import torch from diffusers import StableDiffusionXLPipeline, DiffusionPipeline, AutoencoderKL from PIL import Image from io import BytesIO from typing import Dict, List, Any import base64 class EndpointHandler(): def __init__(self, path=""): self.model_base = "AIhgenerator/nsfwxxl2" self.v_autoencoder = "madebyollin/sdxl-vae-fp16-fix" self.model_refiner = "stabilityai/stable-diffusion-xl-refiner-1.0" # Load the VAE model self.vae = AutoencoderKL.from_pretrained(self.v_autoencoder, torch_dtype=torch.float16) # Load the main pipeline self.pipe = StableDiffusionXLPipeline.from_pretrained( self.model_base, torch_dtype=torch.float16, vae=self.vae, add_watermarker=False, ) self.pipe.safety_checker = None self.pipe.to("cuda") # Load the refiner pipeline self.pipe_refiner = DiffusionPipeline.from_pretrained(self.model_refiner, torch_dtype=torch.float16, use_safetensors=True, variant="fp16") self.pipe_refiner.enable_model_cpu_offload() def __call__(self, data: Any) -> List[List[Dict[str, float]]]: print("data",data) # inputs = data.pop("inputs", data) prompt, prompt2, negative_prompt, negative_prompt2 = data['prompt'], data['prompt2'], data['negative_prompt'], data['negative_prompt2'] print(prompt, prompt2, negative_prompt, negative_prompt2) image_base_latent = self.pipe( prompt=prompt, prompt_2=prompt2, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt2, guidance_scale=7.0, height=1024, width=1024, num_inference_steps=25, output_type="latent", denoising_end=0.8 # Cut the base denoising in 80% ).images[0] print("image base latent") # image_base_latent = image_base_latent.to("cuda") # Refine the image image_refiner = self.pipe_refiner( prompt=prompt, prompt_2=prompt2, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt2, image=image_base_latent, num_inference_steps=25, # height=1024, # width=1024, strength=0.3, denoising_start=0.8 ).images[0] print("image refiner") # # Convert the image to a format that can be easily outputted buffer = BytesIO() image_refiner.save(buffer, format="JPEG") buffer.seek(0) base64_encoded_result = base64.b64encode(buffer.read()).decode('utf-8') return {"image": base64_encoded_result}