| | from typing import Dict |
| | import torch |
| | from diffusers import FluxKontextPipeline |
| | from io import BytesIO |
| | import base64 |
| | from PIL import Image |
| |
|
| | class EndpointHandler: |
| | def __init__(self, path: str = ""): |
| | print("π Initializing Flux Kontext pipeline...") |
| |
|
| | |
| | self.pipe = FluxKontextPipeline.from_pretrained( |
| | "black-forest-labs/FLUX.1-Kontext-dev", |
| | torch_dtype=torch.float16, |
| | ) |
| | self.pipe.to("cuda" if torch.cuda.is_available() else "cpu") |
| | print("β
Model ready.") |
| |
|
| | def __call__(self, data: Dict) -> Dict: |
| | print("π§ Received data:", data) |
| |
|
| | |
| | inputs = data.get("inputs") |
| | if not inputs or not isinstance(inputs, dict): |
| | return {"error": "'inputs' must be a JSON object containing 'prompt' and 'image'."} |
| |
|
| | prompt = inputs.get("prompt") |
| | image_input = inputs.get("image") |
| |
|
| | if not prompt: |
| | return {"error": "'prompt' is required in 'inputs'."} |
| | if not image_input: |
| | return {"error": "'image' (base64 encoded string) is required in 'inputs'."} |
| |
|
| | |
| | try: |
| | image_bytes = base64.b64decode(image_input) |
| | image = Image.open(BytesIO(image_bytes)).convert("RGB") |
| | except Exception as e: |
| | return {"error": f"Failed to decode 'image' input as base64: {str(e)}"} |
| |
|
| | |
| | try: |
| | output = self.pipe( |
| | prompt=prompt, |
| | image=image, |
| | num_inference_steps=28, |
| | guidance_scale=3.5 |
| | ).images[0] |
| | print("π¨ Image generated.") |
| | except Exception as e: |
| | return {"error": f"Model inference failed: {str(e)}"} |
| |
|
| | |
| | try: |
| | buffer = BytesIO() |
| | output.save(buffer, format="PNG") |
| | base64_image = base64.b64encode(buffer.getvalue()).decode("utf-8") |
| | print("β
Returning image.") |
| | return {"image": base64_image} |
| | except Exception as e: |
| | return {"error": f"Failed to encode output image: {str(e)}"} |
| |
|