worksimpli's picture
Add files using upload-large-folder tool
37256d6 verified
import os
import io
import base64
import random
from typing import Any, Dict
import torch
from PIL import Image
from diffusers import FluxKontextPipeline
# FLUX.1-Kontext-dev is a 12B rectified-flow transformer for instruction-based
# image editing (and text-to-image when no input image is supplied).
MAX_SEED = 2**31 - 1
def _decode_image(image_data: str) -> Image.Image:
"""Decode a base64 string (raw or a data URI) into an RGB PIL image."""
if image_data.startswith("data:"):
# strip "data:image/png;base64," style prefixes
image_data = image_data.split(",", 1)[1]
raw = base64.b64decode(image_data)
return Image.open(io.BytesIO(raw)).convert("RGB")
def _encode_image(image: Image.Image) -> str:
"""Encode a PIL image as a base64 PNG string."""
buffer = io.BytesIO()
image.save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")
class EndpointHandler:
def __init__(self, path: str = ""):
# Load the Kontext pipeline from the local model weights.
self.pipe = FluxKontextPipeline.from_pretrained(
path,
torch_dtype=torch.bfloat16,
)
# Placement strategy. The model is ~24GB in bf16, so pick based on the
# instance VRAM via the FLUX_OFFLOAD env var (set in the endpoint config):
# "none" -> keep everything on GPU (fastest, needs ~40GB e.g. A100)
# "model" -> enable_model_cpu_offload (works on ~24GB cards)
# "sequential" -> enable_sequential_cpu_offload (lowest VRAM, slowest)
offload = os.environ.get("FLUX_OFFLOAD", "model").lower()
if offload == "sequential":
self.pipe.enable_sequential_cpu_offload()
elif offload == "model":
self.pipe.enable_model_cpu_offload()
elif torch.cuda.is_available():
self.pipe.to("cuda")
# Small memory savings on the VAE; harmless if unsupported.
try:
self.pipe.vae.enable_slicing()
self.pipe.vae.enable_tiling()
except Exception:
pass
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Expected request body (JSON):
{
"inputs": "Add a hat to the cat", # edit / generation prompt
"image": "<base64 png/jpg>", # OPTIONAL input image to edit
"parameters": { # all optional
"guidance_scale": 2.5,
"num_inference_steps": 30,
"width": 1024, # multiples of 16
"height": 1024, # multiples of 16
"max_sequence_length": 512,
"seed": 42
}
}
Returns:
{"image": "<base64 png>", "format": "png", "seed": <int used>}
"""
# Prompt is conventionally under "inputs".
prompt = data.get("inputs")
if isinstance(prompt, dict):
# tolerate clients that nest everything under "inputs"
data = {**data, **prompt}
prompt = data.get("prompt")
if not prompt:
return {"error": "Missing prompt. Provide the edit instruction under the 'inputs' key."}
params = data.get("parameters") or {}
# Optional input image -> editing mode. Without it -> text-to-image.
image_b64 = data.get("image") or params.get("image")
init_image = _decode_image(image_b64) if image_b64 else None
# Kontext-friendly defaults.
guidance_scale = float(params.get("guidance_scale", 2.5))
num_inference_steps = int(params.get("num_inference_steps", params.get("steps", 30)))
max_sequence_length = int(params.get("max_sequence_length", 512))
width = params.get("width")
height = params.get("height")
seed = params.get("seed")
seed = random.randint(0, MAX_SEED) if seed is None else int(seed)
gen_device = "cuda" if torch.cuda.is_available() else "cpu"
generator = torch.Generator(device=gen_device).manual_seed(seed)
call_kwargs: Dict[str, Any] = {
"prompt": prompt,
"guidance_scale": guidance_scale,
"num_inference_steps": num_inference_steps,
"max_sequence_length": max_sequence_length,
"generator": generator,
}
if init_image is not None:
call_kwargs["image"] = init_image
if width is not None:
call_kwargs["width"] = int(width)
if height is not None:
call_kwargs["height"] = int(height)
with torch.inference_mode():
result = self.pipe(**call_kwargs)
return {
"image": _encode_image(result.images[0]),
"format": "png",
"seed": seed,
}