mastari commited on
Commit
edb5977
·
1 Parent(s): e3b2ec4

Add debug logging for payload tracing

Browse files
Files changed (1) hide show
  1. handler.py +63 -58
handler.py CHANGED
@@ -1,4 +1,5 @@
1
- # Fixed handler for BiRefNet endpoint — now supports base64 + URLs + file paths
 
2
 
3
  from typing import Dict, Any, Tuple
4
  import os
@@ -12,11 +13,11 @@ import torch
12
  from torchvision import transforms
13
  from transformers import AutoModelForImageSegmentation
14
 
15
- torch.set_float32_matmul_precision(["high", "highest"][0])
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
  # ======================================================
19
- # Utility Functions
20
  # ======================================================
21
  def refine_foreground(image, mask, r=90):
22
  if mask.size != image.size:
@@ -24,16 +25,13 @@ def refine_foreground(image, mask, r=90):
24
  image = np.array(image) / 255.0
25
  mask = np.array(mask) / 255.0
26
  estimated_foreground = FB_blur_fusion_foreground_estimator_2(image, mask, r=r)
27
- image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8))
28
- return image_masked
29
-
30
 
31
  def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
32
  alpha = alpha[:, :, None]
33
  F, blur_B = FB_blur_fusion_foreground_estimator(image, image, image, alpha, r)
34
  return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
35
 
36
-
37
  def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
38
  if isinstance(image, Image.Image):
39
  image = np.array(image) / 255.0
@@ -43,15 +41,13 @@ def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
43
  blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
44
  blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
45
  F = blurred_F + alpha * (image - alpha * blurred_F - (1 - alpha) * blurred_B)
46
- F = np.clip(F, 0, 1)
47
- return F, blurred_B
48
-
49
 
50
  # ======================================================
51
  # Preprocessing
52
  # ======================================================
53
- class ImagePreprocessor():
54
- def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
55
  self.transform_image = transforms.Compose([
56
  transforms.Resize(resolution),
57
  transforms.ToTensor(),
@@ -61,7 +57,6 @@ class ImagePreprocessor():
61
  def proc(self, image: Image.Image) -> torch.Tensor:
62
  return self.transform_image(image)
63
 
64
-
65
  # ======================================================
66
  # Model and Endpoint
67
  # ======================================================
@@ -81,68 +76,79 @@ usage_to_weights_file = {
81
  'General-legacy': 'BiRefNet-legacy'
82
  }
83
 
84
- usage = 'General'
85
- if usage in ['General-Lite-2K']:
86
- resolution = (2560, 1440)
87
- elif usage in ['General-reso_512']:
88
- resolution = (512, 512)
89
- elif usage in ['General-HR', 'Matting-HR']:
90
- resolution = (2048, 2048)
91
- else:
92
- resolution = (1024, 1024)
93
-
94
  half_precision = True
95
 
96
-
97
- class EndpointHandler():
98
- def __init__(self, path=''):
 
 
99
  self.birefnet = AutoModelForImageSegmentation.from_pretrained(
100
- '/'.join(('zhengpeng7', usage_to_weights_file[usage])),
101
  trust_remote_code=True
102
  )
103
- self.birefnet.to(device)
104
- self.birefnet.eval()
105
  if half_precision:
106
  self.birefnet.half()
107
  print("✅ BiRefNet model loaded successfully.")
108
 
109
  def __call__(self, data: Dict[str, Any]):
110
- """
111
- Accepts either:
112
- - URL (http:// or https://)
113
- - Base64 (raw or data:image/...;base64,...)
114
- - File path
115
- """
116
  image_src = data.get("inputs")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  if image_src is None:
118
  raise ValueError("Missing 'inputs' key in request payload")
119
 
120
- # ✅ Handle base64 or data URI
121
- if isinstance(image_src, str):
122
- if image_src.startswith("data:image"):
123
- header, b64data = image_src.split(",", 1)
124
- image_ori = Image.open(io.BytesIO(base64.b64decode(b64data)))
125
- elif image_src[:4] in ("/9j/", "iVBOR", "R0lG", "UklG"):
126
- image_ori = Image.open(io.BytesIO(base64.b64decode(image_src)))
127
- elif image_src.startswith("http"):
128
- response = requests.get(image_src)
129
- image_ori = Image.open(io.BytesIO(response.content))
130
- elif os.path.isfile(image_src):
131
- image_ori = Image.open(image_src)
 
 
 
 
 
 
 
 
 
132
  else:
133
- raise ValueError("Unsupported input string format.")
134
- else:
135
- # Assume it's an array-like
136
- image_ori = Image.fromarray(image_src)
 
137
 
138
- image = image_ori.convert('RGB')
139
 
140
- # Preprocess
141
- image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
142
- image_proc = image_preprocessor.proc(image)
143
- image_proc = image_proc.unsqueeze(0)
144
 
145
- # Predict
146
  with torch.no_grad():
147
  preds = self.birefnet(
148
  image_proc.to(device).half() if half_precision else image_proc.to(device)
@@ -154,7 +160,6 @@ class EndpointHandler():
154
  image_masked = refine_foreground(image, pred_pil)
155
  image_masked.putalpha(pred_pil.resize(image.size))
156
 
157
- # Return as base64 for easy JSON transport
158
  buffer = io.BytesIO()
159
  image_masked.save(buffer, format="PNG")
160
  encoded_result = base64.b64encode(buffer.getvalue()).decode("utf-8")
 
1
+ # handler.py BiRefNet endpoint handler
2
+ # Fully instrumented for debugging input structure and format.
3
 
4
  from typing import Dict, Any, Tuple
5
  import os
 
13
  from torchvision import transforms
14
  from transformers import AutoModelForImageSegmentation
15
 
16
+ torch.set_float32_matmul_precision("high")
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
 
19
  # ======================================================
20
+ # Utility functions
21
  # ======================================================
22
  def refine_foreground(image, mask, r=90):
23
  if mask.size != image.size:
 
25
  image = np.array(image) / 255.0
26
  mask = np.array(mask) / 255.0
27
  estimated_foreground = FB_blur_fusion_foreground_estimator_2(image, mask, r=r)
28
+ return Image.fromarray((estimated_foreground * 255.0).astype(np.uint8))
 
 
29
 
30
  def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
31
  alpha = alpha[:, :, None]
32
  F, blur_B = FB_blur_fusion_foreground_estimator(image, image, image, alpha, r)
33
  return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
34
 
 
35
  def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
36
  if isinstance(image, Image.Image):
37
  image = np.array(image) / 255.0
 
41
  blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
42
  blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
43
  F = blurred_F + alpha * (image - alpha * blurred_F - (1 - alpha) * blurred_B)
44
+ return np.clip(F, 0, 1), blurred_B
 
 
45
 
46
  # ======================================================
47
  # Preprocessing
48
  # ======================================================
49
+ class ImagePreprocessor:
50
+ def __init__(self, resolution: Tuple[int, int] = (1024, 1024)):
51
  self.transform_image = transforms.Compose([
52
  transforms.Resize(resolution),
53
  transforms.ToTensor(),
 
57
  def proc(self, image: Image.Image) -> torch.Tensor:
58
  return self.transform_image(image)
59
 
 
60
  # ======================================================
61
  # Model and Endpoint
62
  # ======================================================
 
76
  'General-legacy': 'BiRefNet-legacy'
77
  }
78
 
79
+ usage = "General"
80
+ resolution = (1024, 1024)
 
 
 
 
 
 
 
 
81
  half_precision = True
82
 
83
+ # ======================================================
84
+ # Endpoint Handler
85
+ # ======================================================
86
+ class EndpointHandler:
87
+ def __init__(self, path=""):
88
  self.birefnet = AutoModelForImageSegmentation.from_pretrained(
89
+ f"zhengpeng7/{usage_to_weights_file[usage]}",
90
  trust_remote_code=True
91
  )
92
+ self.birefnet.to(device).eval()
 
93
  if half_precision:
94
  self.birefnet.half()
95
  print("✅ BiRefNet model loaded successfully.")
96
 
97
  def __call__(self, data: Dict[str, Any]):
 
 
 
 
 
 
98
  image_src = data.get("inputs")
99
+
100
+ # ================= DEBUG LOGS =================
101
+ print("\n==============================")
102
+ print("🧩 DEBUG: Incoming data structure")
103
+ print(f"Type of data: {type(data)}")
104
+ print(f"Keys: {list(data.keys()) if isinstance(data, dict) else 'N/A'}")
105
+ print(f"Type of inputs: {type(image_src)}")
106
+ if isinstance(image_src, str):
107
+ print(f" Length: {len(image_src)}")
108
+ print(f" Starts with: {repr(image_src[:120])}")
109
+ elif isinstance(image_src, bytes):
110
+ print(f" Bytes length: {len(image_src)}")
111
+ else:
112
+ print(f" Value preview: {repr(image_src)[:200]}")
113
+ print("==============================\n", flush=True)
114
+ # ===============================================
115
+
116
  if image_src is None:
117
  raise ValueError("Missing 'inputs' key in request payload")
118
 
119
+ # ✅ Decode base64 / data URI / URL / file path
120
+ try:
121
+ if isinstance(image_src, (bytes, bytearray)):
122
+ image_ori = Image.open(io.BytesIO(image_src))
123
+ elif isinstance(image_src, str):
124
+ image_src = image_src.strip()
125
+
126
+ if image_src.startswith("data:image"):
127
+ header, b64data = image_src.split(",", 1)
128
+ image_bytes = base64.b64decode(b64data)
129
+ image_ori = Image.open(io.BytesIO(image_bytes))
130
+ elif any(image_src.startswith(pfx) for pfx in ("iVBOR", "/9j/", "R0lG", "UklG")):
131
+ image_bytes = base64.b64decode(image_src)
132
+ image_ori = Image.open(io.BytesIO(image_bytes))
133
+ elif image_src.startswith("http"):
134
+ response = requests.get(image_src)
135
+ image_ori = Image.open(io.BytesIO(response.content))
136
+ elif os.path.isfile(image_src):
137
+ image_ori = Image.open(image_src)
138
+ else:
139
+ raise ValueError(f"Unsupported input string format: {image_src[:40]}...")
140
  else:
141
+ image_ori = Image.fromarray(np.array(image_src))
142
+
143
+ except Exception as e:
144
+ print(f"❌ ERROR decoding input: {e}")
145
+ raise
146
 
147
+ image = image_ori.convert("RGB")
148
 
149
+ image_preprocessor = ImagePreprocessor(resolution=resolution)
150
+ image_proc = image_preprocessor.proc(image).unsqueeze(0)
 
 
151
 
 
152
  with torch.no_grad():
153
  preds = self.birefnet(
154
  image_proc.to(device).half() if half_precision else image_proc.to(device)
 
160
  image_masked = refine_foreground(image, pred_pil)
161
  image_masked.putalpha(pred_pil.resize(image.size))
162
 
 
163
  buffer = io.BytesIO()
164
  image_masked.save(buffer, format="PNG")
165
  encoded_result = base64.b64encode(buffer.getvalue()).decode("utf-8")