gta1-endpoint / handler.py
Forrest Wargo
removing 32b
6edaebd
import base64
import json
import io
import os
from typing import Any, Dict, List, Optional
from PIL import Image
from transformers import AutoProcessor
from qwen_vl_utils import process_vision_info
import re
from vllm import LLM, SamplingParams
def _b64_to_pil(data_url: str) -> Image.Image:
if not isinstance(data_url, str) or not data_url.startswith("data:"):
raise ValueError("Expected a data URL starting with 'data:'")
header, b64data = data_url.split(',', 1)
raw = base64.b64decode(b64data)
img = Image.open(io.BytesIO(raw))
img.load()
return img
class EndpointHandler:
"""Custom handler for Hugging Face Inference Endpoints (Qwen2.5-VL).
Input (OpenAI-style):
{ "messages": [ { "role":"user", "content": [ {"type":"image_url","image_url":{"url":"data:..."}}, {"type":"text","text":"..."} ] } ] }
Output: { raw: string, width?: number, height?: number }
"""
def __init__(self, path: str = "") -> None:
# Always default to 7B unless MODEL_ID explicitly overrides
model_id = os.environ.get("MODEL_ID") or "HelloKKMe/GTA1-7B"
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
# Speed up first-time HF downloads and enable optimized transport
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
os.environ.setdefault("HF_HUB_ENABLE_QUIC", "1")
hub_token = (
os.environ.get("HUGGINGFACE_HUB_TOKEN")
or os.environ.get("HF_HUB_TOKEN")
or os.environ.get("HF_TOKEN")
)
# Ensure vLLM can pull gated repos if needed
if hub_token and not os.environ.get("HF_TOKEN"):
try:
os.environ["HF_TOKEN"] = hub_token
except Exception:
pass
# Auto-detect tensor parallel size from visible devices
# Default to 'spawn' which is safest across managed environments
os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
visible = os.environ.get("CUDA_VISIBLE_DEVICES")
if visible and visible.strip():
try:
candidates = [d for d in visible.split(",") if d.strip() and d.strip() != "-1"]
tp = max(1, len(candidates))
except Exception:
tp = 1
else:
try:
import torch # local import to avoid global requirement if not using CUDA
tp = max(1, int(torch.cuda.device_count())) if torch.cuda.is_available() else 1
except Exception:
tp = 1
# Defer vLLM engine init to first request to avoid startup failures
self._model_id = model_id
self._tp = tp
self.llm = None # type: ignore
self.processor = AutoProcessor.from_pretrained(
model_id, trust_remote_code=True, token=hub_token
)
def _ensure_llm(self) -> None:
if self.llm is not None:
return
self.llm = LLM(
model=self._model_id,
tensor_parallel_size=self._tp,
pipeline_parallel_size=1,
gpu_memory_utilization=0.95,
dtype="auto",
distributed_executor_backend="mp",
enforce_eager=True,
trust_remote_code=True,
)
def __call__(self, data: Dict[str, Any]) -> Any:
# Normalize HF Endpoint payloads
if isinstance(data, dict) and "inputs" in data:
inputs_val = data.get("inputs")
if isinstance(inputs_val, dict):
data = inputs_val
elif isinstance(inputs_val, (str, bytes, bytearray)):
try:
if isinstance(inputs_val, (bytes, bytearray)):
inputs_val = inputs_val.decode("utf-8")
parsed = json.loads(inputs_val)
if isinstance(parsed, dict):
data = parsed
except Exception:
pass
# New input contract: expect 'system', 'user', and 'image' (data URL). Fallback to messages for compatibility.
img_for_dims: Optional[Image.Image] = None
system_prompt: Optional[str] = None
user_text: Optional[str] = None
image_data_url: Optional[str] = None
if isinstance(data, dict) and ("system" in data or "user" in data or "image" in data):
system_prompt = data.get("system")
user_text = data.get("user")
image_data_url = data.get("image")
if not isinstance(image_data_url, str) or not image_data_url.startswith("data:"):
return {"error": "image must be a data URL (data:...)"}
try:
img_for_dims = _b64_to_pil(image_data_url)
except Exception as e:
return {"error": f"Failed to decode image: {e}"}
messages = [
{"role": "system", "content": system_prompt or ""},
{
"role": "user",
"content": [
{"type": "image", "image": img_for_dims},
{"type": "text", "text": user_text or ""},
],
},
]
else:
messages = data.get("messages")
if not messages:
return {"error": "Provide 'system','user','image' or legacy 'messages'"}
normalized: List[Dict[str, Any]] = []
first_img: Optional[Image.Image] = None
for msg in messages:
if msg.get("role") == "system" and system_prompt is None:
system_prompt = msg.get("content") if isinstance(msg.get("content"), str) else None
if msg.get("role") == "user":
content = msg.get("content", [])
image_url: Optional[str] = None
text_piece: Optional[str] = None
for part in content:
if part.get("type") == "image_url":
image_url = part.get("image_url", {}).get("url")
elif part.get("type") == "text":
text_piece = part.get("text")
if not image_url or not text_piece:
return {"error": "Content must include image_url (data URL) and text."}
if not isinstance(image_url, str) or not image_url.startswith("data:"):
return {"error": "image_url.url must be a data URL (data:...)"}
try:
img_for_dims = _b64_to_pil(image_url)
first_img = first_img or img_for_dims
except Exception:
img_for_dims = None
user_text = user_text or text_piece
normalized.append(
{
"role": "user",
"content": [
{"type": "image", "image": image_url},
{"type": "text", "text": text_piece},
],
}
)
messages = [{"role": "system", "content": system_prompt or ""}] + normalized
if first_img is not None:
img_for_dims = first_img
width = getattr(img_for_dims, "width", None)
height = getattr(img_for_dims, "height", None)
if width and height:
try:
print(f"[gta1-endpoint] Received image size: {width}x{height}")
except Exception:
pass
if not isinstance(img_for_dims, Image.Image) or not isinstance(user_text, str):
return {"error": "Failed to prepare image/text for inference."}
# Build system + user messages with the original image (no pre-resize)
system_message = {"role": "system", "content": system_prompt or ""}
user_message = {
"role": "user",
"content": [
{"type": "image", "image": img_for_dims},
{"type": "text", "text": user_text},
],
}
image_inputs, video_inputs = process_vision_info([system_message, user_message])
text = self.processor.apply_chat_template(
[system_message, user_message], tokenize=False, add_generation_prompt=True
)
request: Dict[str, Any] = {"prompt": text}
if image_inputs:
request["multi_modal_data"] = {"image": image_inputs}
import time
t_start = time.time()
self._ensure_llm()
sampling_params = SamplingParams(max_tokens=32, temperature=0.0, top_p=1.0)
outputs = self.llm.generate([request], sampling_params=sampling_params, use_tqdm=False)
out_text = outputs[0].outputs[0].text
t_infer = time.time() - t_start
# Extract coordinates from model output and rescale to original image
def _extract_xy(s: str):
try:
m = re.findall(r"\((-?\d*\.?\d+),\s*(-?\d*\.?\d+)\)", s)
if not m:
return None
x_str, y_str = m[0]
return float(x_str), float(y_str)
except Exception:
return None
pred = _extract_xy(out_text)
# Log prompts and timings
def _se(s: Optional[str], n: int = 120):
if not s:
return ("", "")
return (s[:n], s[-n:] if len(s) > n else s)
sys_start, sys_end = _se(system_prompt)
usr_start, usr_end = _se(user_text)
try:
print(f"[gta1-endpoint] System prompt (start): {sys_start}")
print(f"[gta1-endpoint] System prompt (end): {sys_end}")
print(f"[gta1-endpoint] User prompt (full): {user_text}")
print(f"[gta1-endpoint] Raw output: {out_text}")
print(f"[gta1-endpoint] Inference time: {t_infer:.3f}s")
except Exception:
pass
if pred is None or not (width and height):
return {"error": "Failed to parse coordinates or missing image dimensions."}
# The model returns pixel coordinates on the input image; we did not pre-resize
px = max(0.0, min(float(pred[0]), float(width)))
py = max(0.0, min(float(pred[1]), float(height)))
# Return normalized [0,1]
nx = px / float(width)
ny = py / float(height)
return {
"points": [{"x": nx, "y": ny}],
"raw": out_text,
}