File size: 4,427 Bytes
92ba8a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""
Illuma (BLIP3o-NEXT-GRPO-TexT-3B) - Custom Handler for Hugging Face Inference Endpoints

This handler enables running the illuma image generation model as a production API
on Hugging Face Inference Endpoints with a dedicated GPU.

Architecture: Qwen2.5 VL AR (3B) + SANA 1.5 Diffusion Decoder
License: Apache 2.0
"""

import os
import base64
import io
import torch
from typing import Any, Dict
from PIL import Image
from dataclasses import dataclass

from transformers import AutoTokenizer
from blip3o.model import *


@dataclass
class T2IConfig:
    model_path: str = ""
    device: str = "cuda:0"
    dtype: torch.dtype = torch.bfloat16
    scale: int = 0
    seq_len: int = 729
    top_p: float = 0.95
    top_k: int = 1200


class EndpointHandler:
    """Custom inference handler for Illuma (BLIP3o-NEXT) image generation."""

    def __init__(self, model_dir: str, **kwargs: Any) -> None:
        """Load the model and tokenizer on startup."""
        self.config = T2IConfig(model_path=model_dir)
        self.device = torch.device(self.config.device if torch.cuda.is_available() else "cpu")

        print(f"[Illuma] Loading model from: {model_dir}")
        print(f"[Illuma] Device: {self.device}")

        self.model = blip3oQwenForInferenceLM.from_pretrained(
            self.config.model_path,
            torch_dtype=self.config.dtype
        ).to(self.device)

        self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_path)
        print("[Illuma] Model loaded successfully!")

    def __call__(self, data: Dict[str, Any]) -> Any:
        """
        Generate an image from a text prompt.

        Input (JSON):
        {
            "inputs": "A neon sign that says HELLO",
            "parameters": {
                "seq_len": 729,
                "top_p": 0.95,
                "top_k": 1200,
                "guidance_scale": 3.0
            }
        }

        Output:
        - Returns base64-encoded PNG image
        - Or raw PNG bytes if Content-Type is set
        """
        # Extract prompt
        prompt = data.get("inputs", "")
        if not prompt:
            return {"error": "No prompt provided. Send {'inputs': 'your prompt here'}"}

        # Extract optional parameters
        parameters = data.get("parameters", {})
        seq_len = parameters.get("seq_len", self.config.seq_len)
        top_p = parameters.get("top_p", self.config.top_p)
        top_k = parameters.get("top_k", self.config.top_k)

        print(f"[Illuma] Generating image for: {prompt[:100]}...")

        try:
            image = self._generate(prompt, seq_len, top_p, top_k)
            return self._encode_image(image)
        except Exception as e:
            print(f"[Illuma] Error generating image: {e}")
            return {"error": str(e)}

    def _generate(self, prompt: str, seq_len: int, top_p: float, top_k: float) -> Image.Image:
        """Generate image using the BLIP3o-NEXT inference pipeline."""
        messages = [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": f"Please generate image based on the following caption: {prompt}"}
        ]

        input_text = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        input_text += "\n"

        inputs = self.tokenizer(
            [input_text],
            return_tensors="pt",
            padding=True,
            truncation=True,
            padding_side="left"
        )

        gen_ids, output_image = self.model.generate_images(
            inputs.input_ids.to(self.device),
            inputs.attention_mask.to(self.device),
            max_new_tokens=seq_len,
            do_sample=True,
            top_p=top_p,
            top_k=top_k
        )

        return output_image[0]

    def _encode_image(self, image: Image.Image) -> Dict[str, str]:
        """Encode PIL Image to base64 for API response."""
        buffered = io.BytesIO()
        image.save(buffered, format="PNG")
        img_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
        return {"image": img_b64}


# For local testing
if __name__ == "__main__":
    handler = EndpointHandler(model_dir="Salesforce/BLIP3o-NEXT-GRPO-TexT-3B")
    result = handler({"inputs": "A neon sign that says ILLUMA"})
    print(f"Generated image, base64 length: {len(result.get('image', ''))}")