Gjm1234's picture
Update handler.py
e134161 verified
import base64
import io
from PIL import Image
import torch
from diffusers import StableDiffusionXLPipeline
from typing import Any, Dict
class EndpointHandler:
def __init__(self, model_dir: str, **kwargs):
print("πŸ”₯ Initializing Juggernaut XL Handler (Prompt + Optional Image)...")
# Load XL model from your big repo
self.pipe = StableDiffusionXLPipeline.from_pretrained(
"Gjm1234/juggernaut-sfw",
torch_dtype=torch.float16,
use_safetensors=True
).to("cuda")
self.pipe.enable_attention_slicing()
print("βœ… Pipeline loaded successfully.")
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
# Must receive `inputs`
if "inputs" not in data:
return {"error": "Body must contain 'inputs' object"}
inputs = data["inputs"]
prompt = inputs.get("prompt", None)
if not prompt:
return {"error": "prompt is required"}
num_images = inputs.get("num_images", 10)
image_b64 = inputs.get("image", None)
init_image = None
if image_b64:
try:
img_bytes = base64.b64decode(image_b64)
init_image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
except Exception as e:
return {"error": f"Invalid image data: {str(e)}"}
# Run txt2img OR img2img depending on whether image was sent
if init_image is None:
print("🎨 Running TEXT β†’ IMAGE")
output = self.pipe(
prompt=prompt,
num_images_per_prompt=num_images
)
else:
print("🎨 Running IMAGE β†’ IMAGE")
output = self.pipe(
prompt=prompt,
image=init_image,
strength=0.6,
num_images_per_prompt=num_images
)
images = output.images
# convert to base64 array
results = []
for img in images:
buffered = io.BytesIO()
img.save(buffered, format="PNG")
img_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
results.append(img_b64)
print(f"βœ… Returning {len(results)} images.")
return {"images": results}