asl-cnn-endpoint / handler.py
duashmi's picture
Upload 5 files
d6637ea verified
from __future__ import annotations
import base64
import io
import json
import os
from pathlib import Path
from typing import Any, Dict, List, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
try:
from torchvision import transforms
except Exception as e:
transforms = None
class ASLCNN(nn.Module):
"""Simple CNN architecture inferred from the state_dict keys/shapes."""
def __init__(self, num_classes: int = 29):
super().__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(40000, 128)
self.fc2 = nn.Linear(128, num_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = F.relu(self.conv3(x))
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
def _load_labels(repo_dir: Path, num_classes: int) -> List[str]:
labels_path = repo_dir / "labels.json"
if labels_path.exists():
with labels_path.open("r", encoding="utf-8") as f:
data = json.load(f)
labels = data.get("labels")
if isinstance(labels, list) and len(labels) == num_classes:
return [str(x) for x in labels]
return [str(i) for i in range(num_classes)]
def _decode_image(inp: Any) -> Image.Image:
"""Accepts PIL.Image, raw bytes, or base64 string (optionally data URL)."""
if isinstance(inp, Image.Image):
return inp
if isinstance(inp, (bytes, bytearray)):
return Image.open(io.BytesIO(inp))
if isinstance(inp, str):
s = inp.strip()
# data URL: data:image/png;base64,...
if s.startswith("data:") and "," in s:
s = s.split(",", 1)[1]
try:
b = base64.b64decode(s, validate=False)
return Image.open(io.BytesIO(b))
except Exception:
# last resort: treat as a local path
p = Path(s)
if p.exists():
return Image.open(str(p))
raise
raise ValueError(f"Unsupported input type for 'inputs': {type(inp)}")
class EndpointHandler:
"""Hugging Face Inference Endpoints custom handler.
__init__(path): called once at container startup.
__call__(data): called per request; data always contains 'inputs'.
"""
def __init__(self, path: str = ""):
self.repo_dir = Path(path) if path else Path("/repository")
# In the default container, "path" is typically the model repo directory.
weights = self.repo_dir / "pytorch_model.bin"
if not weights.exists():
# fallback if someone renamed the file
candidates = list(self.repo_dir.glob("*.bin")) + list(self.repo_dir.glob("*.pt"))
if candidates:
weights = candidates[0]
else:
raise FileNotFoundError("Could not find weights file (expected pytorch_model.bin) in repo")
state_dict = torch.load(str(weights), map_location="cpu")
# Infer num_classes from fc2.weight if present
num_classes = 29
if isinstance(state_dict, dict) and "fc2.weight" in state_dict:
num_classes = int(state_dict["fc2.weight"].shape[0])
self.labels = _load_labels(self.repo_dir, num_classes)
self.model = ASLCNN(num_classes=num_classes)
self.model.load_state_dict(state_dict)
self.model.eval()
if transforms is None:
self.transform = None
else:
# NOTE: This assumes your training used 100x100 RGB inputs and raw [0,1] scaling.
# If you used mean/std normalization, add transforms.Normalize(...) here.
self.transform = transforms.Compose(
[
transforms.Resize((100, 100)),
transforms.ToTensor(),
]
)
def _preprocess(self, img: Image.Image) -> torch.Tensor:
img = img.convert("RGB")
if self.transform is None:
# minimal fallback without torchvision
img = img.resize((100, 100))
arr = np.asarray(img).astype("float32") / 255.0
arr = np.transpose(arr, (2, 0, 1))
x = torch.from_numpy(arr)
else:
x = self.transform(img)
return x.unsqueeze(0) # [1,3,100,100]
def __call__(self, data: Dict[str, Any]) -> Union[List[Dict[str, Any]], Dict[str, Any]]:
inp = data.get("inputs")
params = data.get("parameters") or {}
top_k = int(params.get("top_k", 5))
# Support batch (list of inputs) or single input
if isinstance(inp, list):
imgs = [_decode_image(x) for x in inp]
xs = torch.cat([self._preprocess(im) for im in imgs], dim=0)
else:
im = _decode_image(inp)
xs = self._preprocess(im)
with torch.no_grad():
logits = self.model(xs)
probs = torch.softmax(logits, dim=-1)
# Return per-sample top_k predictions
results: List[List[Dict[str, Any]]] = []
k = min(top_k, probs.shape[-1])
top_probs, top_idx = torch.topk(probs, k=k, dim=-1)
for i in range(probs.shape[0]):
sample: List[Dict[str, Any]] = []
for p, idx in zip(top_probs[i].tolist(), top_idx[i].tolist()):
label = self.labels[idx] if idx < len(self.labels) else str(idx)
sample.append({"label": label, "score": float(p)})
results.append(sample)
return results[0] if len(results) == 1 else results