|
|
from __future__ import annotations |
|
|
|
|
|
import base64 |
|
|
import io |
|
|
import json |
|
|
import os |
|
|
from pathlib import Path |
|
|
from typing import Any, Dict, List, Union |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from PIL import Image |
|
|
|
|
|
try: |
|
|
from torchvision import transforms |
|
|
except Exception as e: |
|
|
transforms = None |
|
|
|
|
|
|
|
|
class ASLCNN(nn.Module): |
|
|
"""Simple CNN architecture inferred from the state_dict keys/shapes.""" |
|
|
|
|
|
def __init__(self, num_classes: int = 29): |
|
|
super().__init__() |
|
|
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1) |
|
|
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) |
|
|
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1) |
|
|
self.pool = nn.MaxPool2d(2, 2) |
|
|
self.fc1 = nn.Linear(40000, 128) |
|
|
self.fc2 = nn.Linear(128, num_classes) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
x = self.pool(F.relu(self.conv1(x))) |
|
|
x = self.pool(F.relu(self.conv2(x))) |
|
|
x = F.relu(self.conv3(x)) |
|
|
x = torch.flatten(x, 1) |
|
|
x = F.relu(self.fc1(x)) |
|
|
x = self.fc2(x) |
|
|
return x |
|
|
|
|
|
|
|
|
def _load_labels(repo_dir: Path, num_classes: int) -> List[str]: |
|
|
labels_path = repo_dir / "labels.json" |
|
|
if labels_path.exists(): |
|
|
with labels_path.open("r", encoding="utf-8") as f: |
|
|
data = json.load(f) |
|
|
labels = data.get("labels") |
|
|
if isinstance(labels, list) and len(labels) == num_classes: |
|
|
return [str(x) for x in labels] |
|
|
return [str(i) for i in range(num_classes)] |
|
|
|
|
|
|
|
|
def _decode_image(inp: Any) -> Image.Image: |
|
|
"""Accepts PIL.Image, raw bytes, or base64 string (optionally data URL).""" |
|
|
if isinstance(inp, Image.Image): |
|
|
return inp |
|
|
|
|
|
if isinstance(inp, (bytes, bytearray)): |
|
|
return Image.open(io.BytesIO(inp)) |
|
|
|
|
|
if isinstance(inp, str): |
|
|
s = inp.strip() |
|
|
|
|
|
if s.startswith("data:") and "," in s: |
|
|
s = s.split(",", 1)[1] |
|
|
try: |
|
|
b = base64.b64decode(s, validate=False) |
|
|
return Image.open(io.BytesIO(b)) |
|
|
except Exception: |
|
|
|
|
|
p = Path(s) |
|
|
if p.exists(): |
|
|
return Image.open(str(p)) |
|
|
raise |
|
|
|
|
|
raise ValueError(f"Unsupported input type for 'inputs': {type(inp)}") |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
"""Hugging Face Inference Endpoints custom handler. |
|
|
|
|
|
__init__(path): called once at container startup. |
|
|
__call__(data): called per request; data always contains 'inputs'. |
|
|
""" |
|
|
|
|
|
def __init__(self, path: str = ""): |
|
|
self.repo_dir = Path(path) if path else Path("/repository") |
|
|
|
|
|
weights = self.repo_dir / "pytorch_model.bin" |
|
|
if not weights.exists(): |
|
|
|
|
|
candidates = list(self.repo_dir.glob("*.bin")) + list(self.repo_dir.glob("*.pt")) |
|
|
if candidates: |
|
|
weights = candidates[0] |
|
|
else: |
|
|
raise FileNotFoundError("Could not find weights file (expected pytorch_model.bin) in repo") |
|
|
|
|
|
state_dict = torch.load(str(weights), map_location="cpu") |
|
|
|
|
|
|
|
|
num_classes = 29 |
|
|
if isinstance(state_dict, dict) and "fc2.weight" in state_dict: |
|
|
num_classes = int(state_dict["fc2.weight"].shape[0]) |
|
|
|
|
|
self.labels = _load_labels(self.repo_dir, num_classes) |
|
|
self.model = ASLCNN(num_classes=num_classes) |
|
|
self.model.load_state_dict(state_dict) |
|
|
self.model.eval() |
|
|
|
|
|
if transforms is None: |
|
|
self.transform = None |
|
|
else: |
|
|
|
|
|
|
|
|
self.transform = transforms.Compose( |
|
|
[ |
|
|
transforms.Resize((100, 100)), |
|
|
transforms.ToTensor(), |
|
|
] |
|
|
) |
|
|
|
|
|
def _preprocess(self, img: Image.Image) -> torch.Tensor: |
|
|
img = img.convert("RGB") |
|
|
if self.transform is None: |
|
|
|
|
|
img = img.resize((100, 100)) |
|
|
arr = np.asarray(img).astype("float32") / 255.0 |
|
|
arr = np.transpose(arr, (2, 0, 1)) |
|
|
x = torch.from_numpy(arr) |
|
|
else: |
|
|
x = self.transform(img) |
|
|
return x.unsqueeze(0) |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Union[List[Dict[str, Any]], Dict[str, Any]]: |
|
|
inp = data.get("inputs") |
|
|
params = data.get("parameters") or {} |
|
|
top_k = int(params.get("top_k", 5)) |
|
|
|
|
|
|
|
|
if isinstance(inp, list): |
|
|
imgs = [_decode_image(x) for x in inp] |
|
|
xs = torch.cat([self._preprocess(im) for im in imgs], dim=0) |
|
|
else: |
|
|
im = _decode_image(inp) |
|
|
xs = self._preprocess(im) |
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = self.model(xs) |
|
|
probs = torch.softmax(logits, dim=-1) |
|
|
|
|
|
|
|
|
results: List[List[Dict[str, Any]]] = [] |
|
|
k = min(top_k, probs.shape[-1]) |
|
|
top_probs, top_idx = torch.topk(probs, k=k, dim=-1) |
|
|
for i in range(probs.shape[0]): |
|
|
sample: List[Dict[str, Any]] = [] |
|
|
for p, idx in zip(top_probs[i].tolist(), top_idx[i].tolist()): |
|
|
label = self.labels[idx] if idx < len(self.labels) else str(idx) |
|
|
sample.append({"label": label, "score": float(p)}) |
|
|
results.append(sample) |
|
|
|
|
|
return results[0] if len(results) == 1 else results |
|
|
|