File size: 4,090 Bytes
4d5eab5
 
 
 
 
 
 
83f365f
4d5eab5
 
 
 
 
 
 
 
 
 
83f365f
4d5eab5
83f365f
 
 
 
 
 
4d5eab5
83f365f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c20c47
83f365f
 
 
 
 
 
 
 
5c20c47
83f365f
 
 
 
 
 
 
 
 
 
 
4d5eab5
 
 
5c20c47
 
4d5eab5
 
5c20c47
 
4d5eab5
83f365f
5c20c47
4d5eab5
83f365f
4d5eab5
5c20c47
4d5eab5
 
83f365f
4d5eab5
 
83f365f
4d5eab5
5c20c47
4d5eab5
 
 
 
83f365f
4d5eab5
5c20c47
4d5eab5
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from transformers import AutoImageProcessor, AutoModelForDepthEstimation
from PIL import Image
import torch
import torch.nn.functional as F
import io
import base64
import numpy as np
import json


class EndpointHandler:
    def __init__(self, path=""):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.processor = AutoImageProcessor.from_pretrained(path)
        self.model = AutoModelForDepthEstimation.from_pretrained(path)
        self.model.to(self.device)
        self.model.eval()

    def _coerce_to_image_bytes(self, obj):
        """
        Accepts:
          - bytes/bytearray: raw image bytes
          - str: base64 string OR JSON string containing {"inputs": "..."} OR plain text (fallback)
          - dict: expects dict["inputs"] (which can itself be str/bytes/etc)
        Returns:
          - image_bytes (bytes)
        """
        # If toolkit passes dict
        if isinstance(obj, dict):
            if "inputs" not in obj:
                raise ValueError(f'Missing "inputs" key. Keys={list(obj.keys())}')
            return self._coerce_to_image_bytes(obj["inputs"])

        # If toolkit passes raw bytes
        if isinstance(obj, (bytes, bytearray)):
            b = bytes(obj)
            # Sometimes body is JSON bytes; try parse
            try:
                txt = b.decode("utf-8")
                if txt.lstrip().startswith("{") and '"inputs"' in txt:
                    return self._coerce_to_image_bytes(json.loads(txt))
            except Exception:
                pass
            return b

        # If toolkit passes str
        if isinstance(obj, str):
            s = obj.strip()

            # Sometimes it's a JSON string
            if s.startswith("{") and '"inputs"' in s:
                try:
                    return self._coerce_to_image_bytes(json.loads(s))
                except Exception:
                    pass

            # Most common: base64 string of image bytes
            try:
                return base64.b64decode(s, validate=False)
            except Exception:
                # Last resort: treat as utf-8 bytes (won't be a valid image, but avoids str->BytesIO crash)
                return s.encode("utf-8")

        raise ValueError(f"Unsupported request type: {type(obj)}")

    def __call__(self, data):
        image_bytes = self._coerce_to_image_bytes(data)

        # Now guaranteed bytes
        image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        orig_w, orig_h = image.size

        inputs_t = self.processor(images=image, return_tensors="pt")
        inputs_t = {k: v.to(self.device) for k, v in inputs_t.items()}

        with torch.no_grad():
            outputs = self.model(**inputs_t)
            predicted_depth = outputs.predicted_depth  # [B, H, W]

        # Upsample to original size
        depth = predicted_depth.unsqueeze(1)  # [B,1,H,W]
        depth = F.interpolate(
            depth, size=(orig_h, orig_w), mode="bicubic", align_corners=False
        )
        depth = depth.squeeze(1).squeeze(0)  # [H,W]
        depth_np = depth.detach().float().cpu().numpy()

        # viz png
        dmin, dmax = float(depth_np.min()), float(depth_np.max())
        denom = (dmax - dmin) if (dmax - dmin) > 1e-12 else 1.0
        depth_uint8 = (((depth_np - dmin) / denom) * 255.0).clip(0, 255).astype(np.uint8)

        depth_img = Image.fromarray(depth_uint8, mode="L")
        buf = io.BytesIO()
        depth_img.save(buf, format="PNG")
        depth_png_base64 = base64.b64encode(buf.getvalue()).decode("utf-8")

        # raw float16 depth
        depth_f16 = depth_np.astype(np.float16)
        depth_raw_base64_f16 = base64.b64encode(depth_f16.tobytes()).decode("utf-8")

        return {
            "type": "relative_depth",
            "width": orig_w,
            "height": orig_h,
            "depth_png_base64": depth_png_base64,
            "depth_raw_base64_f16": depth_raw_base64_f16,
            "raw_dtype": "float16",
            "raw_shape": [orig_h, orig_w],
            "viz_min": dmin,
            "viz_max": dmax,
        }