File size: 2,308 Bytes
a91abf3 e134161 aadf329 e134161 93b4067 e134161 aadf329 e134161 aadf329 93b4067 e134161 aadf329 e134161 aadf329 e134161 aadf329 e134161 aadf329 e134161 aadf329 e134161 aadf329 e134161 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import base64
import io
from PIL import Image
import torch
from diffusers import StableDiffusionXLPipeline
from typing import Any, Dict
class EndpointHandler:
def __init__(self, model_dir: str, **kwargs):
print("🔥 Initializing Juggernaut XL Handler (Prompt + Optional Image)...")
# Load XL model from your big repo
self.pipe = StableDiffusionXLPipeline.from_pretrained(
"Gjm1234/juggernaut-sfw",
torch_dtype=torch.float16,
use_safetensors=True
).to("cuda")
self.pipe.enable_attention_slicing()
print("✅ Pipeline loaded successfully.")
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
# Must receive `inputs`
if "inputs" not in data:
return {"error": "Body must contain 'inputs' object"}
inputs = data["inputs"]
prompt = inputs.get("prompt", None)
if not prompt:
return {"error": "prompt is required"}
num_images = inputs.get("num_images", 10)
image_b64 = inputs.get("image", None)
init_image = None
if image_b64:
try:
img_bytes = base64.b64decode(image_b64)
init_image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
except Exception as e:
return {"error": f"Invalid image data: {str(e)}"}
# Run txt2img OR img2img depending on whether image was sent
if init_image is None:
print("🎨 Running TEXT → IMAGE")
output = self.pipe(
prompt=prompt,
num_images_per_prompt=num_images
)
else:
print("🎨 Running IMAGE → IMAGE")
output = self.pipe(
prompt=prompt,
image=init_image,
strength=0.6,
num_images_per_prompt=num_images
)
images = output.images
# convert to base64 array
results = []
for img in images:
buffered = io.BytesIO()
img.save(buffered, format="PNG")
img_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
results.append(img_b64)
print(f"✅ Returning {len(results)} images.")
return {"images": results} |