File size: 4,090 Bytes
4d5eab5 83f365f 4d5eab5 83f365f 4d5eab5 83f365f 4d5eab5 83f365f 5c20c47 83f365f 5c20c47 83f365f 4d5eab5 5c20c47 4d5eab5 5c20c47 4d5eab5 83f365f 5c20c47 4d5eab5 83f365f 4d5eab5 5c20c47 4d5eab5 83f365f 4d5eab5 83f365f 4d5eab5 5c20c47 4d5eab5 83f365f 4d5eab5 5c20c47 4d5eab5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 | from transformers import AutoImageProcessor, AutoModelForDepthEstimation
from PIL import Image
import torch
import torch.nn.functional as F
import io
import base64
import numpy as np
import json
class EndpointHandler:
def __init__(self, path=""):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.processor = AutoImageProcessor.from_pretrained(path)
self.model = AutoModelForDepthEstimation.from_pretrained(path)
self.model.to(self.device)
self.model.eval()
def _coerce_to_image_bytes(self, obj):
"""
Accepts:
- bytes/bytearray: raw image bytes
- str: base64 string OR JSON string containing {"inputs": "..."} OR plain text (fallback)
- dict: expects dict["inputs"] (which can itself be str/bytes/etc)
Returns:
- image_bytes (bytes)
"""
# If toolkit passes dict
if isinstance(obj, dict):
if "inputs" not in obj:
raise ValueError(f'Missing "inputs" key. Keys={list(obj.keys())}')
return self._coerce_to_image_bytes(obj["inputs"])
# If toolkit passes raw bytes
if isinstance(obj, (bytes, bytearray)):
b = bytes(obj)
# Sometimes body is JSON bytes; try parse
try:
txt = b.decode("utf-8")
if txt.lstrip().startswith("{") and '"inputs"' in txt:
return self._coerce_to_image_bytes(json.loads(txt))
except Exception:
pass
return b
# If toolkit passes str
if isinstance(obj, str):
s = obj.strip()
# Sometimes it's a JSON string
if s.startswith("{") and '"inputs"' in s:
try:
return self._coerce_to_image_bytes(json.loads(s))
except Exception:
pass
# Most common: base64 string of image bytes
try:
return base64.b64decode(s, validate=False)
except Exception:
# Last resort: treat as utf-8 bytes (won't be a valid image, but avoids str->BytesIO crash)
return s.encode("utf-8")
raise ValueError(f"Unsupported request type: {type(obj)}")
def __call__(self, data):
image_bytes = self._coerce_to_image_bytes(data)
# Now guaranteed bytes
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
orig_w, orig_h = image.size
inputs_t = self.processor(images=image, return_tensors="pt")
inputs_t = {k: v.to(self.device) for k, v in inputs_t.items()}
with torch.no_grad():
outputs = self.model(**inputs_t)
predicted_depth = outputs.predicted_depth # [B, H, W]
# Upsample to original size
depth = predicted_depth.unsqueeze(1) # [B,1,H,W]
depth = F.interpolate(
depth, size=(orig_h, orig_w), mode="bicubic", align_corners=False
)
depth = depth.squeeze(1).squeeze(0) # [H,W]
depth_np = depth.detach().float().cpu().numpy()
# viz png
dmin, dmax = float(depth_np.min()), float(depth_np.max())
denom = (dmax - dmin) if (dmax - dmin) > 1e-12 else 1.0
depth_uint8 = (((depth_np - dmin) / denom) * 255.0).clip(0, 255).astype(np.uint8)
depth_img = Image.fromarray(depth_uint8, mode="L")
buf = io.BytesIO()
depth_img.save(buf, format="PNG")
depth_png_base64 = base64.b64encode(buf.getvalue()).decode("utf-8")
# raw float16 depth
depth_f16 = depth_np.astype(np.float16)
depth_raw_base64_f16 = base64.b64encode(depth_f16.tobytes()).decode("utf-8")
return {
"type": "relative_depth",
"width": orig_w,
"height": orig_h,
"depth_png_base64": depth_png_base64,
"depth_raw_base64_f16": depth_raw_base64_f16,
"raw_dtype": "float16",
"raw_shape": [orig_h, orig_w],
"viz_min": dmin,
"viz_max": dmax,
} |