pslime commited on
Commit
83f365f
·
verified ·
1 Parent(s): 5c20c47

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +53 -32
handler.py CHANGED
@@ -5,76 +5,97 @@ import torch.nn.functional as F
5
  import io
6
  import base64
7
  import numpy as np
 
8
 
9
 
10
  class EndpointHandler:
11
  def __init__(self, path=""):
12
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
13
-
14
  self.processor = AutoImageProcessor.from_pretrained(path)
15
  self.model = AutoModelForDepthEstimation.from_pretrained(path)
16
  self.model.to(self.device)
17
  self.model.eval()
18
 
19
- def __call__(self, data):
20
  """
21
- Supports both common endpoint input styles:
22
- 1) JSON: {"inputs": "<base64-encoded image bytes>"} (recommended)
23
- 2) Raw bytes passed through as inputs (fallback)
 
 
 
24
  """
25
- inputs = data.get("inputs", None)
26
- if inputs is None:
27
- raise ValueError('Missing "inputs". Send JSON {"inputs": "<base64>"} or raw bytes.')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- # Decode inputs -> image_bytes
30
- if isinstance(inputs, str):
31
- # JSON base64 string
 
 
 
 
 
32
  try:
33
- image_bytes = base64.b64decode(inputs)
34
- except Exception as e:
35
- raise ValueError(f'Failed to base64-decode "inputs" string: {e}')
36
- elif isinstance(inputs, (bytes, bytearray)):
37
- # raw bytes
38
- image_bytes = bytes(inputs)
39
- else:
40
- raise ValueError(f'Unsupported inputs type: {type(inputs)}')
41
-
42
- # Load image
 
43
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
44
  orig_w, orig_h = image.size
45
 
46
- # Preprocess
47
  inputs_t = self.processor(images=image, return_tensors="pt")
48
  inputs_t = {k: v.to(self.device) for k, v in inputs_t.items()}
49
 
50
- # Inference
51
  with torch.no_grad():
52
  outputs = self.model(**inputs_t)
53
  predicted_depth = outputs.predicted_depth # [B, H, W]
54
 
55
- # Upsample to original image size
56
  depth = predicted_depth.unsqueeze(1) # [B,1,H,W]
57
  depth = F.interpolate(
58
- depth,
59
- size=(orig_h, orig_w),
60
- mode="bicubic",
61
- align_corners=False,
62
  )
63
  depth = depth.squeeze(1).squeeze(0) # [H,W]
64
  depth_np = depth.detach().float().cpu().numpy()
65
 
66
- # Visualization (0..255 grayscale)
67
  dmin, dmax = float(depth_np.min()), float(depth_np.max())
68
  denom = (dmax - dmin) if (dmax - dmin) > 1e-12 else 1.0
69
- depth_norm = (depth_np - dmin) / denom
70
- depth_uint8 = (depth_norm * 255.0).clip(0, 255).astype(np.uint8)
71
 
72
  depth_img = Image.fromarray(depth_uint8, mode="L")
73
  buf = io.BytesIO()
74
  depth_img.save(buf, format="PNG")
75
  depth_png_base64 = base64.b64encode(buf.getvalue()).decode("utf-8")
76
 
77
- # Raw float16 depth (compact) — NOTE: relative depth, not meters
78
  depth_f16 = depth_np.astype(np.float16)
79
  depth_raw_base64_f16 = base64.b64encode(depth_f16.tobytes()).decode("utf-8")
80
 
 
5
  import io
6
  import base64
7
  import numpy as np
8
+ import json
9
 
10
 
11
  class EndpointHandler:
12
  def __init__(self, path=""):
13
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
 
14
  self.processor = AutoImageProcessor.from_pretrained(path)
15
  self.model = AutoModelForDepthEstimation.from_pretrained(path)
16
  self.model.to(self.device)
17
  self.model.eval()
18
 
19
+ def _coerce_to_image_bytes(self, obj):
20
  """
21
+ Accepts:
22
+ - bytes/bytearray: raw image bytes
23
+ - str: base64 string OR JSON string containing {"inputs": "..."} OR plain text (fallback)
24
+ - dict: expects dict["inputs"] (which can itself be str/bytes/etc)
25
+ Returns:
26
+ - image_bytes (bytes)
27
  """
28
+ # If toolkit passes dict
29
+ if isinstance(obj, dict):
30
+ if "inputs" not in obj:
31
+ raise ValueError(f'Missing "inputs" key. Keys={list(obj.keys())}')
32
+ return self._coerce_to_image_bytes(obj["inputs"])
33
+
34
+ # If toolkit passes raw bytes
35
+ if isinstance(obj, (bytes, bytearray)):
36
+ b = bytes(obj)
37
+ # Sometimes body is JSON bytes; try parse
38
+ try:
39
+ txt = b.decode("utf-8")
40
+ if txt.lstrip().startswith("{") and '"inputs"' in txt:
41
+ return self._coerce_to_image_bytes(json.loads(txt))
42
+ except Exception:
43
+ pass
44
+ return b
45
+
46
+ # If toolkit passes str
47
+ if isinstance(obj, str):
48
+ s = obj.strip()
49
 
50
+ # Sometimes it's a JSON string
51
+ if s.startswith("{") and '"inputs"' in s:
52
+ try:
53
+ return self._coerce_to_image_bytes(json.loads(s))
54
+ except Exception:
55
+ pass
56
+
57
+ # Most common: base64 string of image bytes
58
  try:
59
+ return base64.b64decode(s, validate=False)
60
+ except Exception:
61
+ # Last resort: treat as utf-8 bytes (won't be a valid image, but avoids str->BytesIO crash)
62
+ return s.encode("utf-8")
63
+
64
+ raise ValueError(f"Unsupported request type: {type(obj)}")
65
+
66
+ def __call__(self, data):
67
+ image_bytes = self._coerce_to_image_bytes(data)
68
+
69
+ # Now guaranteed bytes
70
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
71
  orig_w, orig_h = image.size
72
 
 
73
  inputs_t = self.processor(images=image, return_tensors="pt")
74
  inputs_t = {k: v.to(self.device) for k, v in inputs_t.items()}
75
 
 
76
  with torch.no_grad():
77
  outputs = self.model(**inputs_t)
78
  predicted_depth = outputs.predicted_depth # [B, H, W]
79
 
80
+ # Upsample to original size
81
  depth = predicted_depth.unsqueeze(1) # [B,1,H,W]
82
  depth = F.interpolate(
83
+ depth, size=(orig_h, orig_w), mode="bicubic", align_corners=False
 
 
 
84
  )
85
  depth = depth.squeeze(1).squeeze(0) # [H,W]
86
  depth_np = depth.detach().float().cpu().numpy()
87
 
88
+ # viz png
89
  dmin, dmax = float(depth_np.min()), float(depth_np.max())
90
  denom = (dmax - dmin) if (dmax - dmin) > 1e-12 else 1.0
91
+ depth_uint8 = (((depth_np - dmin) / denom) * 255.0).clip(0, 255).astype(np.uint8)
 
92
 
93
  depth_img = Image.fromarray(depth_uint8, mode="L")
94
  buf = io.BytesIO()
95
  depth_img.save(buf, format="PNG")
96
  depth_png_base64 = base64.b64encode(buf.getvalue()).decode("utf-8")
97
 
98
+ # raw float16 depth
99
  depth_f16 = depth_np.astype(np.float16)
100
  depth_raw_base64_f16 = base64.b64encode(depth_f16.tobytes()).decode("utf-8")
101