pslime's picture
Update handler.py
199b2e8 verified
from transformers import AutoImageProcessor, AutoModelForDepthEstimation
from PIL import Image
import torch
import torch.nn.functional as F
import io
import base64
import numpy as np
import json
class EndpointHandler:
"""
Hugging Face Inference Endpoint handler for Depth Anything V2 *metric indoor* models.
Request formats accepted:
- dict: {"id": 123, "inputs": "<base64 bytes of image>"} (what your client sends)
- dict: {"inputs": ...} where inputs is str/bytes/bytearray/dict
- bytes/bytearray: raw image bytes OR JSON bytes containing {"inputs": ...}
- str: base64 string OR JSON string containing {"inputs": ...}
Response:
- echoes "id" (if provided)
- returns:
depth_raw_base64_f16 : float16 depth in meters, row-major HxW
depth_png_base64 : visualization only (8-bit PNG)
"""
def __init__(self, path: str = ""):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.processor = AutoImageProcessor.from_pretrained(path)
self.model = AutoModelForDepthEstimation.from_pretrained(path).to(self.device)
self.model.eval()
# Optional: minor speedups on GPU
self.use_amp = (self.device == "cuda")
try:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
except Exception:
pass
# -----------------------------
# Robust request parsing
# -----------------------------
def _maybe_parse_json_str(self, s: str):
ss = s.strip()
if ss.startswith("{") and ss.endswith("}"):
try:
return json.loads(ss)
except Exception:
return None
return None
def _maybe_parse_json_bytes(self, b: bytes):
try:
txt = b.decode("utf-8", errors="strict")
except Exception:
return None
obj = self._maybe_parse_json_str(txt)
return obj
def _strip_data_uri(self, s: str) -> str:
# e.g. "data:image/jpeg;base64,AAAA..."
if s.startswith("data:") and "base64," in s:
return s.split("base64,", 1)[1]
return s
def _coerce_to_image_bytes(self, obj):
"""
Returns raw image bytes.
"""
# dict payload (common)
if isinstance(obj, dict):
# allow nesting like {"inputs": {"image": "..."}}
if "inputs" in obj:
return self._coerce_to_image_bytes(obj["inputs"])
# sometimes toolkits use "image" directly
if "image" in obj:
return self._coerce_to_image_bytes(obj["image"])
raise ValueError(f'Missing "inputs" key. Keys={list(obj.keys())}')
# bytes payload
if isinstance(obj, (bytes, bytearray)):
b = bytes(obj)
# if it's JSON bytes, parse it
parsed = self._maybe_parse_json_bytes(b)
if isinstance(parsed, dict) and ("inputs" in parsed or "image" in parsed):
return self._coerce_to_image_bytes(parsed)
return b
# string payload (base64 or json)
if isinstance(obj, str):
s = obj.strip()
# JSON string?
parsed = self._maybe_parse_json_str(s)
if isinstance(parsed, dict) and ("inputs" in parsed or "image" in parsed):
return self._coerce_to_image_bytes(parsed)
# data URI?
s = self._strip_data_uri(s)
# base64
try:
return base64.b64decode(s, validate=False)
except Exception:
# last resort: treat as raw bytes of text
return s.encode("utf-8")
raise ValueError(f"Unsupported request type: {type(obj)}")
# -----------------------------
# Main inference
# -----------------------------
def __call__(self, data):
# Echo request id if present
rid = data.get("id") if isinstance(data, dict) else None
image_bytes = self._coerce_to_image_bytes(data)
# Decode image
try:
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
except Exception as e:
raise ValueError(f"Could not decode image bytes: {e}")
orig_w, orig_h = image.size
# Preprocess
inputs_t = self.processor(images=image, return_tensors="pt")
inputs_t = {k: v.to(self.device) for k, v in inputs_t.items()}
# Forward
with torch.inference_mode():
if self.use_amp:
with torch.autocast(device_type="cuda", dtype=torch.float16):
outputs = self.model(**inputs_t)
else:
outputs = self.model(**inputs_t)
predicted_depth = outputs.predicted_depth # [B, H, W] (metric model => meters)
# Upsample to original size
depth = predicted_depth.unsqueeze(1) # [B,1,H,W]
depth = F.interpolate(depth, size=(orig_h, orig_w), mode="bicubic", align_corners=False)
depth = depth.squeeze(1).squeeze(0) # [H,W]
depth_np = depth.detach().float().cpu().numpy()
# Sanitize: avoid negatives/NaNs
depth_np = np.nan_to_num(depth_np, nan=0.0, posinf=0.0, neginf=0.0)
depth_np = np.clip(depth_np, 0.0, 1e6)
# Visualization PNG (robust scaling using percentiles so it doesn't flicker as much)
# NOTE: viz only — do not use for mapping scale.
p1 = float(np.percentile(depth_np, 1.0))
p99 = float(np.percentile(depth_np, 99.0))
denom = (p99 - p1) if (p99 - p1) > 1e-6 else 1.0
viz = ((depth_np - p1) / denom)
viz = np.clip(viz, 0.0, 1.0)
depth_uint8 = (viz * 255.0).astype(np.uint8)
depth_img = Image.fromarray(depth_uint8, mode="L")
buf = io.BytesIO()
depth_img.save(buf, format="PNG")
depth_png_base64 = base64.b64encode(buf.getvalue()).decode("utf-8")
# Raw float16 meters (THIS is what your client should use)
depth_f16 = depth_np.astype(np.float16)
depth_raw_base64_f16 = base64.b64encode(depth_f16.tobytes()).decode("utf-8")
return {
"id": rid,
"type": "metric_depth",
"units": "meters",
"width": int(orig_w),
"height": int(orig_h),
# Visualization
"depth_png_base64": depth_png_base64,
"viz_p1": p1,
"viz_p99": p99,
# Raw metric depth
"depth_raw_base64_f16": depth_raw_base64_f16,
"raw_dtype": "float16",
"raw_shape": [int(orig_h), int(orig_w)],
}