File size: 1,562 Bytes
00bb2a2 bbc474e 00bb2a2 bbc474e 00bb2a2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 | from typing import Dict
import torch
from diffusers import StableDiffusionXLPipeline
from io import BytesIO
import base64
class EndpointHandler:
def __init__(self, path: str = ""):
print(f"Initializing SDXL model from: {path}")
# Load the base SDXL model
self.pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
variant="fp16"
)
print("Loading LoRA weights from: Texttra/Bh0r")
self.pipe.load_lora_weights(
"Texttra/Bh0r",
weight_name="Bh0r-10.safetensors",
adapter_name="bh0r_lora"
)
self.pipe.set_adapters(["bh0r_lora"], adapter_weights=[0.9])
self.pipe.fuse_lora()
self.pipe.to("cuda" if torch.cuda.is_available() else "cpu")
print("Model ready.")
def __call__(self, data: Dict) -> Dict:
print("Received data:", data)
inputs = data.get("inputs", {})
prompt = inputs.get("prompt", "")
print("Extracted prompt:", prompt)
if not prompt:
return {"error": "No prompt provided."}
image = self.pipe(
prompt,
num_inference_steps=35,
guidance_scale=7.0,
).images[0]
print("Image generated.")
buffer = BytesIO()
image.save(buffer, format="PNG")
base64_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
print("Returning image.")
return {"image": base64_image}
|