| | from typing import Dict, List, Any |
| | import torch |
| | from diffusers import DPMSolverMultistepScheduler, DiffusionPipeline |
| | from PIL import Image |
| | import base64 |
| | from io import BytesIO |
| |
|
| |
|
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | if device.type != "cuda": |
| | raise ValueError("need to run on GPU") |
| |
|
| |
|
| | class EndpointHandler: |
| | def __init__(self, path=""): |
| | |
| | self.base = DiffusionPipeline.from_pretrained( |
| | path, torch_dtype=torch.float16, variant="fp16", use_safetensors=True |
| | ) |
| | |
| | self.base.scheduler = DPMSolverMultistepScheduler.from_config( |
| | self.base.scheduler.config |
| | ) |
| | |
| | self.base = self.base.to(device) |
| | self.base.unet = torch.compile(self.base.unet, mode="reduce-overhead", fullgraph=True) |
| |
|
| | self.refiner = DiffusionPipeline.from_pretrained( |
| | "stabilityai/stable-diffusion-xl-refiner-1.0", |
| | text_encoder_2=self.base.text_encoder_2, |
| | vae=self.base.vae, |
| | torch_dtype=torch.float16, |
| | use_safetensors=True, |
| | variant="fp16", |
| | ) |
| | |
| | self.refiner.scheduler = DPMSolverMultistepScheduler.from_config( |
| | self.refiner.scheduler.config |
| | ) |
| | self.refiner = self.refiner.to(device) |
| | self.refiner.unet = torch.compile(self.refiner.unet, mode="reduce-overhead", fullgraph=True) |
| |
|
| | def __call__(self, data: Any) -> List[List[Dict[str, float]]]: |
| | """ |
| | :param data: A dictionary contains `inputs` and optional `image` field. |
| | :return: A dictionary with `image` field contains image in base64. |
| | """ |
| | prompt = data.pop("inputs", None) |
| |
|
| | if prompt is None: |
| | return {"error": "Please provide a prompt"} |
| | |
| |
|
| | |
| | use_refiner = True if data.pop("use_refiner", False) else False |
| | num_inference_steps = data.pop("num_inference_steps", 30) |
| | guidance_scale = data.pop("guidance_scale", 8) |
| | negative_prompt = data.pop("negative_prompt", None) |
| | high_noise_frac = data.pop("high_noise_frac", 0.8) |
| | height = data.pop("height", None) |
| | width = data.pop("width", None) |
| |
|
| | if use_refiner: |
| | image = self.base( |
| | prompt=prompt, |
| | num_inference_steps=num_inference_steps, |
| | denoising_end=high_noise_frac, |
| | output_type="latent", |
| | ).images |
| | out = self.refiner( |
| | prompt=prompt, |
| | num_inference_steps=num_inference_steps, |
| | denoising_start=high_noise_frac, |
| | image=image, |
| | ) |
| | else: |
| | out = self.base( |
| | prompt, |
| | num_inference_steps=num_inference_steps, |
| | guidance_scale=guidance_scale, |
| | num_images_per_prompt=1, |
| | negative_prompt=negative_prompt, |
| | height=height, |
| | width=width, |
| | ) |
| | |
| | |
| | buffered = BytesIO() |
| | out.images[0].save(buffered, format="JPEG") |
| | img_str = base64.b64encode(buffered.getvalue()) |
| |
|
| | |
| | return {"image": img_str.decode()} |
| |
|
| | |
| | |
| |
|