dinov2-large / handler.py
kas1293's picture
Create handler.py
ef6797e verified
Raw
History Blame Contribute Delete
2.94 kB
# handler.py — facebook/dinov2-large on Hugging Face Inference Endpoints
#
# Buildly's visual place-recognition embedder: one image -> one 1024-d vector.
# Pooling = mean over patch tokens (AnyLoc-style), chosen for viewpoint
# robustness on bare / near-identical rooms. This choice is PERMANENT:
# changing it silently invalidates every stored room-scan vector.
import base64
import io
from typing import Any, Dict
import torch
from PIL import Image, UnidentifiedImageError
from transformers import AutoImageProcessor, AutoModel
MODEL_ID = "facebook/dinov2-large"
class EndpointHandler:
def __init__(self, path: str = MODEL_ID):
# Eager load at startup. During this window (cold start / scale-from-
# zero) the platform answers 503 on its own; Buildly maps that to
# EmbeddingUnavailable and skips the match without corrupting state.
self.processor = AutoImageProcessor.from_pretrained(path)
self.model = AutoModel.from_pretrained(path).eval()
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
def _to_image(self, inp: Any) -> Image.Image:
"""Tolerant of every shape Buildly's probe or HF may hand us:
PIL.Image | raw bytes | base64 str | {"image": <b64|bytes>}.
Anything unparseable raises -> the endpoint returns 4xx (NOT 503),
so Buildly's format probe falls through to the next wire shape
instead of treating it as the model being unavailable.
"""
if isinstance(inp, Image.Image):
return inp.convert("RGB")
if isinstance(inp, dict): # {"image": ...}
inp = inp.get("image") or inp.get("inputs")
if isinstance(inp, str): # base64 jpeg
try:
inp = base64.b64decode(inp)
except Exception as exc:
raise ValueError(f"input string is not valid base64: {exc}")
if isinstance(inp, (bytes, bytearray)):
try:
return Image.open(io.BytesIO(inp)).convert("RGB")
except (UnidentifiedImageError, OSError) as exc:
raise ValueError(f"bytes are not a decodable image: {exc}")
raise ValueError(f"unsupported input type: {type(inp).__name__}")
@torch.no_grad()
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
payload = data.get("inputs", data) if isinstance(data, dict) else data
image = self._to_image(payload)
batch = self.processor(images=image, return_tensors="pt").to(self.device)
tokens = self.model(**batch).last_hidden_state # [1, 257, 1024]
vec = tokens[:, 1:, :].mean(dim=1)[0] # mean-pool patches -> [1024]
# CLS alternative (simpler, less robust for VPR): vec = tokens[:, 0, :][0]
return {"embedding": vec.float().cpu().tolist(), "dim": int(vec.shape[-1])}