| | |
| |
|
| | import base64 |
| | import io |
| | from typing import Any, Dict |
| |
|
| | import torch |
| | import torchvision.transforms as T |
| | from PIL import Image |
| | from transformers import AutoImageProcessor, Dinov2ForImageClassification |
| |
|
| |
|
| | def get_inference_transform(processor: AutoImageProcessor, size: int): |
| | """Get the raw validation transform for direct inference on PIL images.""" |
| | normalize = T.Normalize(mean=processor.image_mean, std=processor.image_std) |
| |
|
| | to_rgb = T.Lambda(lambda img: img.convert('RGB')) |
| | |
| | def pad_to_square(img): |
| | w, h = img.size |
| | max_size = max(w, h) |
| | pad_w = (max_size - w) // 2 |
| | pad_h = (max_size - h) // 2 |
| | padding = (pad_w, pad_h, max_size - w - pad_w, max_size - h - pad_h) |
| | return T.Pad(padding, fill=0)(img) |
| |
|
| | aug = T.Compose([ |
| | to_rgb, |
| | pad_to_square, |
| | T.Resize(size), |
| | T.ToTensor(), |
| | normalize |
| | ]) |
| |
|
| | return aug |
| |
|
| |
|
| | class EndpointHandler: |
| | """ |
| | HF Inference Endpoints entry‑point. |
| | Loads model/processor once, then uses your *imported* preprocessing |
| | on every request. |
| | """ |
| |
|
| | def __init__(self, path: str = "", image_size: int = 224): |
| | |
| | self.processor = AutoImageProcessor.from_pretrained(path or ".") |
| | self.model = ( |
| | Dinov2ForImageClassification.from_pretrained(path or ".") |
| | .eval() |
| | ) |
| |
|
| | |
| | self.transform = get_inference_transform(self.processor, image_size) |
| |
|
| | self.id2label = self.model.config.id2label |
| |
|
| | |
| | def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| | """ |
| | Expects {"inputs": "<base64‑encoded image>"}. |
| | Returns the top prediction + per‑class probabilities. |
| | """ |
| | |
| | if isinstance(data, (bytes, bytearray)): |
| | img_bytes = data |
| |
|
| | |
| | elif isinstance(data, dict) and "inputs" in data: |
| | inp = data["inputs"] |
| |
|
| | |
| | if isinstance(inp, str): |
| | img_bytes = base64.b64decode(inp.split(",")[-1]) |
| |
|
| | |
| | elif isinstance(inp, (bytes, bytearray)): |
| | img_bytes = inp |
| |
|
| | |
| | elif hasattr(inp, "convert"): |
| | image = inp |
| | else: |
| | raise ValueError("Unsupported 'inputs' format") |
| |
|
| | else: |
| | raise ValueError("Unsupported request body type") |
| |
|
| | |
| | if "image" not in locals(): |
| | image = Image.open(io.BytesIO(img_bytes)) |
| |
|
| | |
| | pixel_values = self.transform(image).unsqueeze(0) |
| |
|
| | with torch.no_grad(): |
| | logits = self.model(pixel_values).logits[0] |
| | probs = logits.softmax(dim=-1) |
| |
|
| | |
| | k = min(5, probs.numel()) |
| | topk = torch.topk(probs, k) |
| |
|
| | response = [ |
| | {"label": self.id2label[idx.item()], "score": prob.item()} |
| | for prob, idx in zip(topk.values, topk.indices) |
| | ] |
| |
|
| | return response |
| |
|
| |
|