Forrest Wargo commited on
Commit
92c5d3d
·
0 Parent(s):

Add vLLM Qwen3-VL endpoint

Browse files
Files changed (2) hide show
  1. handler.py +213 -0
  2. requirements.txt +5 -0
handler.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import json
4
+ import os
5
+ from typing import Any, Dict, List, Optional, Tuple
6
+
7
+ from PIL import Image
8
+
9
+ from transformers import AutoProcessor
10
+ from vllm import LLM, SamplingParams
11
+
12
+
13
+ def _b64_to_pil(data_url: str) -> Image.Image:
14
+ if not isinstance(data_url, str) or not data_url.startswith("data:"):
15
+ raise ValueError("Expected a data URL starting with 'data:'")
16
+ header, b64data = data_url.split(",", 1)
17
+ raw = base64.b64decode(b64data)
18
+ img = Image.open(io.BytesIO(raw))
19
+ img.load()
20
+ return img
21
+
22
+
23
+ class EndpointHandler:
24
+ """HF Inference Endpoint handler for Qwen3-VL chat-to-point.
25
+
26
+ Input:
27
+ - { system, user, image(data URL) }
28
+ - or legacy OpenAI-style messages with image_url + text
29
+
30
+ Output:
31
+ - { points: [{x,y}], raw: <string> }
32
+ where x,y are normalized [0,1]
33
+ """
34
+
35
+ def __init__(self, path: str = "") -> None:
36
+ model_id = os.environ.get("MODEL_ID") or "Qwen/Qwen3-VL-8B-Instruct"
37
+
38
+ os.environ.setdefault("OMP_NUM_THREADS", "1")
39
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
40
+ os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
41
+ os.environ.setdefault("HF_HUB_ENABLE_QUIC", "1")
42
+ os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
43
+
44
+ # Auto TP detection from visible GPUs
45
+ visible = os.environ.get("CUDA_VISIBLE_DEVICES")
46
+ if visible and visible.strip():
47
+ try:
48
+ candidates = [d for d in visible.split(",") if d.strip() and d.strip() != "-1"]
49
+ tp = max(1, len(candidates))
50
+ except Exception:
51
+ tp = 1
52
+ else:
53
+ try:
54
+ import torch # local import to avoid global dependency if CPU-only
55
+ tp = max(1, int(torch.cuda.device_count())) if torch.cuda.is_available() else 1
56
+ except Exception:
57
+ tp = 1
58
+
59
+ self._model_id = model_id
60
+ self._tp = tp
61
+ self.llm = None # type: ignore
62
+
63
+ hub_token = (
64
+ os.environ.get("HUGGINGFACE_HUB_TOKEN")
65
+ or os.environ.get("HF_HUB_TOKEN")
66
+ or os.environ.get("HF_TOKEN")
67
+ )
68
+ if hub_token and not os.environ.get("HF_TOKEN"):
69
+ try:
70
+ os.environ["HF_TOKEN"] = hub_token
71
+ except Exception:
72
+ pass
73
+
74
+ self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True, token=hub_token)
75
+
76
+ def _ensure_llm(self) -> None:
77
+ if self.llm is not None:
78
+ return
79
+ self.llm = LLM(
80
+ model=self._model_id,
81
+ tensor_parallel_size=self._tp,
82
+ pipeline_parallel_size=1,
83
+ gpu_memory_utilization=0.95,
84
+ dtype="auto",
85
+ distributed_executor_backend="mp",
86
+ enforce_eager=True,
87
+ trust_remote_code=True,
88
+ )
89
+
90
+ @staticmethod
91
+ def _parse_legacy_messages(messages: List[Dict[str, Any]]) -> Tuple[Optional[str], Optional[str], Optional[str]]:
92
+ system_prompt: Optional[str] = None
93
+ first_image_data_url: Optional[str] = None
94
+ first_text: Optional[str] = None
95
+ for msg in messages:
96
+ if msg.get("role") == "system" and system_prompt is None:
97
+ content = msg.get("content")
98
+ if isinstance(content, str):
99
+ system_prompt = content
100
+ if msg.get("role") == "user":
101
+ content = msg.get("content", [])
102
+ if not isinstance(content, list):
103
+ continue
104
+ for part in content:
105
+ if part.get("type") == "image_url" and not first_image_data_url:
106
+ url = part.get("image_url", {}).get("url")
107
+ if isinstance(url, str) and url.startswith("data:"):
108
+ first_image_data_url = url
109
+ if part.get("type") == "text" and not first_text:
110
+ t = part.get("text")
111
+ if isinstance(t, str):
112
+ first_text = t
113
+ return system_prompt, first_text, first_image_data_url
114
+
115
+ def __call__(self, data: Dict[str, Any]) -> Any:
116
+ # Normalize HF toolkit payloads
117
+ if isinstance(data, dict) and "inputs" in data:
118
+ inputs_val = data.get("inputs")
119
+ if isinstance(inputs_val, dict):
120
+ data = inputs_val
121
+ elif isinstance(inputs_val, (str, bytes, bytearray)):
122
+ try:
123
+ if isinstance(inputs_val, (bytes, bytearray)):
124
+ inputs_val = inputs_val.decode("utf-8")
125
+ parsed = json.loads(inputs_val)
126
+ if isinstance(parsed, dict):
127
+ data = parsed
128
+ except Exception:
129
+ pass
130
+
131
+ system_prompt: Optional[str] = None
132
+ user_text: Optional[str] = None
133
+ image_data_url: Optional[str] = None
134
+
135
+ if isinstance(data, dict) and ("system" in data or "user" in data or "image" in data):
136
+ system_prompt = data.get("system")
137
+ user_text = data.get("user")
138
+ image_data_url = data.get("image")
139
+ if not isinstance(image_data_url, str) or not image_data_url.startswith("data:"):
140
+ return {"error": "image must be a data URL (data:...)"}
141
+ else:
142
+ messages = data.get("messages") if isinstance(data, dict) else None
143
+ if not messages:
144
+ return {"error": "Provide 'system','user','image' or legacy 'messages'"}
145
+ system_prompt, user_text, image_data_url = self._parse_legacy_messages(messages)
146
+ if not isinstance(image_data_url, str) or not image_data_url.startswith("data:"):
147
+ return {"error": "messages.content image_url.url must be a data URL (data:...)"}
148
+
149
+ try:
150
+ pil = _b64_to_pil(image_data_url)
151
+ except Exception as e:
152
+ return {"error": f"Failed to decode image: {e}"}
153
+
154
+ width = getattr(pil, "width", None)
155
+ height = getattr(pil, "height", None)
156
+ if isinstance(width, int) and isinstance(height, int):
157
+ try:
158
+ print(f"[qwen3-vl-endpoint] Received image size: {width}x{height}")
159
+ except Exception:
160
+ pass
161
+
162
+ if not isinstance(user_text, str):
163
+ return {"error": "user text must be provided"}
164
+
165
+ system_message = {"role": "system", "content": system_prompt or ""}
166
+ user_message = {
167
+ "role": "user",
168
+ "content": [
169
+ {"type": "image", "image": pil},
170
+ {"type": "text", "text": user_text},
171
+ ],
172
+ }
173
+
174
+ prompt = self.processor.apply_chat_template(
175
+ [system_message, user_message], tokenize=False, add_generation_prompt=True
176
+ )
177
+
178
+ request: Dict[str, Any] = {"prompt": prompt}
179
+ request["multi_modal_data"] = {"image": [pil]}
180
+
181
+ import time
182
+ t0 = time.time()
183
+ self._ensure_llm()
184
+ sampling_params = SamplingParams(max_tokens=32, temperature=0.0, top_p=1.0)
185
+ outputs = self.llm.generate([request], sampling_params=sampling_params, use_tqdm=False)
186
+ out_text = outputs[0].outputs[0].text
187
+ t1 = time.time()
188
+
189
+ try:
190
+ print(f"[qwen3-vl-endpoint] Prompt: {user_text}")
191
+ print(f"[qwen3-vl-endpoint] Raw output: {out_text}")
192
+ print(f"[qwen3-vl-endpoint] Inference time: {t1 - t0:.3f}s")
193
+ except Exception:
194
+ pass
195
+
196
+ try:
197
+ import re
198
+ m = re.findall(r"\((-?\d*\.?\d+),\s*(-?\d*\.?\d+)\)", out_text)
199
+ if not m:
200
+ return {"error": "Failed to parse coordinates from model output."}
201
+ x_str, y_str = m[0]
202
+ px, py = float(x_str), float(y_str)
203
+ if not isinstance(width, int) or not isinstance(height, int):
204
+ return {"error": "Missing image dimensions for normalization."}
205
+ w, h = float(width), float(height)
206
+ px = max(0.0, min(px, w))
207
+ py = max(0.0, min(py, h))
208
+ nx, ny = px / w, py / h
209
+ return {"points": [{"x": nx, "y": ny}], "raw": out_text}
210
+ except Exception as e:
211
+ return {"error": f"Postprocessing failed: {e}"}
212
+
213
+
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch==2.8.0
2
+ Pillow
3
+ transformers==4.57.1
4
+ vllm==0.11.0
5
+