File size: 3,789 Bytes
e8208a0
 
 
 
 
 
 
2e97025
e8208a0
 
 
2e97025
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8208a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69baf1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8208a0
 
 
 
 
2846bb6
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# to be bundled with the model on upload to HF Inference Endpoints

import base64
import io
from typing import Any, Dict

import torch
import torchvision.transforms as T
from PIL import Image
from transformers import AutoImageProcessor, Dinov2ForImageClassification


def get_inference_transform(processor: AutoImageProcessor, size: int):
    """Get the raw validation transform for direct inference on PIL images."""
    normalize   = T.Normalize(mean=processor.image_mean, std=processor.image_std)

    to_rgb = T.Lambda(lambda img: img.convert('RGB'))
 
    def pad_to_square(img):
        w, h = img.size
        max_size = max(w, h)
        pad_w = (max_size - w) // 2
        pad_h = (max_size - h) // 2
        padding = (pad_w, pad_h, max_size - w - pad_w, max_size - h - pad_h)
        return T.Pad(padding, fill=0)(img)

    aug     = T.Compose([
        to_rgb,
        pad_to_square,
        T.Resize(size),
        T.ToTensor(), 
        normalize
    ])

    return aug


class EndpointHandler:
    """
    HF Inference Endpoints entry‑point.
    Loads model/processor once, then uses your *imported* preprocessing
    on every request.
    """

    def __init__(self, path: str = "", image_size: int = 224):
        # Weights + processor --------------------------------------------------------
        self.processor = AutoImageProcessor.from_pretrained(path or ".")
        self.model     = (
            Dinov2ForImageClassification.from_pretrained(path or ".")
            .eval()
        )

        # Re‑use the exact transform from your code ---------------------------------
        self.transform = get_inference_transform(self.processor, image_size)

        self.id2label = self.model.config.id2label

    # -------------------------------------------------------------------------------
    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Expects {"inputs": "<base64‑encoded image>"}.
        Returns the top prediction + per‑class probabilities.
        """
        # case 1 ─ raw bytes (default HF client / curl -T)
        if isinstance(data, (bytes, bytearray)):
            img_bytes = data

        # case 2 ─ JSON with "inputs": <something>
        elif isinstance(data, dict) and "inputs" in data:
            inp = data["inputs"]

            # Base‑64 string
            if isinstance(inp, str):
                img_bytes = base64.b64decode(inp.split(",")[-1])  # drop "data:..." if present

            # Already‑bytes
            elif isinstance(inp, (bytes, bytearray)):
                img_bytes = inp

            # Already a PIL Image object
            elif hasattr(inp, "convert"):
                image = inp                                    # PIL.Image
            else:
                raise ValueError("Unsupported 'inputs' format")

        else:
            raise ValueError("Unsupported request body type")

        # If we didn’t get a ready‑made PIL Image above, decode bytes → PIL
        if "image" not in locals():
            image = Image.open(io.BytesIO(img_bytes))

        # Preprocess with your own transform
        pixel_values = self.transform(image).unsqueeze(0)   # [1, C, H, W]

        with torch.no_grad():
            logits = self.model(pixel_values).logits[0]        # tensor [num_labels]
            probs  = logits.softmax(dim=-1)

        # convert to the required wire format (top‑k or all classes)
        k = min(5, probs.numel())                              # send top‑5
        topk = torch.topk(probs, k)

        response = [
            {"label": self.id2label[idx.item()], "score": prob.item()}
            for prob, idx in zip(topk.values, topk.indices)
        ]

        return response               # <‑‑ must be a *list* of dicts