pslime's picture
Update handler.py
83f365f verified
from transformers import AutoImageProcessor, AutoModelForDepthEstimation
from PIL import Image
import torch
import torch.nn.functional as F
import io
import base64
import numpy as np
import json
class EndpointHandler:
def __init__(self, path=""):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.processor = AutoImageProcessor.from_pretrained(path)
self.model = AutoModelForDepthEstimation.from_pretrained(path)
self.model.to(self.device)
self.model.eval()
def _coerce_to_image_bytes(self, obj):
"""
Accepts:
- bytes/bytearray: raw image bytes
- str: base64 string OR JSON string containing {"inputs": "..."} OR plain text (fallback)
- dict: expects dict["inputs"] (which can itself be str/bytes/etc)
Returns:
- image_bytes (bytes)
"""
# If toolkit passes dict
if isinstance(obj, dict):
if "inputs" not in obj:
raise ValueError(f'Missing "inputs" key. Keys={list(obj.keys())}')
return self._coerce_to_image_bytes(obj["inputs"])
# If toolkit passes raw bytes
if isinstance(obj, (bytes, bytearray)):
b = bytes(obj)
# Sometimes body is JSON bytes; try parse
try:
txt = b.decode("utf-8")
if txt.lstrip().startswith("{") and '"inputs"' in txt:
return self._coerce_to_image_bytes(json.loads(txt))
except Exception:
pass
return b
# If toolkit passes str
if isinstance(obj, str):
s = obj.strip()
# Sometimes it's a JSON string
if s.startswith("{") and '"inputs"' in s:
try:
return self._coerce_to_image_bytes(json.loads(s))
except Exception:
pass
# Most common: base64 string of image bytes
try:
return base64.b64decode(s, validate=False)
except Exception:
# Last resort: treat as utf-8 bytes (won't be a valid image, but avoids str->BytesIO crash)
return s.encode("utf-8")
raise ValueError(f"Unsupported request type: {type(obj)}")
def __call__(self, data):
image_bytes = self._coerce_to_image_bytes(data)
# Now guaranteed bytes
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
orig_w, orig_h = image.size
inputs_t = self.processor(images=image, return_tensors="pt")
inputs_t = {k: v.to(self.device) for k, v in inputs_t.items()}
with torch.no_grad():
outputs = self.model(**inputs_t)
predicted_depth = outputs.predicted_depth # [B, H, W]
# Upsample to original size
depth = predicted_depth.unsqueeze(1) # [B,1,H,W]
depth = F.interpolate(
depth, size=(orig_h, orig_w), mode="bicubic", align_corners=False
)
depth = depth.squeeze(1).squeeze(0) # [H,W]
depth_np = depth.detach().float().cpu().numpy()
# viz png
dmin, dmax = float(depth_np.min()), float(depth_np.max())
denom = (dmax - dmin) if (dmax - dmin) > 1e-12 else 1.0
depth_uint8 = (((depth_np - dmin) / denom) * 255.0).clip(0, 255).astype(np.uint8)
depth_img = Image.fromarray(depth_uint8, mode="L")
buf = io.BytesIO()
depth_img.save(buf, format="PNG")
depth_png_base64 = base64.b64encode(buf.getvalue()).decode("utf-8")
# raw float16 depth
depth_f16 = depth_np.astype(np.float16)
depth_raw_base64_f16 = base64.b64encode(depth_f16.tobytes()).decode("utf-8")
return {
"type": "relative_depth",
"width": orig_w,
"height": orig_h,
"depth_png_base64": depth_png_base64,
"depth_raw_base64_f16": depth_raw_base64_f16,
"raw_dtype": "float16",
"raw_shape": [orig_h, orig_w],
"viz_min": dmin,
"viz_max": dmax,
}