hedemil commited on
Commit
9405dd8
·
1 Parent(s): 7356865

Update handler.py with new logic

Browse files
Files changed (1) hide show
  1. handler.py +35 -12
handler.py CHANGED
@@ -99,7 +99,7 @@ class EndpointHandler:
99
  logger.error(f"Error generating embedding: {e}", exc_info=True)
100
  raise ValueError(f"Failed to generate embedding: {str(e)}")
101
 
102
- def _parse_input(self, data: Union[Dict[str, Any], bytes]) -> Image.Image:
103
  """
104
  Parse input data into PIL Image.
105
 
@@ -118,25 +118,48 @@ class EndpointHandler:
118
  ValueError: If image format is invalid
119
  """
120
  try:
121
- # Case 1: Binary bytes directly
122
- if isinstance(data, bytes):
 
 
 
 
123
  return Image.open(io.BytesIO(data)).convert("RGB")
124
 
125
- # Case 2: Dict with "inputs" key
126
  if isinstance(data, dict):
127
- inputs = data.get("inputs")
 
 
 
 
 
128
 
129
- if inputs is None:
130
- raise ValueError("Missing 'inputs' key in request data")
 
131
 
132
- # Case 2a: Base64 string
133
  if isinstance(inputs, str):
134
- image_bytes = base64.b64decode(inputs)
 
 
 
135
  return Image.open(io.BytesIO(image_bytes)).convert("RGB")
136
 
137
- # Case 2b: Binary bytes
138
- if isinstance(inputs, bytes):
139
- return Image.open(io.BytesIO(inputs)).convert("RGB")
 
 
 
 
 
 
 
 
 
 
140
 
141
  raise ValueError(f"Unsupported inputs type: {type(inputs)}")
142
 
 
99
  logger.error(f"Error generating embedding: {e}", exc_info=True)
100
  raise ValueError(f"Failed to generate embedding: {str(e)}")
101
 
102
+ def _parse_input(self, data: Union[Dict[str, Any], bytes, Image.Image]) -> Image.Image:
103
  """
104
  Parse input data into PIL Image.
105
 
 
118
  ValueError: If image format is invalid
119
  """
120
  try:
121
+ # Case 0: Already a PIL Image
122
+ if isinstance(data, Image.Image):
123
+ return data.convert("RGB")
124
+
125
+ # Case 1: Raw binary bytes directly
126
+ if isinstance(data, (bytes, bytearray)):
127
  return Image.open(io.BytesIO(data)).convert("RGB")
128
 
129
+ # Case 2: Dict with possible variants
130
  if isinstance(data, dict):
131
+ # Many endpoints pass {"inputs": <something>}
132
+ inputs = data.get("inputs", data)
133
+
134
+ # 2a: Inputs is already a PIL image
135
+ if isinstance(inputs, Image.Image):
136
+ return inputs.convert("RGB")
137
 
138
+ # 2b: Raw bytes
139
+ if isinstance(inputs, (bytes, bytearray)):
140
+ return Image.open(io.BytesIO(inputs)).convert("RGB")
141
 
142
+ # 2c: Base64 string (plain or data URL)
143
  if isinstance(inputs, str):
144
+ b64_str = inputs
145
+ if inputs.startswith("data:"):
146
+ b64_str = inputs.split(",", 1)[1]
147
+ image_bytes = base64.b64decode(b64_str)
148
  return Image.open(io.BytesIO(image_bytes)).convert("RGB")
149
 
150
+ # 2d: Nested dict like {"image": <...>}
151
+ if isinstance(inputs, dict) and "image" in inputs:
152
+ inner = inputs["image"]
153
+ if isinstance(inner, Image.Image):
154
+ return inner.convert("RGB")
155
+ if isinstance(inner, (bytes, bytearray)):
156
+ return Image.open(io.BytesIO(inner)).convert("RGB")
157
+ if isinstance(inner, str):
158
+ b64_str = inner
159
+ if inner.startswith("data:"):
160
+ b64_str = inner.split(",", 1)[1]
161
+ image_bytes = base64.b64decode(b64_str)
162
+ return Image.open(io.BytesIO(image_bytes)).convert("RGB")
163
 
164
  raise ValueError(f"Unsupported inputs type: {type(inputs)}")
165