| | from typing import Dict |
| | import torch |
| | from diffusers import StableDiffusionXLPipeline |
| | from io import BytesIO |
| | import base64 |
| |
|
| | class EndpointHandler: |
| | def __init__(self, path: str = ""): |
| | print(f"Initializing SDXL model from: {path}") |
| |
|
| | |
| | self.pipe = StableDiffusionXLPipeline.from_pretrained( |
| | "stabilityai/stable-diffusion-xl-base-1.0", |
| | torch_dtype=torch.float16, |
| | variant="fp16" |
| | ) |
| |
|
| | print("Loading LoRA weights from: Texttra/Bh0r") |
| | self.pipe.load_lora_weights( |
| | "Texttra/Bh0r", |
| | weight_name="Bh0r-10.safetensors", |
| | adapter_name="bh0r_lora" |
| | ) |
| | self.pipe.set_adapters(["bh0r_lora"], adapter_weights=[0.9]) |
| | self.pipe.fuse_lora() |
| |
|
| | self.pipe.to("cuda" if torch.cuda.is_available() else "cpu") |
| | print("Model ready.") |
| |
|
| | def __call__(self, data: Dict) -> Dict: |
| | print("Received data:", data) |
| |
|
| | inputs = data.get("inputs", {}) |
| | prompt = inputs.get("prompt", "") |
| | print("Extracted prompt:", prompt) |
| |
|
| | if not prompt: |
| | return {"error": "No prompt provided."} |
| |
|
| | image = self.pipe( |
| | prompt, |
| | num_inference_steps=35, |
| | guidance_scale=7.0, |
| | ).images[0] |
| |
|
| | print("Image generated.") |
| |
|
| | buffer = BytesIO() |
| | image.save(buffer, format="PNG") |
| | base64_image = base64.b64encode(buffer.getvalue()).decode("utf-8") |
| | print("Returning image.") |
| |
|
| | return {"image": base64_image} |
| |
|