File size: 6,692 Bytes
199b2e8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 | from transformers import AutoImageProcessor, AutoModelForDepthEstimation
from PIL import Image
import torch
import torch.nn.functional as F
import io
import base64
import numpy as np
import json
class EndpointHandler:
"""
Hugging Face Inference Endpoint handler for Depth Anything V2 *metric indoor* models.
Request formats accepted:
- dict: {"id": 123, "inputs": "<base64 bytes of image>"} (what your client sends)
- dict: {"inputs": ...} where inputs is str/bytes/bytearray/dict
- bytes/bytearray: raw image bytes OR JSON bytes containing {"inputs": ...}
- str: base64 string OR JSON string containing {"inputs": ...}
Response:
- echoes "id" (if provided)
- returns:
depth_raw_base64_f16 : float16 depth in meters, row-major HxW
depth_png_base64 : visualization only (8-bit PNG)
"""
def __init__(self, path: str = ""):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.processor = AutoImageProcessor.from_pretrained(path)
self.model = AutoModelForDepthEstimation.from_pretrained(path).to(self.device)
self.model.eval()
# Optional: minor speedups on GPU
self.use_amp = (self.device == "cuda")
try:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
except Exception:
pass
# -----------------------------
# Robust request parsing
# -----------------------------
def _maybe_parse_json_str(self, s: str):
ss = s.strip()
if ss.startswith("{") and ss.endswith("}"):
try:
return json.loads(ss)
except Exception:
return None
return None
def _maybe_parse_json_bytes(self, b: bytes):
try:
txt = b.decode("utf-8", errors="strict")
except Exception:
return None
obj = self._maybe_parse_json_str(txt)
return obj
def _strip_data_uri(self, s: str) -> str:
# e.g. "data:image/jpeg;base64,AAAA..."
if s.startswith("data:") and "base64," in s:
return s.split("base64,", 1)[1]
return s
def _coerce_to_image_bytes(self, obj):
"""
Returns raw image bytes.
"""
# dict payload (common)
if isinstance(obj, dict):
# allow nesting like {"inputs": {"image": "..."}}
if "inputs" in obj:
return self._coerce_to_image_bytes(obj["inputs"])
# sometimes toolkits use "image" directly
if "image" in obj:
return self._coerce_to_image_bytes(obj["image"])
raise ValueError(f'Missing "inputs" key. Keys={list(obj.keys())}')
# bytes payload
if isinstance(obj, (bytes, bytearray)):
b = bytes(obj)
# if it's JSON bytes, parse it
parsed = self._maybe_parse_json_bytes(b)
if isinstance(parsed, dict) and ("inputs" in parsed or "image" in parsed):
return self._coerce_to_image_bytes(parsed)
return b
# string payload (base64 or json)
if isinstance(obj, str):
s = obj.strip()
# JSON string?
parsed = self._maybe_parse_json_str(s)
if isinstance(parsed, dict) and ("inputs" in parsed or "image" in parsed):
return self._coerce_to_image_bytes(parsed)
# data URI?
s = self._strip_data_uri(s)
# base64
try:
return base64.b64decode(s, validate=False)
except Exception:
# last resort: treat as raw bytes of text
return s.encode("utf-8")
raise ValueError(f"Unsupported request type: {type(obj)}")
# -----------------------------
# Main inference
# -----------------------------
def __call__(self, data):
# Echo request id if present
rid = data.get("id") if isinstance(data, dict) else None
image_bytes = self._coerce_to_image_bytes(data)
# Decode image
try:
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
except Exception as e:
raise ValueError(f"Could not decode image bytes: {e}")
orig_w, orig_h = image.size
# Preprocess
inputs_t = self.processor(images=image, return_tensors="pt")
inputs_t = {k: v.to(self.device) for k, v in inputs_t.items()}
# Forward
with torch.inference_mode():
if self.use_amp:
with torch.autocast(device_type="cuda", dtype=torch.float16):
outputs = self.model(**inputs_t)
else:
outputs = self.model(**inputs_t)
predicted_depth = outputs.predicted_depth # [B, H, W] (metric model => meters)
# Upsample to original size
depth = predicted_depth.unsqueeze(1) # [B,1,H,W]
depth = F.interpolate(depth, size=(orig_h, orig_w), mode="bicubic", align_corners=False)
depth = depth.squeeze(1).squeeze(0) # [H,W]
depth_np = depth.detach().float().cpu().numpy()
# Sanitize: avoid negatives/NaNs
depth_np = np.nan_to_num(depth_np, nan=0.0, posinf=0.0, neginf=0.0)
depth_np = np.clip(depth_np, 0.0, 1e6)
# Visualization PNG (robust scaling using percentiles so it doesn't flicker as much)
# NOTE: viz only — do not use for mapping scale.
p1 = float(np.percentile(depth_np, 1.0))
p99 = float(np.percentile(depth_np, 99.0))
denom = (p99 - p1) if (p99 - p1) > 1e-6 else 1.0
viz = ((depth_np - p1) / denom)
viz = np.clip(viz, 0.0, 1.0)
depth_uint8 = (viz * 255.0).astype(np.uint8)
depth_img = Image.fromarray(depth_uint8, mode="L")
buf = io.BytesIO()
depth_img.save(buf, format="PNG")
depth_png_base64 = base64.b64encode(buf.getvalue()).decode("utf-8")
# Raw float16 meters (THIS is what your client should use)
depth_f16 = depth_np.astype(np.float16)
depth_raw_base64_f16 = base64.b64encode(depth_f16.tobytes()).decode("utf-8")
return {
"id": rid,
"type": "metric_depth",
"units": "meters",
"width": int(orig_w),
"height": int(orig_h),
# Visualization
"depth_png_base64": depth_png_base64,
"viz_p1": p1,
"viz_p99": p99,
# Raw metric depth
"depth_raw_base64_f16": depth_raw_base64_f16,
"raw_dtype": "float16",
"raw_shape": [int(orig_h), int(orig_w)],
} |