from __future__ import annotations import base64 import io import json import os from pathlib import Path from typing import Any, Dict, List, Union import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image try: from torchvision import transforms except Exception as e: transforms = None class ASLCNN(nn.Module): """Simple CNN architecture inferred from the state_dict keys/shapes.""" def __init__(self, num_classes: int = 29): super().__init__() self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1) self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1) self.pool = nn.MaxPool2d(2, 2) self.fc1 = nn.Linear(40000, 128) self.fc2 = nn.Linear(128, num_classes) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = F.relu(self.conv3(x)) x = torch.flatten(x, 1) x = F.relu(self.fc1(x)) x = self.fc2(x) return x def _load_labels(repo_dir: Path, num_classes: int) -> List[str]: labels_path = repo_dir / "labels.json" if labels_path.exists(): with labels_path.open("r", encoding="utf-8") as f: data = json.load(f) labels = data.get("labels") if isinstance(labels, list) and len(labels) == num_classes: return [str(x) for x in labels] return [str(i) for i in range(num_classes)] def _decode_image(inp: Any) -> Image.Image: """Accepts PIL.Image, raw bytes, or base64 string (optionally data URL).""" if isinstance(inp, Image.Image): return inp if isinstance(inp, (bytes, bytearray)): return Image.open(io.BytesIO(inp)) if isinstance(inp, str): s = inp.strip() # data URL: data:image/png;base64,... if s.startswith("data:") and "," in s: s = s.split(",", 1)[1] try: b = base64.b64decode(s, validate=False) return Image.open(io.BytesIO(b)) except Exception: # last resort: treat as a local path p = Path(s) if p.exists(): return Image.open(str(p)) raise raise ValueError(f"Unsupported input type for 'inputs': {type(inp)}") class EndpointHandler: """Hugging Face Inference Endpoints custom handler. __init__(path): called once at container startup. __call__(data): called per request; data always contains 'inputs'. """ def __init__(self, path: str = ""): self.repo_dir = Path(path) if path else Path("/repository") # In the default container, "path" is typically the model repo directory. weights = self.repo_dir / "pytorch_model.bin" if not weights.exists(): # fallback if someone renamed the file candidates = list(self.repo_dir.glob("*.bin")) + list(self.repo_dir.glob("*.pt")) if candidates: weights = candidates[0] else: raise FileNotFoundError("Could not find weights file (expected pytorch_model.bin) in repo") state_dict = torch.load(str(weights), map_location="cpu") # Infer num_classes from fc2.weight if present num_classes = 29 if isinstance(state_dict, dict) and "fc2.weight" in state_dict: num_classes = int(state_dict["fc2.weight"].shape[0]) self.labels = _load_labels(self.repo_dir, num_classes) self.model = ASLCNN(num_classes=num_classes) self.model.load_state_dict(state_dict) self.model.eval() if transforms is None: self.transform = None else: # NOTE: This assumes your training used 100x100 RGB inputs and raw [0,1] scaling. # If you used mean/std normalization, add transforms.Normalize(...) here. self.transform = transforms.Compose( [ transforms.Resize((100, 100)), transforms.ToTensor(), ] ) def _preprocess(self, img: Image.Image) -> torch.Tensor: img = img.convert("RGB") if self.transform is None: # minimal fallback without torchvision img = img.resize((100, 100)) arr = np.asarray(img).astype("float32") / 255.0 arr = np.transpose(arr, (2, 0, 1)) x = torch.from_numpy(arr) else: x = self.transform(img) return x.unsqueeze(0) # [1,3,100,100] def __call__(self, data: Dict[str, Any]) -> Union[List[Dict[str, Any]], Dict[str, Any]]: inp = data.get("inputs") params = data.get("parameters") or {} top_k = int(params.get("top_k", 5)) # Support batch (list of inputs) or single input if isinstance(inp, list): imgs = [_decode_image(x) for x in inp] xs = torch.cat([self._preprocess(im) for im in imgs], dim=0) else: im = _decode_image(inp) xs = self._preprocess(im) with torch.no_grad(): logits = self.model(xs) probs = torch.softmax(logits, dim=-1) # Return per-sample top_k predictions results: List[List[Dict[str, Any]]] = [] k = min(top_k, probs.shape[-1]) top_probs, top_idx = torch.topk(probs, k=k, dim=-1) for i in range(probs.shape[0]): sample: List[Dict[str, Any]] = [] for p, idx in zip(top_probs[i].tolist(), top_idx[i].tolist()): label = self.labels[idx] if idx < len(self.labels) else str(idx) sample.append({"label": label, "score": float(p)}) results.append(sample) return results[0] if len(results) == 1 else results