hentai / handler.py
AIhgenerator's picture
Update handler.py
b74eb89 verified
raw
history blame
3.09 kB
import torch
from diffusers import StableDiffusionXLPipeline, DiffusionPipeline, AutoencoderKL
from PIL import Image
from io import BytesIO
from typing import Dict, List, Any
import base64
class EndpointHandler():
def __init__(self, path=""):
self.model_base = "AIhgenerator/nsfwxxl2"
self.v_autoencoder = "madebyollin/sdxl-vae-fp16-fix"
self.model_refiner = "stabilityai/stable-diffusion-xl-refiner-1.0"
# Load the VAE model
self.vae = AutoencoderKL.from_pretrained(self.v_autoencoder, torch_dtype=torch.float16)
# Load the main pipeline
self.pipe = StableDiffusionXLPipeline.from_pretrained(
self.model_base,
torch_dtype=torch.float16,
vae=self.vae,
add_watermarker=False,
)
self.pipe.safety_checker = None
self.pipe.to("cuda")
# Load the refiner pipeline
self.pipe_refiner = DiffusionPipeline.from_pretrained(self.model_refiner, torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
self.pipe_refiner.enable_model_cpu_offload()
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
print("data",data)
# inputs = data.pop("inputs", data)
prompt, prompt2, negative_prompt, negative_prompt2 = data['prompt'], data['prompt2'], data['negative_prompt'], data['negative_prompt2']
print(prompt, prompt2, negative_prompt, negative_prompt2)
image_base_latent = self.pipe(
prompt=prompt,
prompt_2=prompt2,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt2,
guidance_scale=7.0,
height=1024,
width=1024,
num_inference_steps=25,
output_type="latent",
denoising_end=0.8 # Cut the base denoising in 80%
).images[0]
print("image base latent")
# image_base_latent = image_base_latent.to("cuda")
# Refine the image
image_refiner = self.pipe_refiner(
prompt=prompt,
prompt_2=prompt2,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt2,
image=image_base_latent,
num_inference_steps=25,
# height=1024,
# width=1024,
strength=0.3,
denoising_start=0.8
).images[0]
print("image refiner")
# # Convert the image to a format that can be easily outputted
buffer = BytesIO()
image_refiner.save(buffer, format="JPEG")
buffer.seek(0)
base64_encoded_result = base64.b64encode(buffer.read()).decode('utf-8')
return {"image": base64_encoded_result}