fuhaddesmond commited on
Commit
92ba8a3
·
verified ·
1 Parent(s): 1cfc13c

Upload handler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. handler.py +137 -0
handler.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Illuma (BLIP3o-NEXT-GRPO-TexT-3B) - Custom Handler for Hugging Face Inference Endpoints
3
+
4
+ This handler enables running the illuma image generation model as a production API
5
+ on Hugging Face Inference Endpoints with a dedicated GPU.
6
+
7
+ Architecture: Qwen2.5 VL AR (3B) + SANA 1.5 Diffusion Decoder
8
+ License: Apache 2.0
9
+ """
10
+
11
+ import os
12
+ import base64
13
+ import io
14
+ import torch
15
+ from typing import Any, Dict
16
+ from PIL import Image
17
+ from dataclasses import dataclass
18
+
19
+ from transformers import AutoTokenizer
20
+ from blip3o.model import *
21
+
22
+
23
+ @dataclass
24
+ class T2IConfig:
25
+ model_path: str = ""
26
+ device: str = "cuda:0"
27
+ dtype: torch.dtype = torch.bfloat16
28
+ scale: int = 0
29
+ seq_len: int = 729
30
+ top_p: float = 0.95
31
+ top_k: int = 1200
32
+
33
+
34
+ class EndpointHandler:
35
+ """Custom inference handler for Illuma (BLIP3o-NEXT) image generation."""
36
+
37
+ def __init__(self, model_dir: str, **kwargs: Any) -> None:
38
+ """Load the model and tokenizer on startup."""
39
+ self.config = T2IConfig(model_path=model_dir)
40
+ self.device = torch.device(self.config.device if torch.cuda.is_available() else "cpu")
41
+
42
+ print(f"[Illuma] Loading model from: {model_dir}")
43
+ print(f"[Illuma] Device: {self.device}")
44
+
45
+ self.model = blip3oQwenForInferenceLM.from_pretrained(
46
+ self.config.model_path,
47
+ torch_dtype=self.config.dtype
48
+ ).to(self.device)
49
+
50
+ self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_path)
51
+ print("[Illuma] Model loaded successfully!")
52
+
53
+ def __call__(self, data: Dict[str, Any]) -> Any:
54
+ """
55
+ Generate an image from a text prompt.
56
+
57
+ Input (JSON):
58
+ {
59
+ "inputs": "A neon sign that says HELLO",
60
+ "parameters": {
61
+ "seq_len": 729,
62
+ "top_p": 0.95,
63
+ "top_k": 1200,
64
+ "guidance_scale": 3.0
65
+ }
66
+ }
67
+
68
+ Output:
69
+ - Returns base64-encoded PNG image
70
+ - Or raw PNG bytes if Content-Type is set
71
+ """
72
+ # Extract prompt
73
+ prompt = data.get("inputs", "")
74
+ if not prompt:
75
+ return {"error": "No prompt provided. Send {'inputs': 'your prompt here'}"}
76
+
77
+ # Extract optional parameters
78
+ parameters = data.get("parameters", {})
79
+ seq_len = parameters.get("seq_len", self.config.seq_len)
80
+ top_p = parameters.get("top_p", self.config.top_p)
81
+ top_k = parameters.get("top_k", self.config.top_k)
82
+
83
+ print(f"[Illuma] Generating image for: {prompt[:100]}...")
84
+
85
+ try:
86
+ image = self._generate(prompt, seq_len, top_p, top_k)
87
+ return self._encode_image(image)
88
+ except Exception as e:
89
+ print(f"[Illuma] Error generating image: {e}")
90
+ return {"error": str(e)}
91
+
92
+ def _generate(self, prompt: str, seq_len: int, top_p: float, top_k: float) -> Image.Image:
93
+ """Generate image using the BLIP3o-NEXT inference pipeline."""
94
+ messages = [
95
+ {"role": "system", "content": "You are a helpful assistant."},
96
+ {"role": "user", "content": f"Please generate image based on the following caption: {prompt}"}
97
+ ]
98
+
99
+ input_text = self.tokenizer.apply_chat_template(
100
+ messages,
101
+ tokenize=False,
102
+ add_generation_prompt=True
103
+ )
104
+ input_text += "\n"
105
+
106
+ inputs = self.tokenizer(
107
+ [input_text],
108
+ return_tensors="pt",
109
+ padding=True,
110
+ truncation=True,
111
+ padding_side="left"
112
+ )
113
+
114
+ gen_ids, output_image = self.model.generate_images(
115
+ inputs.input_ids.to(self.device),
116
+ inputs.attention_mask.to(self.device),
117
+ max_new_tokens=seq_len,
118
+ do_sample=True,
119
+ top_p=top_p,
120
+ top_k=top_k
121
+ )
122
+
123
+ return output_image[0]
124
+
125
+ def _encode_image(self, image: Image.Image) -> Dict[str, str]:
126
+ """Encode PIL Image to base64 for API response."""
127
+ buffered = io.BytesIO()
128
+ image.save(buffered, format="PNG")
129
+ img_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
130
+ return {"image": img_b64}
131
+
132
+
133
+ # For local testing
134
+ if __name__ == "__main__":
135
+ handler = EndpointHandler(model_dir="Salesforce/BLIP3o-NEXT-GRPO-TexT-3B")
136
+ result = handler({"inputs": "A neon sign that says ILLUMA"})
137
+ print(f"Generated image, base64 length: {len(result.get('image', ''))}")