File size: 4,865 Bytes
37256d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
import io
import base64
import random
from typing import Any, Dict

import torch
from PIL import Image
from diffusers import FluxKontextPipeline

# FLUX.1-Kontext-dev is a 12B rectified-flow transformer for instruction-based
# image editing (and text-to-image when no input image is supplied).
MAX_SEED = 2**31 - 1


def _decode_image(image_data: str) -> Image.Image:
    """Decode a base64 string (raw or a data URI) into an RGB PIL image."""
    if image_data.startswith("data:"):
        # strip "data:image/png;base64," style prefixes
        image_data = image_data.split(",", 1)[1]
    raw = base64.b64decode(image_data)
    return Image.open(io.BytesIO(raw)).convert("RGB")


def _encode_image(image: Image.Image) -> str:
    """Encode a PIL image as a base64 PNG string."""
    buffer = io.BytesIO()
    image.save(buffer, format="PNG")
    return base64.b64encode(buffer.getvalue()).decode("utf-8")


class EndpointHandler:
    def __init__(self, path: str = ""):
        # Load the Kontext pipeline from the local model weights.
        self.pipe = FluxKontextPipeline.from_pretrained(
            path,
            torch_dtype=torch.bfloat16,
        )

        # Placement strategy. The model is ~24GB in bf16, so pick based on the
        # instance VRAM via the FLUX_OFFLOAD env var (set in the endpoint config):
        #   "none"       -> keep everything on GPU (fastest, needs ~40GB e.g. A100)
        #   "model"      -> enable_model_cpu_offload  (works on ~24GB cards)
        #   "sequential" -> enable_sequential_cpu_offload (lowest VRAM, slowest)
        offload = os.environ.get("FLUX_OFFLOAD", "model").lower()
        if offload == "sequential":
            self.pipe.enable_sequential_cpu_offload()
        elif offload == "model":
            self.pipe.enable_model_cpu_offload()
        elif torch.cuda.is_available():
            self.pipe.to("cuda")

        # Small memory savings on the VAE; harmless if unsupported.
        try:
            self.pipe.vae.enable_slicing()
            self.pipe.vae.enable_tiling()
        except Exception:
            pass

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Expected request body (JSON):
          {
            "inputs": "Add a hat to the cat",          # edit / generation prompt
            "image": "<base64 png/jpg>",                # OPTIONAL input image to edit
            "parameters": {                             # all optional
              "guidance_scale": 2.5,
              "num_inference_steps": 30,
              "width": 1024,                            # multiples of 16
              "height": 1024,                           # multiples of 16
              "max_sequence_length": 512,
              "seed": 42
            }
          }

        Returns:
          {"image": "<base64 png>", "format": "png", "seed": <int used>}
        """
        # Prompt is conventionally under "inputs".
        prompt = data.get("inputs")
        if isinstance(prompt, dict):
            # tolerate clients that nest everything under "inputs"
            data = {**data, **prompt}
            prompt = data.get("prompt")
        if not prompt:
            return {"error": "Missing prompt. Provide the edit instruction under the 'inputs' key."}

        params = data.get("parameters") or {}

        # Optional input image -> editing mode. Without it -> text-to-image.
        image_b64 = data.get("image") or params.get("image")
        init_image = _decode_image(image_b64) if image_b64 else None

        # Kontext-friendly defaults.
        guidance_scale = float(params.get("guidance_scale", 2.5))
        num_inference_steps = int(params.get("num_inference_steps", params.get("steps", 30)))
        max_sequence_length = int(params.get("max_sequence_length", 512))
        width = params.get("width")
        height = params.get("height")

        seed = params.get("seed")
        seed = random.randint(0, MAX_SEED) if seed is None else int(seed)
        gen_device = "cuda" if torch.cuda.is_available() else "cpu"
        generator = torch.Generator(device=gen_device).manual_seed(seed)

        call_kwargs: Dict[str, Any] = {
            "prompt": prompt,
            "guidance_scale": guidance_scale,
            "num_inference_steps": num_inference_steps,
            "max_sequence_length": max_sequence_length,
            "generator": generator,
        }
        if init_image is not None:
            call_kwargs["image"] = init_image
        if width is not None:
            call_kwargs["width"] = int(width)
        if height is not None:
            call_kwargs["height"] = int(height)

        with torch.inference_mode():
            result = self.pipe(**call_kwargs)

        return {
            "image": _encode_image(result.images[0]),
            "format": "png",
            "seed": seed,
        }