File size: 11,237 Bytes
7c2acc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5daee26
7c2acc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
import base64
import io
import json
import os
from typing import Any, Dict, List, Optional

from PIL import Image

import torch
from transformers import AutoModelForCausalLM


def _b64_to_pil(data_url: str) -> Image.Image:
    if not isinstance(data_url, str) or not data_url.startswith("data:"):
        raise ValueError("Expected a data URL starting with 'data:'")
    header, b64data = data_url.split(",", 1)
    raw = base64.b64decode(b64data)
    img = Image.open(io.BytesIO(raw))
    img.load()
    return img


class EndpointHandler:
    """HF Inference Endpoint handler for Moondream3 Preview.

    Input contract (OpenAI-style):
    {
      "messages": [
        {
          "role": "user",
          "content": [
            { "type": "image_url", "image_url": { "url": "data:<mime>;base64,<...>" } },
            { "type": "text", "text": "<object or question>" }
          ]
        }
      ],
      "task": "point" | "detect" | "query"  // optional, default "point"
      "max_objects": <int>                     // optional for detect
      "reasoning": <bool>                      // optional for query
    }

    Output:
    - task=="point": { points: [{x, y}], width, height }
    - task=="detect": { objects: [{x_min, y_min, x_max, y_max}], width, height }
    - task=="query":  { answer: "...", width?, height? }
    Coordinates are normalized (0-1). width/height echo source image dims for convenience.
    """

    def __init__(self, path: str = "") -> None:
        model_id = os.environ.get("MODEL_ID", "moondream/moondream3-preview")

        os.environ.setdefault("OMP_NUM_THREADS", "1")
        os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")

        # Load local repo (or remote if MODEL_ID points to hub id)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_id,
            trust_remote_code=True,
            torch_dtype=torch.bfloat16,
            device_map="auto",
        )

        # Optional compilation for speed if exposed by remote code
        try:
            compile_fn = getattr(self.model, "compile", None)
            if callable(compile_fn):
                compile_fn()
        except Exception:
            pass

    def __call__(self, data: Dict[str, Any]) -> Any:
        # Accept HF toolkit shapes: { inputs: {...} } or JSON string
        if isinstance(data, dict) and "inputs" in data:
            inputs_val = data.get("inputs")
            if isinstance(inputs_val, dict):
                data = inputs_val
            elif isinstance(inputs_val, (str, bytes, bytearray)):
                try:
                    if isinstance(inputs_val, (bytes, bytearray)):
                        inputs_val = inputs_val.decode("utf-8")
                    parsed = json.loads(inputs_val)
                    if isinstance(parsed, dict):
                        data = parsed
                except Exception:
                    pass

        messages = data.get("messages")
        task = str(data.get("task", "point")).lower()
        reasoning = bool(data.get("reasoning", True))
        max_objects = data.get("max_objects")
        prioritize_accuracy = bool(data.get("prioritize_accuracy", True))

        if not messages:
            return {"error": "Provide 'messages' with user image and text"}

        # Extract first user image and text
        image_data_url: Optional[str] = None
        text_piece: Optional[str] = None
        for msg in messages:
            if msg.get("role") != "user":
                return {"error": "Only user messages are supported."}
            for part in msg.get("content", []):
                if part.get("type") == "image_url" and image_data_url is None:
                    image_data_url = part.get("image_url", {}).get("url")
                elif part.get("type") == "text" and text_piece is None:
                    text_piece = part.get("text")
            if image_data_url and text_piece:
                break

        if not image_data_url or not isinstance(image_data_url, str) or not image_data_url.startswith("data:"):
            return {"error": "image_url.url must be a data URL (data:...)"}
        if not text_piece:
            return {"error": "Content must include text."}

        # Decode for dimensions and pass PIL to model
        try:
            pil = _b64_to_pil(image_data_url)
        except Exception as e:
            return {"error": f"Failed to decode image data URL: {e}"}

        width = getattr(pil, "width", None)
        height = getattr(pil, "height", None)
        if width and height:
            try:
                print(f"[moondream-endpoint] Received image size: {width}x{height}")
            except Exception:
                pass

        # Run selected skill
        try:
            if task == "point":
                if prioritize_accuracy:
                    flipped = pil.transpose(Image.FLIP_LEFT_RIGHT)
                    res_orig = self.model.point(pil, text_piece)
                    res_flip = self.model.point(flipped, text_piece)
                    points = self._tta_points(res_orig.get("points", []), res_flip.get("points", []))
                    out: Dict[str, Any] = {"points": points}
                else:
                    result = self.model.point(pil, text_piece)
                    out = {"points": result.get("points", [])}
            elif task == "detect":
                settings = {"max_objects": int(max_objects)} if max_objects else None
                if prioritize_accuracy:
                    flipped = pil.transpose(Image.FLIP_LEFT_RIGHT)
                    res_orig = self.model.detect(pil, text_piece, settings=settings)
                    res_flip = self.model.detect(flipped, text_piece, settings=settings)
                    objects = self._tta_boxes(res_orig.get("objects", []), res_flip.get("objects", []))
                    out = {"objects": objects}
                else:
                    result = self.model.detect(pil, text_piece, settings=settings)
                    out = {"objects": result.get("objects", [])}
            elif task == "query":
                result = self.model.query(pil, question=text_piece, reasoning=reasoning, stream=False)
                out = {"answer": result.get("answer", "")}
            else:
                return {"error": f"Unsupported task '{task}'. Use 'point', 'detect', or 'query'."}
        except Exception as e:
            return {"error": f"Model inference failed: {e}"}

        if width and height:
            out.update({"width": width, "height": height})
        out.update({"task": task})
        return out

    @staticmethod
    def _flip_point(p: Dict[str, Any]) -> Dict[str, float]:
        x = float(p.get("x", 0.0))
        y = float(p.get("y", 0.0))
        x = 1.0 - x
        return {"x": max(0.0, min(1.0, x)), "y": max(0.0, min(1.0, y))}

    @classmethod
    def _deduplicate_and_average_points(cls, points: List[Dict[str, Any]], tol: float = 0.03) -> List[Dict[str, float]]:
        clusters: List[Dict[str, float]] = []
        counts: List[int] = []
        for p in points:
            px = float(p.get("x", 0.0))
            py = float(p.get("y", 0.0))
            matched = False
            for i, c in enumerate(clusters):
                dx = px - c["x"]
                dy = py - c["y"]
                if dx * dx + dy * dy <= tol * tol:
                    n = counts[i]
                    c["x"] = (c["x"] * n + px) / (n + 1)
                    c["y"] = (c["y"] * n + py) / (n + 1)
                    counts[i] = n + 1
                    matched = True
                    break
            if not matched:
                clusters.append({"x": px, "y": py})
                counts.append(1)
        return clusters

    @classmethod
    def _tta_points(cls, points_a: List[Dict[str, Any]], points_b_flipped: List[Dict[str, Any]]) -> List[Dict[str, float]]:
        # Convert flipped prediction back to original frame
        unflipped_b = [cls._flip_point(p) for p in points_b_flipped]
        merged = list(points_a) + unflipped_b
        return cls._deduplicate_and_average_points(merged)

    @staticmethod
    def _flip_box(b: Dict[str, Any]) -> Dict[str, float]:
        xmin = float(b.get("x_min", 0.0))
        xmax = float(b.get("x_max", 0.0))
        ymin = float(b.get("y_min", 0.0))
        ymax = float(b.get("y_max", 0.0))
        nxmin = 1.0 - xmax
        nxmax = 1.0 - xmin
        nxmin, nxmax = max(0.0, min(1.0, nxmin)), max(0.0, min(1.0, nxmax))
        ymin, ymax = max(0.0, min(1.0, ymin)), max(0.0, min(1.0, ymax))
        if nxmin > nxmax:
            nxmin, nxmax = nxmax, nxmin
        return {"x_min": nxmin, "y_min": ymin, "x_max": nxmax, "y_max": ymax}

    @staticmethod
    def _iou(b1: Dict[str, float], b2: Dict[str, float]) -> float:
        x1 = max(b1["x_min"], b2["x_min"]) 
        y1 = max(b1["y_min"], b2["y_min"]) 
        x2 = min(b1["x_max"], b2["x_max"]) 
        y2 = min(b1["y_max"], b2["y_max"]) 
        inter_w = max(0.0, x2 - x1)
        inter_h = max(0.0, y2 - y1)
        inter = inter_w * inter_h
        a1 = max(0.0, b1["x_max"] - b1["x_min"]) * max(0.0, b1["y_max"] - b1["y_min"]) 
        a2 = max(0.0, b2["x_max"] - b2["x_min"]) * max(0.0, b2["y_max"] - b2["y_min"]) 
        denom = a1 + a2 - inter
        return inter / denom if denom > 0 else 0.0

    @classmethod
    def _merge_boxes_with_nms(cls, boxes: List[Dict[str, float]], iou_threshold: float = 0.5) -> List[Dict[str, float]]:
        merged: List[Dict[str, float]] = []
        used = [False] * len(boxes)
        for i in range(len(boxes)):
            if used[i]:
                continue
            cluster = [boxes[i]]
            used[i] = True
            for j in range(i + 1, len(boxes)):
                if used[j]:
                    continue
                if cls._iou(boxes[i], boxes[j]) >= iou_threshold:
                    used[j] = True
                    cluster.append(boxes[j])
            # Average cluster
            n = float(len(cluster))
            avg = {
                "x_min": sum(b["x_min"] for b in cluster) / n,
                "y_min": sum(b["y_min"] for b in cluster) / n,
                "x_max": sum(b["x_max"] for b in cluster) / n,
                "y_max": sum(b["y_max"] for b in cluster) / n,
            }
            # Clamp
            avg["x_min"] = max(0.0, min(1.0, avg["x_min"]))
            avg["y_min"] = max(0.0, min(1.0, avg["y_min"]))
            avg["x_max"] = max(0.0, min(1.0, avg["x_max"]))
            avg["y_max"] = max(0.0, min(1.0, avg["y_max"]))
            merged.append(avg)
        return merged

    @classmethod
    def _tta_boxes(cls, boxes_a: List[Dict[str, Any]], boxes_b_flipped: List[Dict[str, Any]]) -> List[Dict[str, float]]:
        unflipped_b = [cls._flip_box(b) for b in boxes_b_flipped]
        combined = [
            {
                "x_min": float(b.get("x_min", 0.0)),
                "y_min": float(b.get("y_min", 0.0)),
                "x_max": float(b.get("x_max", 0.0)),
                "y_max": float(b.get("y_max", 0.0)),
            }
            for b in (list(boxes_a) + unflipped_b)
        ]
        return cls._merge_boxes_with_nms(combined, iou_threshold=0.5)