File size: 6,692 Bytes
199b2e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
from transformers import AutoImageProcessor, AutoModelForDepthEstimation
from PIL import Image
import torch
import torch.nn.functional as F
import io
import base64
import numpy as np
import json


class EndpointHandler:
    """
    Hugging Face Inference Endpoint handler for Depth Anything V2 *metric indoor* models.

    Request formats accepted:
      - dict: {"id": 123, "inputs": "<base64 bytes of image>"}  (what your client sends)
      - dict: {"inputs": ...} where inputs is str/bytes/bytearray/dict
      - bytes/bytearray: raw image bytes OR JSON bytes containing {"inputs": ...}
      - str: base64 string OR JSON string containing {"inputs": ...}

    Response:
      - echoes "id" (if provided)
      - returns:
          depth_raw_base64_f16 : float16 depth in meters, row-major HxW
          depth_png_base64     : visualization only (8-bit PNG)
    """

    def __init__(self, path: str = ""):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.processor = AutoImageProcessor.from_pretrained(path)
        self.model = AutoModelForDepthEstimation.from_pretrained(path).to(self.device)
        self.model.eval()

        # Optional: minor speedups on GPU
        self.use_amp = (self.device == "cuda")
        try:
            torch.backends.cuda.matmul.allow_tf32 = True
            torch.backends.cudnn.allow_tf32 = True
        except Exception:
            pass

    # -----------------------------
    # Robust request parsing
    # -----------------------------

    def _maybe_parse_json_str(self, s: str):
        ss = s.strip()
        if ss.startswith("{") and ss.endswith("}"):
            try:
                return json.loads(ss)
            except Exception:
                return None
        return None

    def _maybe_parse_json_bytes(self, b: bytes):
        try:
            txt = b.decode("utf-8", errors="strict")
        except Exception:
            return None
        obj = self._maybe_parse_json_str(txt)
        return obj

    def _strip_data_uri(self, s: str) -> str:
        # e.g. "data:image/jpeg;base64,AAAA..."
        if s.startswith("data:") and "base64," in s:
            return s.split("base64,", 1)[1]
        return s

    def _coerce_to_image_bytes(self, obj):
        """
        Returns raw image bytes.
        """
        # dict payload (common)
        if isinstance(obj, dict):
            # allow nesting like {"inputs": {"image": "..."}}
            if "inputs" in obj:
                return self._coerce_to_image_bytes(obj["inputs"])
            # sometimes toolkits use "image" directly
            if "image" in obj:
                return self._coerce_to_image_bytes(obj["image"])
            raise ValueError(f'Missing "inputs" key. Keys={list(obj.keys())}')

        # bytes payload
        if isinstance(obj, (bytes, bytearray)):
            b = bytes(obj)
            # if it's JSON bytes, parse it
            parsed = self._maybe_parse_json_bytes(b)
            if isinstance(parsed, dict) and ("inputs" in parsed or "image" in parsed):
                return self._coerce_to_image_bytes(parsed)
            return b

        # string payload (base64 or json)
        if isinstance(obj, str):
            s = obj.strip()

            # JSON string?
            parsed = self._maybe_parse_json_str(s)
            if isinstance(parsed, dict) and ("inputs" in parsed or "image" in parsed):
                return self._coerce_to_image_bytes(parsed)

            # data URI?
            s = self._strip_data_uri(s)

            # base64
            try:
                return base64.b64decode(s, validate=False)
            except Exception:
                # last resort: treat as raw bytes of text
                return s.encode("utf-8")

        raise ValueError(f"Unsupported request type: {type(obj)}")

    # -----------------------------
    # Main inference
    # -----------------------------

    def __call__(self, data):
        # Echo request id if present
        rid = data.get("id") if isinstance(data, dict) else None

        image_bytes = self._coerce_to_image_bytes(data)

        # Decode image
        try:
            image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        except Exception as e:
            raise ValueError(f"Could not decode image bytes: {e}")

        orig_w, orig_h = image.size

        # Preprocess
        inputs_t = self.processor(images=image, return_tensors="pt")
        inputs_t = {k: v.to(self.device) for k, v in inputs_t.items()}

        # Forward
        with torch.inference_mode():
            if self.use_amp:
                with torch.autocast(device_type="cuda", dtype=torch.float16):
                    outputs = self.model(**inputs_t)
            else:
                outputs = self.model(**inputs_t)

            predicted_depth = outputs.predicted_depth  # [B, H, W] (metric model => meters)

        # Upsample to original size
        depth = predicted_depth.unsqueeze(1)  # [B,1,H,W]
        depth = F.interpolate(depth, size=(orig_h, orig_w), mode="bicubic", align_corners=False)
        depth = depth.squeeze(1).squeeze(0)  # [H,W]
        depth_np = depth.detach().float().cpu().numpy()

        # Sanitize: avoid negatives/NaNs
        depth_np = np.nan_to_num(depth_np, nan=0.0, posinf=0.0, neginf=0.0)
        depth_np = np.clip(depth_np, 0.0, 1e6)

        # Visualization PNG (robust scaling using percentiles so it doesn't flicker as much)
        # NOTE: viz only — do not use for mapping scale.
        p1 = float(np.percentile(depth_np, 1.0))
        p99 = float(np.percentile(depth_np, 99.0))
        denom = (p99 - p1) if (p99 - p1) > 1e-6 else 1.0
        viz = ((depth_np - p1) / denom)
        viz = np.clip(viz, 0.0, 1.0)
        depth_uint8 = (viz * 255.0).astype(np.uint8)

        depth_img = Image.fromarray(depth_uint8, mode="L")
        buf = io.BytesIO()
        depth_img.save(buf, format="PNG")
        depth_png_base64 = base64.b64encode(buf.getvalue()).decode("utf-8")

        # Raw float16 meters (THIS is what your client should use)
        depth_f16 = depth_np.astype(np.float16)
        depth_raw_base64_f16 = base64.b64encode(depth_f16.tobytes()).decode("utf-8")

        return {
            "id": rid,
            "type": "metric_depth",
            "units": "meters",
            "width": int(orig_w),
            "height": int(orig_h),

            # Visualization
            "depth_png_base64": depth_png_base64,
            "viz_p1": p1,
            "viz_p99": p99,

            # Raw metric depth
            "depth_raw_base64_f16": depth_raw_base64_f16,
            "raw_dtype": "float16",
            "raw_shape": [int(orig_h), int(orig_w)],
        }