import os import io import base64 import random from typing import Any, Dict import torch from PIL import Image from diffusers import FluxKontextPipeline # FLUX.1-Kontext-dev is a 12B rectified-flow transformer for instruction-based # image editing (and text-to-image when no input image is supplied). MAX_SEED = 2**31 - 1 def _decode_image(image_data: str) -> Image.Image: """Decode a base64 string (raw or a data URI) into an RGB PIL image.""" if image_data.startswith("data:"): # strip "data:image/png;base64," style prefixes image_data = image_data.split(",", 1)[1] raw = base64.b64decode(image_data) return Image.open(io.BytesIO(raw)).convert("RGB") def _encode_image(image: Image.Image) -> str: """Encode a PIL image as a base64 PNG string.""" buffer = io.BytesIO() image.save(buffer, format="PNG") return base64.b64encode(buffer.getvalue()).decode("utf-8") class EndpointHandler: def __init__(self, path: str = ""): # Load the Kontext pipeline from the local model weights. self.pipe = FluxKontextPipeline.from_pretrained( path, torch_dtype=torch.bfloat16, ) # Placement strategy. The model is ~24GB in bf16, so pick based on the # instance VRAM via the FLUX_OFFLOAD env var (set in the endpoint config): # "none" -> keep everything on GPU (fastest, needs ~40GB e.g. A100) # "model" -> enable_model_cpu_offload (works on ~24GB cards) # "sequential" -> enable_sequential_cpu_offload (lowest VRAM, slowest) offload = os.environ.get("FLUX_OFFLOAD", "model").lower() if offload == "sequential": self.pipe.enable_sequential_cpu_offload() elif offload == "model": self.pipe.enable_model_cpu_offload() elif torch.cuda.is_available(): self.pipe.to("cuda") # Small memory savings on the VAE; harmless if unsupported. try: self.pipe.vae.enable_slicing() self.pipe.vae.enable_tiling() except Exception: pass def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Expected request body (JSON): { "inputs": "Add a hat to the cat", # edit / generation prompt "image": "", # OPTIONAL input image to edit "parameters": { # all optional "guidance_scale": 2.5, "num_inference_steps": 30, "width": 1024, # multiples of 16 "height": 1024, # multiples of 16 "max_sequence_length": 512, "seed": 42 } } Returns: {"image": "", "format": "png", "seed": } """ # Prompt is conventionally under "inputs". prompt = data.get("inputs") if isinstance(prompt, dict): # tolerate clients that nest everything under "inputs" data = {**data, **prompt} prompt = data.get("prompt") if not prompt: return {"error": "Missing prompt. Provide the edit instruction under the 'inputs' key."} params = data.get("parameters") or {} # Optional input image -> editing mode. Without it -> text-to-image. image_b64 = data.get("image") or params.get("image") init_image = _decode_image(image_b64) if image_b64 else None # Kontext-friendly defaults. guidance_scale = float(params.get("guidance_scale", 2.5)) num_inference_steps = int(params.get("num_inference_steps", params.get("steps", 30))) max_sequence_length = int(params.get("max_sequence_length", 512)) width = params.get("width") height = params.get("height") seed = params.get("seed") seed = random.randint(0, MAX_SEED) if seed is None else int(seed) gen_device = "cuda" if torch.cuda.is_available() else "cpu" generator = torch.Generator(device=gen_device).manual_seed(seed) call_kwargs: Dict[str, Any] = { "prompt": prompt, "guidance_scale": guidance_scale, "num_inference_steps": num_inference_steps, "max_sequence_length": max_sequence_length, "generator": generator, } if init_image is not None: call_kwargs["image"] = init_image if width is not None: call_kwargs["width"] = int(width) if height is not None: call_kwargs["height"] = int(height) with torch.inference_mode(): result = self.pipe(**call_kwargs) return { "image": _encode_image(result.images[0]), "format": "png", "seed": seed, }