|
|
import base64 |
|
|
import io |
|
|
from PIL import Image |
|
|
import torch |
|
|
from diffusers import StableDiffusionXLPipeline |
|
|
from typing import Any, Dict |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, model_dir: str, **kwargs): |
|
|
print("π₯ Initializing Juggernaut XL Handler (Prompt + Optional Image)...") |
|
|
|
|
|
|
|
|
self.pipe = StableDiffusionXLPipeline.from_pretrained( |
|
|
"Gjm1234/juggernaut-sfw", |
|
|
torch_dtype=torch.float16, |
|
|
use_safetensors=True |
|
|
).to("cuda") |
|
|
|
|
|
self.pipe.enable_attention_slicing() |
|
|
|
|
|
print("β
Pipeline loaded successfully.") |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
|
|
|
|
|
|
if "inputs" not in data: |
|
|
return {"error": "Body must contain 'inputs' object"} |
|
|
|
|
|
inputs = data["inputs"] |
|
|
|
|
|
prompt = inputs.get("prompt", None) |
|
|
if not prompt: |
|
|
return {"error": "prompt is required"} |
|
|
|
|
|
num_images = inputs.get("num_images", 10) |
|
|
image_b64 = inputs.get("image", None) |
|
|
|
|
|
init_image = None |
|
|
if image_b64: |
|
|
try: |
|
|
img_bytes = base64.b64decode(image_b64) |
|
|
init_image = Image.open(io.BytesIO(img_bytes)).convert("RGB") |
|
|
except Exception as e: |
|
|
return {"error": f"Invalid image data: {str(e)}"} |
|
|
|
|
|
|
|
|
if init_image is None: |
|
|
print("π¨ Running TEXT β IMAGE") |
|
|
output = self.pipe( |
|
|
prompt=prompt, |
|
|
num_images_per_prompt=num_images |
|
|
) |
|
|
else: |
|
|
print("π¨ Running IMAGE β IMAGE") |
|
|
output = self.pipe( |
|
|
prompt=prompt, |
|
|
image=init_image, |
|
|
strength=0.6, |
|
|
num_images_per_prompt=num_images |
|
|
) |
|
|
|
|
|
images = output.images |
|
|
|
|
|
|
|
|
results = [] |
|
|
for img in images: |
|
|
buffered = io.BytesIO() |
|
|
img.save(buffered, format="PNG") |
|
|
img_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
|
results.append(img_b64) |
|
|
|
|
|
print(f"β
Returning {len(results)} images.") |
|
|
return {"images": results} |