import base64 import io import json import os from typing import Any, Dict, List, Optional, Tuple from PIL import Image from transformers import AutoProcessor from vllm import LLM, SamplingParams def _b64_to_pil(data_url: str) -> Image.Image: if not isinstance(data_url, str) or not data_url.startswith("data:"): raise ValueError("Expected a data URL starting with 'data:'") header, b64data = data_url.split(",", 1) raw = base64.b64decode(b64data) img = Image.open(io.BytesIO(raw)) img.load() return img class EndpointHandler: """HF Inference Endpoint handler for Qwen3-VL chat-to-point. Input: - { system, user, image(data URL) } - or legacy OpenAI-style messages with image_url + text Output: - { points: [{x,y}], raw: } where x,y are normalized [0,1] """ def __init__(self, path: str = "") -> None: model_id = os.environ.get("MODEL_ID") or "Qwen/Qwen3-VL-8B-Instruct" os.environ.setdefault("OMP_NUM_THREADS", "1") os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") os.environ.setdefault("HF_HUB_ENABLE_QUIC", "1") os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn") # Auto TP detection from visible GPUs visible = os.environ.get("CUDA_VISIBLE_DEVICES") if visible and visible.strip(): try: candidates = [d for d in visible.split(",") if d.strip() and d.strip() != "-1"] tp = max(1, len(candidates)) except Exception: tp = 1 else: try: import torch # local import to avoid global dependency if CPU-only tp = max(1, int(torch.cuda.device_count())) if torch.cuda.is_available() else 1 except Exception: tp = 1 self._model_id = model_id self._tp = tp self.llm = None # type: ignore hub_token = ( os.environ.get("HUGGINGFACE_HUB_TOKEN") or os.environ.get("HF_HUB_TOKEN") or os.environ.get("HF_TOKEN") ) if hub_token and not os.environ.get("HF_TOKEN"): try: os.environ["HF_TOKEN"] = hub_token except Exception: pass self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True, token=hub_token) def _ensure_llm(self) -> None: if self.llm is not None: return self.llm = LLM( model=self._model_id, tensor_parallel_size=self._tp, pipeline_parallel_size=1, gpu_memory_utilization=0.95, max_model_len=8192, dtype="auto", distributed_executor_backend="mp", enforce_eager=True, trust_remote_code=True, ) @staticmethod def _parse_legacy_messages(messages: List[Dict[str, Any]]) -> Tuple[Optional[str], Optional[str], Optional[str]]: system_prompt: Optional[str] = None first_image_data_url: Optional[str] = None first_text: Optional[str] = None for msg in messages: if msg.get("role") == "system" and system_prompt is None: content = msg.get("content") if isinstance(content, str): system_prompt = content if msg.get("role") == "user": content = msg.get("content", []) if not isinstance(content, list): continue for part in content: if part.get("type") == "image_url" and not first_image_data_url: url = part.get("image_url", {}).get("url") if isinstance(url, str) and url.startswith("data:"): first_image_data_url = url if part.get("type") == "text" and not first_text: t = part.get("text") if isinstance(t, str): first_text = t return system_prompt, first_text, first_image_data_url def __call__(self, data: Dict[str, Any]) -> Any: # Normalize HF toolkit payloads if isinstance(data, dict) and "inputs" in data: inputs_val = data.get("inputs") if isinstance(inputs_val, dict): data = inputs_val elif isinstance(inputs_val, (str, bytes, bytearray)): try: if isinstance(inputs_val, (bytes, bytearray)): inputs_val = inputs_val.decode("utf-8") parsed = json.loads(inputs_val) if isinstance(parsed, dict): data = parsed except Exception: pass system_prompt: Optional[str] = None user_text: Optional[str] = None image_data_url: Optional[str] = None if isinstance(data, dict) and ("system" in data or "user" in data or "image" in data): system_prompt = data.get("system") user_text = data.get("user") image_data_url = data.get("image") if not isinstance(image_data_url, str) or not image_data_url.startswith("data:"): return {"error": "image must be a data URL (data:...)"} else: messages = data.get("messages") if isinstance(data, dict) else None if not messages: return {"error": "Provide 'system','user','image' or legacy 'messages'"} system_prompt, user_text, image_data_url = self._parse_legacy_messages(messages) if not isinstance(image_data_url, str) or not image_data_url.startswith("data:"): return {"error": "messages.content image_url.url must be a data URL (data:...)"} try: pil = _b64_to_pil(image_data_url) except Exception as e: return {"error": f"Failed to decode image: {e}"} width = getattr(pil, "width", None) height = getattr(pil, "height", None) if isinstance(width, int) and isinstance(height, int): try: print(f"[qwen3-vl-endpoint] Received image size: {width}x{height}") except Exception: pass if not isinstance(user_text, str): return {"error": "user text must be provided"} system_message = {"role": "system", "content": system_prompt or ""} user_message = { "role": "user", "content": [ {"type": "image", "image": pil}, {"type": "text", "text": user_text}, ], } prompt = self.processor.apply_chat_template( [system_message, user_message], tokenize=False, add_generation_prompt=True ) request: Dict[str, Any] = {"prompt": prompt} request["multi_modal_data"] = {"image": [pil]} import time t0 = time.time() self._ensure_llm() sampling_params = SamplingParams(max_tokens=16, temperature=0.0, top_p=1.0) outputs = self.llm.generate([request], sampling_params=sampling_params, use_tqdm=False) out_text = outputs[0].outputs[0].text out_text_short = out_text[:20] t1 = time.time() try: print(f"[qwen3-vl-endpoint] Prompt: {user_text}") print(f"[qwen3-vl-endpoint] Raw output: {out_text_short}") print(f"[qwen3-vl-endpoint] Inference time: {t1 - t0:.3f}s") except Exception: pass try: import re m = re.findall(r"\((-?\d*\.?\d+),\s*(-?\d*\.?\d+)\)", out_text) if not m: return {"error": "Failed to parse coordinates from model output."} x_str, y_str = m[0] px, py = float(x_str), float(y_str) if not isinstance(width, int) or not isinstance(height, int): return {"error": "Missing image dimensions for normalization."} w, h = float(width), float(height) px = max(0.0, min(px, w)) py = max(0.0, min(py, h)) nx, ny = px / w, py / h return {"points": [{"x": nx, "y": ny}], "raw": out_text_short} except Exception as e: return {"error": f"Postprocessing failed: {e}"}