File size: 10,646 Bytes
98e1622
 
 
 
 
 
 
 
222ced7
98e1622
 
222ced7
98e1622
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6edaebd
 
98e1622
 
 
9985c0d
 
 
98e1622
 
 
 
 
 
222ced7
 
cde6e20
222ced7
cde6e20
 
222ced7
 
9985c0d
 
222ced7
 
 
 
 
 
 
 
 
 
 
 
 
 
9985c0d
 
 
 
 
 
 
 
 
 
 
222ced7
9985c0d
 
222ced7
 
 
eccbc24
 
222ced7
 
98e1622
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9d0613
98e1622
e9d0613
 
 
 
 
 
 
 
 
 
98e1622
e9d0613
 
 
 
 
98e1622
 
 
e9d0613
 
98e1622
e9d0613
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98e1622
 
 
 
 
 
 
 
 
 
 
 
 
e9d0613
98e1622
 
 
 
 
 
 
 
 
 
 
 
 
 
222ced7
 
 
98e1622
e9d0613
 
9985c0d
222ced7
 
 
e9d0613
98e1622
 
 
 
 
 
 
 
 
 
 
 
 
e9d0613
 
 
 
 
 
 
582f8ae
e9d0613
 
 
582f8ae
e9d0613
582f8ae
 
 
 
 
 
 
 
 
e9d0613
582f8ae
 
 
 
 
 
98e1622
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
import base64
import json
import io
import os
from typing import Any, Dict, List, Optional

from PIL import Image

from transformers import AutoProcessor
from qwen_vl_utils import process_vision_info
import re
from vllm import LLM, SamplingParams


def _b64_to_pil(data_url: str) -> Image.Image:
    if not isinstance(data_url, str) or not data_url.startswith("data:"):
        raise ValueError("Expected a data URL starting with 'data:'")
    header, b64data = data_url.split(',', 1)
    raw = base64.b64decode(b64data)
    img = Image.open(io.BytesIO(raw))
    img.load()
    return img


class EndpointHandler:
    """Custom handler for Hugging Face Inference Endpoints (Qwen2.5-VL).

    Input (OpenAI-style):
    { "messages": [ { "role":"user", "content": [ {"type":"image_url","image_url":{"url":"data:..."}}, {"type":"text","text":"..."} ] } ] }

    Output: { raw: string, width?: number, height?: number }
    """

    def __init__(self, path: str = "") -> None:
        # Always default to 7B unless MODEL_ID explicitly overrides
        model_id = os.environ.get("MODEL_ID") or "HelloKKMe/GTA1-7B"

        os.environ.setdefault("OMP_NUM_THREADS", "1")
        os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
        # Speed up first-time HF downloads and enable optimized transport
        os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
        os.environ.setdefault("HF_HUB_ENABLE_QUIC", "1")

        hub_token = (
            os.environ.get("HUGGINGFACE_HUB_TOKEN")
            or os.environ.get("HF_HUB_TOKEN")
            or os.environ.get("HF_TOKEN")
        )
        # Ensure vLLM can pull gated repos if needed
        if hub_token and not os.environ.get("HF_TOKEN"):
            try:
                os.environ["HF_TOKEN"] = hub_token
            except Exception:
                pass

        # Auto-detect tensor parallel size from visible devices
        # Default to 'spawn' which is safest across managed environments
        os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
        visible = os.environ.get("CUDA_VISIBLE_DEVICES")
        if visible and visible.strip():
            try:
                candidates = [d for d in visible.split(",") if d.strip() and d.strip() != "-1"]
                tp = max(1, len(candidates))
            except Exception:
                tp = 1
        else:
            try:
                import torch  # local import to avoid global requirement if not using CUDA
                tp = max(1, int(torch.cuda.device_count())) if torch.cuda.is_available() else 1
            except Exception:
                tp = 1

        # Defer vLLM engine init to first request to avoid startup failures
        self._model_id = model_id
        self._tp = tp
        self.llm = None  # type: ignore
        self.processor = AutoProcessor.from_pretrained(
            model_id, trust_remote_code=True, token=hub_token
        )

    def _ensure_llm(self) -> None:
        if self.llm is not None:
            return
        self.llm = LLM(
            model=self._model_id,
            tensor_parallel_size=self._tp,
            pipeline_parallel_size=1,
            gpu_memory_utilization=0.95,
            dtype="auto",
            distributed_executor_backend="mp",
            enforce_eager=True,
            trust_remote_code=True,
        )

    def __call__(self, data: Dict[str, Any]) -> Any:
        # Normalize HF Endpoint payloads
        if isinstance(data, dict) and "inputs" in data:
            inputs_val = data.get("inputs")
            if isinstance(inputs_val, dict):
                data = inputs_val
            elif isinstance(inputs_val, (str, bytes, bytearray)):
                try:
                    if isinstance(inputs_val, (bytes, bytearray)):
                        inputs_val = inputs_val.decode("utf-8")
                    parsed = json.loads(inputs_val)
                    if isinstance(parsed, dict):
                        data = parsed
                except Exception:
                    pass

        # New input contract: expect 'system', 'user', and 'image' (data URL). Fallback to messages for compatibility.
        img_for_dims: Optional[Image.Image] = None
        system_prompt: Optional[str] = None
        user_text: Optional[str] = None
        image_data_url: Optional[str] = None

        if isinstance(data, dict) and ("system" in data or "user" in data or "image" in data):
            system_prompt = data.get("system")
            user_text = data.get("user")
            image_data_url = data.get("image")
            if not isinstance(image_data_url, str) or not image_data_url.startswith("data:"):
                return {"error": "image must be a data URL (data:...)"}
            try:
                img_for_dims = _b64_to_pil(image_data_url)
            except Exception as e:
                return {"error": f"Failed to decode image: {e}"}
            messages = [
                {"role": "system", "content": system_prompt or ""},
                {
                    "role": "user",
                    "content": [
                        {"type": "image", "image": img_for_dims},
                        {"type": "text", "text": user_text or ""},
                    ],
                },
            ]
        else:
            messages = data.get("messages")
            if not messages:
                return {"error": "Provide 'system','user','image' or legacy 'messages'"}
            normalized: List[Dict[str, Any]] = []
            first_img: Optional[Image.Image] = None
            for msg in messages:
                if msg.get("role") == "system" and system_prompt is None:
                    system_prompt = msg.get("content") if isinstance(msg.get("content"), str) else None
                if msg.get("role") == "user":
                    content = msg.get("content", [])
                    image_url: Optional[str] = None
                    text_piece: Optional[str] = None
                    for part in content:
                        if part.get("type") == "image_url":
                            image_url = part.get("image_url", {}).get("url")
                        elif part.get("type") == "text":
                            text_piece = part.get("text")
                    if not image_url or not text_piece:
                        return {"error": "Content must include image_url (data URL) and text."}
                    if not isinstance(image_url, str) or not image_url.startswith("data:"):
                        return {"error": "image_url.url must be a data URL (data:...)"}
                    try:
                        img_for_dims = _b64_to_pil(image_url)
                        first_img = first_img or img_for_dims
                    except Exception:
                        img_for_dims = None
                    user_text = user_text or text_piece
                    normalized.append(
                        {
                            "role": "user",
                            "content": [
                                {"type": "image", "image": image_url},
                                {"type": "text", "text": text_piece},
                            ],
                        }
                    )
            messages = [{"role": "system", "content": system_prompt or ""}] + normalized
            if first_img is not None:
                img_for_dims = first_img

        width = getattr(img_for_dims, "width", None)
        height = getattr(img_for_dims, "height", None)
        if width and height:
            try:
                print(f"[gta1-endpoint] Received image size: {width}x{height}")
            except Exception:
                pass

        if not isinstance(img_for_dims, Image.Image) or not isinstance(user_text, str):
            return {"error": "Failed to prepare image/text for inference."}

        # Build system + user messages with the original image (no pre-resize)
        system_message = {"role": "system", "content": system_prompt or ""}
        user_message = {
            "role": "user",
            "content": [
                {"type": "image", "image": img_for_dims},
                {"type": "text", "text": user_text},
            ],
        }

        image_inputs, video_inputs = process_vision_info([system_message, user_message])

        text = self.processor.apply_chat_template(
            [system_message, user_message], tokenize=False, add_generation_prompt=True
        )

        request: Dict[str, Any] = {"prompt": text}
        if image_inputs:
            request["multi_modal_data"] = {"image": image_inputs}

        import time
        t_start = time.time()
        self._ensure_llm()
        sampling_params = SamplingParams(max_tokens=32, temperature=0.0, top_p=1.0)
        outputs = self.llm.generate([request], sampling_params=sampling_params, use_tqdm=False)
        out_text = outputs[0].outputs[0].text
        t_infer = time.time() - t_start

        # Extract coordinates from model output and rescale to original image
        def _extract_xy(s: str):
            try:
                m = re.findall(r"\((-?\d*\.?\d+),\s*(-?\d*\.?\d+)\)", s)
                if not m:
                    return None
                x_str, y_str = m[0]
                return float(x_str), float(y_str)
            except Exception:
                return None

        pred = _extract_xy(out_text)
        # Log prompts and timings
        def _se(s: Optional[str], n: int = 120):
            if not s:
                return ("", "")
            return (s[:n], s[-n:] if len(s) > n else s)
        sys_start, sys_end = _se(system_prompt)
        usr_start, usr_end = _se(user_text)
        try:
            print(f"[gta1-endpoint] System prompt (start): {sys_start}")
            print(f"[gta1-endpoint] System prompt (end): {sys_end}")
            print(f"[gta1-endpoint] User prompt (full): {user_text}")
            print(f"[gta1-endpoint] Raw output: {out_text}")
            print(f"[gta1-endpoint] Inference time: {t_infer:.3f}s")
        except Exception:
            pass

        if pred is None or not (width and height):
            return {"error": "Failed to parse coordinates or missing image dimensions."}

        # The model returns pixel coordinates on the input image; we did not pre-resize
        px = max(0.0, min(float(pred[0]), float(width)))
        py = max(0.0, min(float(pred[1]), float(height)))
        # Return normalized [0,1]
        nx = px / float(width)
        ny = py / float(height)
        return {
            "points": [{"x": nx, "y": ny}],
            "raw": out_text,
        }