import base64 import io from PIL import Image import torch from diffusers import StableDiffusionXLPipeline from typing import Any, Dict class EndpointHandler: def __init__(self, model_dir: str, **kwargs): print("🔥 Initializing Juggernaut XL Handler (Prompt + Optional Image)...") # Load XL model from your big repo self.pipe = StableDiffusionXLPipeline.from_pretrained( "Gjm1234/juggernaut-sfw", torch_dtype=torch.float16, use_safetensors=True ).to("cuda") self.pipe.enable_attention_slicing() print("✅ Pipeline loaded successfully.") def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: # Must receive `inputs` if "inputs" not in data: return {"error": "Body must contain 'inputs' object"} inputs = data["inputs"] prompt = inputs.get("prompt", None) if not prompt: return {"error": "prompt is required"} num_images = inputs.get("num_images", 10) image_b64 = inputs.get("image", None) init_image = None if image_b64: try: img_bytes = base64.b64decode(image_b64) init_image = Image.open(io.BytesIO(img_bytes)).convert("RGB") except Exception as e: return {"error": f"Invalid image data: {str(e)}"} # Run txt2img OR img2img depending on whether image was sent if init_image is None: print("🎨 Running TEXT → IMAGE") output = self.pipe( prompt=prompt, num_images_per_prompt=num_images ) else: print("🎨 Running IMAGE → IMAGE") output = self.pipe( prompt=prompt, image=init_image, strength=0.6, num_images_per_prompt=num_images ) images = output.images # convert to base64 array results = [] for img in images: buffered = io.BytesIO() img.save(buffered, format="PNG") img_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8") results.append(img_b64) print(f"✅ Returning {len(results)} images.") return {"images": results}