File size: 11,003 Bytes
1a91148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e77552b
 
 
 
 
 
 
1a91148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e77552b
1a91148
e77552b
 
1a91148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
"""
Inference module for LearningStudio Callout Detection wrapper.

This module:
1. Normalizes input to bytes (handles URLs, data URLs, raw base64)
2. Gets presigned S3 URL from API Gateway
3. Uploads image directly to S3 (bypasses API Gateway for large payloads)
4. Calls API Gateway to start detection job
5. Polls for completion
6. Transforms callouts to EMCO format
"""

import os
import base64
import time
import logging
from typing import Dict, Any, List, Optional, Tuple

import requests

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Environment variables (set in HF Inference Endpoint secrets)
API_GATEWAY_URL = os.environ.get("API_GATEWAY_URL", "")
API_KEY = os.environ.get("API_KEY", "")

# Polling configuration
MAX_WAIT_SECONDS = 900  # 15 minutes
POLL_INTERVAL_SECONDS = 5

# Default processing parameters (LearningStudio doesn't pass these)
DEFAULT_PARAMS = {
    "tiling": {"tile": 2048, "overlap": 0.30},
    "floodfill": {"erase_text": False, "min_fill_vs_text": 0.0},
    "preclean": {"denoise_sw": 8}
}


def normalize_to_bytes(image_input: str) -> Tuple[bytes, str]:
    """
    Normalize image input to bytes.

    Handles:
    - HTTP/HTTPS URLs: Downloads image
    - Data URLs (data:image/png;base64,...): Decodes base64
    - Raw base64: Decodes to bytes

    Args:
        image_input: Image URL, data URL, or base64 string

    Returns:
        Tuple of (image_bytes, filename)
    """
    # Check if it's a URL
    if image_input.startswith(("http://", "https://")):
        logger.info(f"Downloading image from URL: {image_input[:100]}...")
        response = requests.get(image_input, timeout=60)
        response.raise_for_status()

        # Try to get filename from URL
        from urllib.parse import urlparse
        parsed = urlparse(image_input)
        filename = os.path.basename(parsed.path) or "image.png"

        return response.content, filename

    # Check if it's a data URL
    if image_input.startswith("data:"):
        # Parse data URL: data:image/png;base64,<data>
        try:
            header, encoded = image_input.split(",", 1)
            # Extract extension from mime type
            mime_part = header.split(";")[0].replace("data:", "")
            ext = mime_part.split("/")[-1] if "/" in mime_part else "png"
            return base64.b64decode(encoded), f"image.{ext}"
        except ValueError:
            raise ValueError("Invalid data URL format")

    # Assume it's already base64
    try:
        return base64.b64decode(image_input), "image.png"
    except Exception as e:
        raise ValueError(f"Invalid base64 string: {e}")


def get_upload_url(filename: str = "image.png") -> Dict[str, str]:
    """
    Get presigned S3 URL for image upload.

    Args:
        filename: Original filename for the image

    Returns:
        Dict with job_id, upload_url, s3_url
    """
    if not API_GATEWAY_URL or not API_KEY:
        raise ValueError(
            "API_GATEWAY_URL and API_KEY must be set in environment variables. "
            "Configure these in your HF Inference Endpoint secrets."
        )

    url = f"{API_GATEWAY_URL.rstrip('/')}/upload-url"
    headers = {"x-api-key": API_KEY}
    params = {"filename": filename}

    logger.info(f"Getting upload URL from {url}")
    response = requests.get(url, headers=headers, params=params, timeout=30)
    response.raise_for_status()

    result = response.json()
    logger.info(f"Got upload URL for job_id={result.get('job_id')}")
    return result


def upload_to_s3(upload_url: str, image_bytes: bytes) -> None:
    """
    Upload image directly to S3 using presigned URL.

    Args:
        upload_url: Presigned PUT URL
        image_bytes: Image data to upload
    """
    logger.info(f"Uploading {len(image_bytes)} bytes to S3...")
    response = requests.put(
        upload_url,
        data=image_bytes,
        headers={"Content-Type": "image/png"},
        timeout=60
    )
    response.raise_for_status()
    logger.info("Upload complete")


def start_detection_job(job_id: str, s3_url: str, params: Optional[Dict] = None) -> str:
    """
    Start a detection job via API Gateway.

    Args:
        job_id: Job ID from get_upload_url
        s3_url: S3 URL from get_upload_url
        params: Optional processing parameters

    Returns:
        Job ID for polling
    """
    url = f"{API_GATEWAY_URL.rstrip('/')}/detect"
    headers = {
        "x-api-key": API_KEY,
        "Content-Type": "application/json"
    }
    payload = {
        "job_id": job_id,
        "s3_url": s3_url
    }
    if params:
        payload["params"] = params

    logger.info(f"Starting detection job {job_id}")
    response = requests.post(url, headers=headers, json=payload, timeout=30)
    response.raise_for_status()

    result = response.json()
    logger.info(f"Detection job started: {result.get('status')}")
    return job_id


def poll_for_completion(job_id: str) -> Dict[str, Any]:
    """
    Poll API Gateway for job completion.

    Args:
        job_id: Job ID to poll

    Returns:
        Final result with callouts
    """
    url = f"{API_GATEWAY_URL.rstrip('/')}/status/{job_id}"
    headers = {"x-api-key": API_KEY}

    elapsed = 0
    while elapsed < MAX_WAIT_SECONDS:
        logger.info(f"Polling job {job_id} (elapsed: {elapsed}s)")

        response = requests.get(url, headers=headers, timeout=30)
        response.raise_for_status()

        result = response.json()
        status = result.get("status")

        if status == "SUCCEEDED":
            logger.info(f"Job {job_id} completed successfully")
            return result

        if status in ("FAILED", "TIMED_OUT", "ABORTED"):
            error_msg = result.get("error", f"Job {status.lower()}")
            logger.error(f"Job {job_id} failed: {error_msg}")
            return {
                "status": status,
                "error": error_msg,
                "callouts": []
            }

        # Still running, wait and retry
        time.sleep(POLL_INTERVAL_SECONDS)
        elapsed += POLL_INTERVAL_SECONDS

    # Timeout
    logger.error(f"Job {job_id} timed out after {MAX_WAIT_SECONDS}s")
    return {
        "status": "TIMEOUT",
        "error": f"Timeout waiting for results after {MAX_WAIT_SECONDS}s",
        "callouts": []
    }


def transform_to_emco_format(
    callouts: List[Dict],
    image_base64: str,
    image_width: int = 0,
    image_height: int = 0
) -> Dict[str, Any]:
    """
    Transform callouts from Lambda format to EMCO format.

    Lambda format:
        {"bbox": [x, y, w, h], "score": 0.95, ...}  # xywh

    EMCO format:
        {"bbox": {"x1": x, "y1": y, "x2": x+w, "y2": y+h}, "confidence": 0.95, ...}  # xyxy

    Args:
        callouts: List of callouts from Lambda
        image_base64: Original image as base64
        image_width: Image width
        image_height: Image height

    Returns:
        EMCO-compatible response dict
    """
    predictions = []

    for i, callout in enumerate(callouts):
        bbox = callout.get("bbox", [0, 0, 0, 0])

        # Convert from [x, y, w, h] to {x1, y1, x2, y2}
        x, y, w, h = bbox[0], bbox[1], bbox[2], bbox[3]

        prediction = {
            "id": i + 1,
            "label": "callout",
            "class_id": 0,
            "confidence": callout.get("score", callout.get("confidence", 1.0)),
            "bbox": {
                "x1": int(x),
                "y1": int(y),
                "x2": int(x + w),
                "y2": int(y + h)
            }
        }

        # Include optional fields if present
        if "text" in callout:
            prediction["text"] = callout["text"]

        predictions.append(prediction)

    return {
        "predictions": predictions,
        "total_detections": len(predictions),
        "image": image_base64,
        "image_width": image_width,
        "image_height": image_height
    }


def inference(image_input: str, parameters: Optional[Dict] = None) -> Dict[str, Any]:
    """
    Run inference on an image.

    This is the main entry point for the HF wrapper.

    Flow:
    1. Normalize input to bytes
    2. Get presigned S3 URL
    3. Upload image directly to S3
    4. Start detection job (small JSON payload)
    5. Poll for completion
    6. Transform results to EMCO format

    Args:
        image_input: Image URL, data URL, or base64 string
        parameters: Optional processing parameters

    Returns:
        EMCO-compatible response with predictions
    """
    try:
        # 1. Normalize input to bytes
        logger.info("Normalizing input...")
        image_bytes, filename = normalize_to_bytes(image_input)

        # Keep base64 for response
        image_base64 = base64.b64encode(image_bytes).decode("utf-8")

        # 2. Get presigned upload URL
        logger.info("Getting upload URL...")
        upload_info = get_upload_url(filename)
        job_id = upload_info["job_id"]
        upload_url = upload_info["upload_url"]
        s3_url = upload_info["s3_url"]

        # 3. Upload image directly to S3
        logger.info("Uploading to S3...")
        upload_to_s3(upload_url, image_bytes)

        # 4. Start detection job (merge user params with defaults)
        logger.info("Starting detection job...")
        merged_params = {**DEFAULT_PARAMS, **(parameters or {})}
        start_detection_job(job_id, s3_url, merged_params)

        # 5. Poll for completion
        logger.info("Polling for completion...")
        result = poll_for_completion(job_id)

        # 6. Check for errors
        if result.get("status") in ("FAILED", "TIMED_OUT", "ABORTED", "TIMEOUT"):
            return {
                "error": result.get("error", "Unknown error"),
                "predictions": [],
                "total_detections": 0,
                "image": image_base64
            }

        # 7. Transform to EMCO format
        logger.info("Transforming results to EMCO format...")
        callouts = result.get("callouts", [])
        image_width = result.get("image_width", 0)
        image_height = result.get("image_height", 0)

        return transform_to_emco_format(
            callouts,
            image_base64,
            image_width,
            image_height
        )

    except requests.exceptions.RequestException as e:
        logger.error(f"Request error: {e}")
        return {
            "error": f"Request error: {str(e)}",
            "predictions": [],
            "total_detections": 0,
            "image": ""
        }
    except ValueError as e:
        logger.error(f"Validation error: {e}")
        return {
            "error": str(e),
            "predictions": [],
            "total_detections": 0,
            "image": ""
        }
    except Exception as e:
        logger.error(f"Unexpected error: {e}", exc_info=True)
        return {
            "error": f"Unexpected error: {str(e)}",
            "predictions": [],
            "total_detections": 0,
            "image": ""
        }