File size: 8,414 Bytes
92c5d3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e69d45
92c5d3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e69d45
92c5d3d
 
1e69d45
92c5d3d
 
 
 
1e69d45
92c5d3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e69d45
92c5d3d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
import base64
import io
import json
import os
from typing import Any, Dict, List, Optional, Tuple

from PIL import Image

from transformers import AutoProcessor
from vllm import LLM, SamplingParams


def _b64_to_pil(data_url: str) -> Image.Image:
    if not isinstance(data_url, str) or not data_url.startswith("data:"):
        raise ValueError("Expected a data URL starting with 'data:'")
    header, b64data = data_url.split(",", 1)
    raw = base64.b64decode(b64data)
    img = Image.open(io.BytesIO(raw))
    img.load()
    return img


class EndpointHandler:
    """HF Inference Endpoint handler for Qwen3-VL chat-to-point.

    Input:
      - { system, user, image(data URL) }
      - or legacy OpenAI-style messages with image_url + text

    Output:
      - { points: [{x,y}], raw: <string> }
        where x,y are normalized [0,1]
    """

    def __init__(self, path: str = "") -> None:
        model_id = os.environ.get("MODEL_ID") or "Qwen/Qwen3-VL-8B-Instruct"

        os.environ.setdefault("OMP_NUM_THREADS", "1")
        os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
        os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
        os.environ.setdefault("HF_HUB_ENABLE_QUIC", "1")
        os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn")

        # Auto TP detection from visible GPUs
        visible = os.environ.get("CUDA_VISIBLE_DEVICES")
        if visible and visible.strip():
            try:
                candidates = [d for d in visible.split(",") if d.strip() and d.strip() != "-1"]
                tp = max(1, len(candidates))
            except Exception:
                tp = 1
        else:
            try:
                import torch  # local import to avoid global dependency if CPU-only
                tp = max(1, int(torch.cuda.device_count())) if torch.cuda.is_available() else 1
            except Exception:
                tp = 1

        self._model_id = model_id
        self._tp = tp
        self.llm = None  # type: ignore

        hub_token = (
            os.environ.get("HUGGINGFACE_HUB_TOKEN")
            or os.environ.get("HF_HUB_TOKEN")
            or os.environ.get("HF_TOKEN")
        )
        if hub_token and not os.environ.get("HF_TOKEN"):
            try:
                os.environ["HF_TOKEN"] = hub_token
            except Exception:
                pass

        self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True, token=hub_token)

    def _ensure_llm(self) -> None:
        if self.llm is not None:
            return
        self.llm = LLM(
            model=self._model_id,
            tensor_parallel_size=self._tp,
            pipeline_parallel_size=1,
            gpu_memory_utilization=0.95,
            max_model_len=8192,
            dtype="auto",
            distributed_executor_backend="mp",
            enforce_eager=True,
            trust_remote_code=True,
        )

    @staticmethod
    def _parse_legacy_messages(messages: List[Dict[str, Any]]) -> Tuple[Optional[str], Optional[str], Optional[str]]:
        system_prompt: Optional[str] = None
        first_image_data_url: Optional[str] = None
        first_text: Optional[str] = None
        for msg in messages:
            if msg.get("role") == "system" and system_prompt is None:
                content = msg.get("content")
                if isinstance(content, str):
                    system_prompt = content
            if msg.get("role") == "user":
                content = msg.get("content", [])
                if not isinstance(content, list):
                    continue
                for part in content:
                    if part.get("type") == "image_url" and not first_image_data_url:
                        url = part.get("image_url", {}).get("url")
                        if isinstance(url, str) and url.startswith("data:"):
                            first_image_data_url = url
                    if part.get("type") == "text" and not first_text:
                        t = part.get("text")
                        if isinstance(t, str):
                            first_text = t
        return system_prompt, first_text, first_image_data_url

    def __call__(self, data: Dict[str, Any]) -> Any:
        # Normalize HF toolkit payloads
        if isinstance(data, dict) and "inputs" in data:
            inputs_val = data.get("inputs")
            if isinstance(inputs_val, dict):
                data = inputs_val
            elif isinstance(inputs_val, (str, bytes, bytearray)):
                try:
                    if isinstance(inputs_val, (bytes, bytearray)):
                        inputs_val = inputs_val.decode("utf-8")
                    parsed = json.loads(inputs_val)
                    if isinstance(parsed, dict):
                        data = parsed
                except Exception:
                    pass

        system_prompt: Optional[str] = None
        user_text: Optional[str] = None
        image_data_url: Optional[str] = None

        if isinstance(data, dict) and ("system" in data or "user" in data or "image" in data):
            system_prompt = data.get("system")
            user_text = data.get("user")
            image_data_url = data.get("image")
            if not isinstance(image_data_url, str) or not image_data_url.startswith("data:"):
                return {"error": "image must be a data URL (data:...)"}
        else:
            messages = data.get("messages") if isinstance(data, dict) else None
            if not messages:
                return {"error": "Provide 'system','user','image' or legacy 'messages'"}
            system_prompt, user_text, image_data_url = self._parse_legacy_messages(messages)
            if not isinstance(image_data_url, str) or not image_data_url.startswith("data:"):
                return {"error": "messages.content image_url.url must be a data URL (data:...)"}

        try:
            pil = _b64_to_pil(image_data_url)
        except Exception as e:
            return {"error": f"Failed to decode image: {e}"}

        width = getattr(pil, "width", None)
        height = getattr(pil, "height", None)
        if isinstance(width, int) and isinstance(height, int):
            try:
                print(f"[qwen3-vl-endpoint] Received image size: {width}x{height}")
            except Exception:
                pass

        if not isinstance(user_text, str):
            return {"error": "user text must be provided"}

        system_message = {"role": "system", "content": system_prompt or ""}
        user_message = {
            "role": "user",
            "content": [
                {"type": "image", "image": pil},
                {"type": "text", "text": user_text},
            ],
        }

        prompt = self.processor.apply_chat_template(
            [system_message, user_message], tokenize=False, add_generation_prompt=True
        )

        request: Dict[str, Any] = {"prompt": prompt}
        request["multi_modal_data"] = {"image": [pil]}

        import time
        t0 = time.time()
        self._ensure_llm()
        sampling_params = SamplingParams(max_tokens=16, temperature=0.0, top_p=1.0)
        outputs = self.llm.generate([request], sampling_params=sampling_params, use_tqdm=False)
        out_text = outputs[0].outputs[0].text
        out_text_short = out_text[:20]
        t1 = time.time()

        try:
            print(f"[qwen3-vl-endpoint] Prompt: {user_text}")
            print(f"[qwen3-vl-endpoint] Raw output: {out_text_short}")
            print(f"[qwen3-vl-endpoint] Inference time: {t1 - t0:.3f}s")
        except Exception:
            pass

        try:
            import re
            m = re.findall(r"\((-?\d*\.?\d+),\s*(-?\d*\.?\d+)\)", out_text)
            if not m:
                return {"error": "Failed to parse coordinates from model output."}
            x_str, y_str = m[0]
            px, py = float(x_str), float(y_str)
            if not isinstance(width, int) or not isinstance(height, int):
                return {"error": "Missing image dimensions for normalization."}
            w, h = float(width), float(height)
            px = max(0.0, min(px, w))
            py = max(0.0, min(py, h))
            nx, ny = px / w, py / h
            return {"points": [{"x": nx, "y": ny}], "raw": out_text_short}
        except Exception as e:
            return {"error": f"Postprocessing failed: {e}"}