pslime commited on
Commit
4d5eab5
·
verified ·
1 Parent(s): ddcc98b

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +81 -0
handler.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoImageProcessor, AutoModelForDepthEstimation
2
+ from PIL import Image
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import io
6
+ import base64
7
+ import numpy as np
8
+
9
+
10
+ class EndpointHandler:
11
+ def __init__(self, path=""):
12
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+ # Load processor + model from the *endpoint repo*
15
+ self.processor = AutoImageProcessor.from_pretrained(path)
16
+ self.model = AutoModelForDepthEstimation.from_pretrained(path)
17
+ self.model.to(self.device)
18
+ self.model.eval()
19
+
20
+ def __call__(self, data):
21
+ """
22
+ Expected request body: raw image bytes (recommended)
23
+ Hugging Face Endpoints typically pass:
24
+ data["inputs"] -> bytes
25
+ """
26
+ image_bytes = data.get("inputs", None)
27
+ if image_bytes is None:
28
+ raise ValueError('Missing "inputs". Send raw image bytes as the request body.')
29
+
30
+ # Load image
31
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
32
+ orig_w, orig_h = image.size
33
+
34
+ # Preprocess
35
+ inputs = self.processor(images=image, return_tensors="pt")
36
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
37
+
38
+ # Inference
39
+ with torch.no_grad():
40
+ outputs = self.model(**inputs)
41
+ predicted_depth = outputs.predicted_depth # shape: [B, H, W] (or similar)
42
+
43
+ # Upsample depth to original image size (as in the docs)
44
+ # Make it [B,1,H,W] for interpolate
45
+ depth = predicted_depth.unsqueeze(1)
46
+ depth = F.interpolate(
47
+ depth,
48
+ size=(orig_h, orig_w),
49
+ mode="bicubic",
50
+ align_corners=False,
51
+ )
52
+ depth = depth.squeeze(1).squeeze(0) # [H, W]
53
+ depth_np = depth.detach().float().cpu().numpy()
54
+
55
+ # ---- Make a nice visualization PNG (0..255) ----
56
+ dmin, dmax = float(depth_np.min()), float(depth_np.max())
57
+ denom = (dmax - dmin) if (dmax - dmin) > 1e-12 else 1.0
58
+ depth_norm = (depth_np - dmin) / denom
59
+ depth_uint8 = (depth_norm * 255.0).clip(0, 255).astype(np.uint8)
60
+
61
+ depth_img = Image.fromarray(depth_uint8, mode="L") # grayscale
62
+ buf = io.BytesIO()
63
+ depth_img.save(buf, format="PNG")
64
+ depth_png_base64 = base64.b64encode(buf.getvalue()).decode("utf-8")
65
+
66
+ # ---- Optional: return raw depth as float16 bytes (compact) ----
67
+ depth_f16 = depth_np.astype(np.float16)
68
+ raw_bytes = depth_f16.tobytes()
69
+ depth_raw_base64_f16 = base64.b64encode(raw_bytes).decode("utf-8")
70
+
71
+ return {
72
+ "type": "relative_depth",
73
+ "width": orig_w,
74
+ "height": orig_h,
75
+ "depth_png_base64": depth_png_base64,
76
+ "depth_raw_base64_f16": depth_raw_base64_f16,
77
+ "raw_dtype": "float16",
78
+ "raw_shape": [orig_h, orig_w],
79
+ "viz_min": dmin,
80
+ "viz_max": dmax,
81
+ }