File size: 5,836 Bytes
d6637ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
from __future__ import annotations

import base64
import io
import json
import os
from pathlib import Path
from typing import Any, Dict, List, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image

try:
    from torchvision import transforms
except Exception as e:
    transforms = None


class ASLCNN(nn.Module):
    """Simple CNN architecture inferred from the state_dict keys/shapes."""

    def __init__(self, num_classes: int = 29):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(40000, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


def _load_labels(repo_dir: Path, num_classes: int) -> List[str]:
    labels_path = repo_dir / "labels.json"
    if labels_path.exists():
        with labels_path.open("r", encoding="utf-8") as f:
            data = json.load(f)
        labels = data.get("labels")
        if isinstance(labels, list) and len(labels) == num_classes:
            return [str(x) for x in labels]
    return [str(i) for i in range(num_classes)]


def _decode_image(inp: Any) -> Image.Image:
    """Accepts PIL.Image, raw bytes, or base64 string (optionally data URL)."""
    if isinstance(inp, Image.Image):
        return inp

    if isinstance(inp, (bytes, bytearray)):
        return Image.open(io.BytesIO(inp))

    if isinstance(inp, str):
        s = inp.strip()
        # data URL: data:image/png;base64,...
        if s.startswith("data:") and "," in s:
            s = s.split(",", 1)[1]
        try:
            b = base64.b64decode(s, validate=False)
            return Image.open(io.BytesIO(b))
        except Exception:
            # last resort: treat as a local path
            p = Path(s)
            if p.exists():
                return Image.open(str(p))
            raise

    raise ValueError(f"Unsupported input type for 'inputs': {type(inp)}")


class EndpointHandler:
    """Hugging Face Inference Endpoints custom handler.

    __init__(path): called once at container startup.
    __call__(data): called per request; data always contains 'inputs'.
    """

    def __init__(self, path: str = ""):
        self.repo_dir = Path(path) if path else Path("/repository")
# In the default container, "path" is typically the model repo directory.
        weights = self.repo_dir / "pytorch_model.bin"
        if not weights.exists():
            # fallback if someone renamed the file
            candidates = list(self.repo_dir.glob("*.bin")) + list(self.repo_dir.glob("*.pt"))
            if candidates:
                weights = candidates[0]
            else:
                raise FileNotFoundError("Could not find weights file (expected pytorch_model.bin) in repo")

        state_dict = torch.load(str(weights), map_location="cpu")

        # Infer num_classes from fc2.weight if present
        num_classes = 29
        if isinstance(state_dict, dict) and "fc2.weight" in state_dict:
            num_classes = int(state_dict["fc2.weight"].shape[0])

        self.labels = _load_labels(self.repo_dir, num_classes)
        self.model = ASLCNN(num_classes=num_classes)
        self.model.load_state_dict(state_dict)
        self.model.eval()

        if transforms is None:
            self.transform = None
        else:
            # NOTE: This assumes your training used 100x100 RGB inputs and raw [0,1] scaling.
            # If you used mean/std normalization, add transforms.Normalize(...) here.
            self.transform = transforms.Compose(
                [
                    transforms.Resize((100, 100)),
                    transforms.ToTensor(),
                ]
            )

    def _preprocess(self, img: Image.Image) -> torch.Tensor:
        img = img.convert("RGB")
        if self.transform is None:
            # minimal fallback without torchvision
            img = img.resize((100, 100))
            arr = np.asarray(img).astype("float32") / 255.0
            arr = np.transpose(arr, (2, 0, 1))
            x = torch.from_numpy(arr)
        else:
            x = self.transform(img)
        return x.unsqueeze(0)  # [1,3,100,100]

    def __call__(self, data: Dict[str, Any]) -> Union[List[Dict[str, Any]], Dict[str, Any]]:
        inp = data.get("inputs")
        params = data.get("parameters") or {}
        top_k = int(params.get("top_k", 5))

        # Support batch (list of inputs) or single input
        if isinstance(inp, list):
            imgs = [_decode_image(x) for x in inp]
            xs = torch.cat([self._preprocess(im) for im in imgs], dim=0)
        else:
            im = _decode_image(inp)
            xs = self._preprocess(im)

        with torch.no_grad():
            logits = self.model(xs)
            probs = torch.softmax(logits, dim=-1)

        # Return per-sample top_k predictions
        results: List[List[Dict[str, Any]]] = []
        k = min(top_k, probs.shape[-1])
        top_probs, top_idx = torch.topk(probs, k=k, dim=-1)
        for i in range(probs.shape[0]):
            sample: List[Dict[str, Any]] = []
            for p, idx in zip(top_probs[i].tolist(), top_idx[i].tolist()):
                label = self.labels[idx] if idx < len(self.labels) else str(idx)
                sample.append({"label": label, "score": float(p)})
            results.append(sample)

        return results[0] if len(results) == 1 else results