illuma / handler.py
fuhaddesmond's picture
Upload handler.py with huggingface_hub
92ba8a3 verified
"""
Illuma (BLIP3o-NEXT-GRPO-TexT-3B) - Custom Handler for Hugging Face Inference Endpoints
This handler enables running the illuma image generation model as a production API
on Hugging Face Inference Endpoints with a dedicated GPU.
Architecture: Qwen2.5 VL AR (3B) + SANA 1.5 Diffusion Decoder
License: Apache 2.0
"""
import os
import base64
import io
import torch
from typing import Any, Dict
from PIL import Image
from dataclasses import dataclass
from transformers import AutoTokenizer
from blip3o.model import *
@dataclass
class T2IConfig:
model_path: str = ""
device: str = "cuda:0"
dtype: torch.dtype = torch.bfloat16
scale: int = 0
seq_len: int = 729
top_p: float = 0.95
top_k: int = 1200
class EndpointHandler:
"""Custom inference handler for Illuma (BLIP3o-NEXT) image generation."""
def __init__(self, model_dir: str, **kwargs: Any) -> None:
"""Load the model and tokenizer on startup."""
self.config = T2IConfig(model_path=model_dir)
self.device = torch.device(self.config.device if torch.cuda.is_available() else "cpu")
print(f"[Illuma] Loading model from: {model_dir}")
print(f"[Illuma] Device: {self.device}")
self.model = blip3oQwenForInferenceLM.from_pretrained(
self.config.model_path,
torch_dtype=self.config.dtype
).to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_path)
print("[Illuma] Model loaded successfully!")
def __call__(self, data: Dict[str, Any]) -> Any:
"""
Generate an image from a text prompt.
Input (JSON):
{
"inputs": "A neon sign that says HELLO",
"parameters": {
"seq_len": 729,
"top_p": 0.95,
"top_k": 1200,
"guidance_scale": 3.0
}
}
Output:
- Returns base64-encoded PNG image
- Or raw PNG bytes if Content-Type is set
"""
# Extract prompt
prompt = data.get("inputs", "")
if not prompt:
return {"error": "No prompt provided. Send {'inputs': 'your prompt here'}"}
# Extract optional parameters
parameters = data.get("parameters", {})
seq_len = parameters.get("seq_len", self.config.seq_len)
top_p = parameters.get("top_p", self.config.top_p)
top_k = parameters.get("top_k", self.config.top_k)
print(f"[Illuma] Generating image for: {prompt[:100]}...")
try:
image = self._generate(prompt, seq_len, top_p, top_k)
return self._encode_image(image)
except Exception as e:
print(f"[Illuma] Error generating image: {e}")
return {"error": str(e)}
def _generate(self, prompt: str, seq_len: int, top_p: float, top_k: float) -> Image.Image:
"""Generate image using the BLIP3o-NEXT inference pipeline."""
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": f"Please generate image based on the following caption: {prompt}"}
]
input_text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
input_text += "\n"
inputs = self.tokenizer(
[input_text],
return_tensors="pt",
padding=True,
truncation=True,
padding_side="left"
)
gen_ids, output_image = self.model.generate_images(
inputs.input_ids.to(self.device),
inputs.attention_mask.to(self.device),
max_new_tokens=seq_len,
do_sample=True,
top_p=top_p,
top_k=top_k
)
return output_image[0]
def _encode_image(self, image: Image.Image) -> Dict[str, str]:
"""Encode PIL Image to base64 for API response."""
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
return {"image": img_b64}
# For local testing
if __name__ == "__main__":
handler = EndpointHandler(model_dir="Salesforce/BLIP3o-NEXT-GRPO-TexT-3B")
result = handler({"inputs": "A neon sign that says ILLUMA"})
print(f"Generated image, base64 length: {len(result.get('image', ''))}")