from transformers import AutoImageProcessor, AutoModelForDepthEstimation from PIL import Image import torch import torch.nn.functional as F import io import base64 import numpy as np import json class EndpointHandler: """ Hugging Face Inference Endpoint handler for Depth Anything V2 *metric indoor* models. Request formats accepted: - dict: {"id": 123, "inputs": ""} (what your client sends) - dict: {"inputs": ...} where inputs is str/bytes/bytearray/dict - bytes/bytearray: raw image bytes OR JSON bytes containing {"inputs": ...} - str: base64 string OR JSON string containing {"inputs": ...} Response: - echoes "id" (if provided) - returns: depth_raw_base64_f16 : float16 depth in meters, row-major HxW depth_png_base64 : visualization only (8-bit PNG) """ def __init__(self, path: str = ""): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.processor = AutoImageProcessor.from_pretrained(path) self.model = AutoModelForDepthEstimation.from_pretrained(path).to(self.device) self.model.eval() # Optional: minor speedups on GPU self.use_amp = (self.device == "cuda") try: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True except Exception: pass # ----------------------------- # Robust request parsing # ----------------------------- def _maybe_parse_json_str(self, s: str): ss = s.strip() if ss.startswith("{") and ss.endswith("}"): try: return json.loads(ss) except Exception: return None return None def _maybe_parse_json_bytes(self, b: bytes): try: txt = b.decode("utf-8", errors="strict") except Exception: return None obj = self._maybe_parse_json_str(txt) return obj def _strip_data_uri(self, s: str) -> str: # e.g. "data:image/jpeg;base64,AAAA..." if s.startswith("data:") and "base64," in s: return s.split("base64,", 1)[1] return s def _coerce_to_image_bytes(self, obj): """ Returns raw image bytes. """ # dict payload (common) if isinstance(obj, dict): # allow nesting like {"inputs": {"image": "..."}} if "inputs" in obj: return self._coerce_to_image_bytes(obj["inputs"]) # sometimes toolkits use "image" directly if "image" in obj: return self._coerce_to_image_bytes(obj["image"]) raise ValueError(f'Missing "inputs" key. Keys={list(obj.keys())}') # bytes payload if isinstance(obj, (bytes, bytearray)): b = bytes(obj) # if it's JSON bytes, parse it parsed = self._maybe_parse_json_bytes(b) if isinstance(parsed, dict) and ("inputs" in parsed or "image" in parsed): return self._coerce_to_image_bytes(parsed) return b # string payload (base64 or json) if isinstance(obj, str): s = obj.strip() # JSON string? parsed = self._maybe_parse_json_str(s) if isinstance(parsed, dict) and ("inputs" in parsed or "image" in parsed): return self._coerce_to_image_bytes(parsed) # data URI? s = self._strip_data_uri(s) # base64 try: return base64.b64decode(s, validate=False) except Exception: # last resort: treat as raw bytes of text return s.encode("utf-8") raise ValueError(f"Unsupported request type: {type(obj)}") # ----------------------------- # Main inference # ----------------------------- def __call__(self, data): # Echo request id if present rid = data.get("id") if isinstance(data, dict) else None image_bytes = self._coerce_to_image_bytes(data) # Decode image try: image = Image.open(io.BytesIO(image_bytes)).convert("RGB") except Exception as e: raise ValueError(f"Could not decode image bytes: {e}") orig_w, orig_h = image.size # Preprocess inputs_t = self.processor(images=image, return_tensors="pt") inputs_t = {k: v.to(self.device) for k, v in inputs_t.items()} # Forward with torch.inference_mode(): if self.use_amp: with torch.autocast(device_type="cuda", dtype=torch.float16): outputs = self.model(**inputs_t) else: outputs = self.model(**inputs_t) predicted_depth = outputs.predicted_depth # [B, H, W] (metric model => meters) # Upsample to original size depth = predicted_depth.unsqueeze(1) # [B,1,H,W] depth = F.interpolate(depth, size=(orig_h, orig_w), mode="bicubic", align_corners=False) depth = depth.squeeze(1).squeeze(0) # [H,W] depth_np = depth.detach().float().cpu().numpy() # Sanitize: avoid negatives/NaNs depth_np = np.nan_to_num(depth_np, nan=0.0, posinf=0.0, neginf=0.0) depth_np = np.clip(depth_np, 0.0, 1e6) # Visualization PNG (robust scaling using percentiles so it doesn't flicker as much) # NOTE: viz only — do not use for mapping scale. p1 = float(np.percentile(depth_np, 1.0)) p99 = float(np.percentile(depth_np, 99.0)) denom = (p99 - p1) if (p99 - p1) > 1e-6 else 1.0 viz = ((depth_np - p1) / denom) viz = np.clip(viz, 0.0, 1.0) depth_uint8 = (viz * 255.0).astype(np.uint8) depth_img = Image.fromarray(depth_uint8, mode="L") buf = io.BytesIO() depth_img.save(buf, format="PNG") depth_png_base64 = base64.b64encode(buf.getvalue()).decode("utf-8") # Raw float16 meters (THIS is what your client should use) depth_f16 = depth_np.astype(np.float16) depth_raw_base64_f16 = base64.b64encode(depth_f16.tobytes()).decode("utf-8") return { "id": rid, "type": "metric_depth", "units": "meters", "width": int(orig_w), "height": int(orig_h), # Visualization "depth_png_base64": depth_png_base64, "viz_p1": p1, "viz_p99": p99, # Raw metric depth "depth_raw_base64_f16": depth_raw_base64_f16, "raw_dtype": "float16", "raw_shape": [int(orig_h), int(orig_w)], }