File size: 2,153 Bytes
f0e6dfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import base64
import io
from typing import Any, Dict, List

import torch
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForObjectDetection


class EndpointHandler:
    def __init__(self, path: str = ""):
        self.processor = AutoImageProcessor.from_pretrained(path)
        self.model = AutoModelForObjectDetection.from_pretrained(path)
        self.model.eval()

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        inputs = data.get("inputs", data)

        # Handle base64-encoded image
        if isinstance(inputs, str):
            image_bytes = base64.b64decode(inputs)
            image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        elif isinstance(inputs, bytes):
            image = Image.open(io.BytesIO(inputs)).convert("RGB")
        elif isinstance(inputs, Image.Image):
            image = inputs.convert("RGB")
        else:
            raise ValueError(
                "Unsupported input type. Provide a base64-encoded image string or raw bytes."
            )

        # Run inference
        with torch.no_grad():
            encoded = self.processor(images=image, return_tensors="pt")
            outputs = self.model(**encoded)

        # Post-process: convert to bounding boxes
        target_size = torch.tensor([image.size[::-1]])  # (height, width)
        results = self.processor.post_process_object_detection(
            outputs, threshold=0.5, target_sizes=target_size
        )[0]

        detections = []
        for score, label, box in zip(
            results["scores"], results["labels"], results["boxes"]
        ):
            xmin, ymin, xmax, ymax = box.tolist()
            detections.append(
                {
                    "score": round(score.item(), 4),
                    "label": self.model.config.id2label[label.item()],
                    "box": {
                        "xmin": round(xmin, 2),
                        "ymin": round(ymin, 2),
                        "xmax": round(xmax, 2),
                        "ymax": round(ymax, 2),
                    },
                }
            )

        return detections