CodeJackR commited on
Commit
e52ad65
·
1 Parent(s): 2f4ef92

Fix image upload errors

Browse files
Files changed (1) hide show
  1. handler.py +12 -22
handler.py CHANGED
@@ -29,10 +29,11 @@ class EndpointHandler():
29
  self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
30
  self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
31
 
32
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
33
  """
34
  Called on every HTTP request.
35
  Handles both base64-encoded images and PIL images.
 
36
  """
37
  # 1. Parse and decode the input image
38
  inputs = data.pop("inputs", None)
@@ -41,12 +42,10 @@ class EndpointHandler():
41
 
42
  # Check the type of inputs to handle both base64 strings and pre-processed PIL Images
43
  if isinstance(inputs, Image.Image):
44
- # Input is already a PIL Image
45
  img = inputs.convert("RGB")
46
  elif isinstance(inputs, str):
47
- # Input is a base64-encoded string
48
  if inputs.startswith("data:"):
49
- inputs = inputs.split(",", 1)[1] # Handle data URL format
50
  image_bytes = base64.b64decode(inputs)
51
  img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
52
  else:
@@ -91,36 +90,27 @@ class EndpointHandler():
91
  size = min(width, height) // 8
92
  mask_binary[center_y-size:center_y+size, center_x-size:center_x+size] = 255
93
 
94
- # 5. Encode the output image to base64
95
- out = io.BytesIO()
96
- Image.fromarray(mask_binary).save(out, format="PNG")
97
- out.seek(0)
98
- mask_base64 = base64.b64encode(out.getvalue()).decode('utf-8')
99
-
100
- # 6. Return the response
101
- return [{"mask_png_base64": mask_base64}]
102
 
103
  def main():
104
- # This main function shows how a client would call the endpoint.
105
  input_path = "/Users/rp7/Downloads/test.jpeg"
106
  output_path = "output.png"
107
 
108
- # 1. Prepare the payload
109
  with open(input_path, "rb") as f:
110
  img_bytes = f.read()
111
  img_b64 = base64.b64encode(img_bytes).decode("utf-8")
112
  payload = {"inputs": "data:image/jpeg;base64,{}".format(img_b64)}
113
 
114
- # 2. Instantiate handler and call it
115
  handler = EndpointHandler(path=".")
116
- result = handler(payload)
117
 
118
- # 3. Process the response
119
- mask_b64 = result[0]["mask_png_base64"]
120
- mask_bytes = base64.b64decode(mask_b64)
121
-
122
- mask_img = Image.open(io.BytesIO(mask_bytes))
123
- mask_img.save(output_path)
124
  print("Wrote mask to {}".format(output_path))
125
 
126
  if __name__ == "__main__":
 
29
  self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
30
  self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
31
 
32
+ def __call__(self, data):
33
  """
34
  Called on every HTTP request.
35
  Handles both base64-encoded images and PIL images.
36
+ Returns a PIL Image object.
37
  """
38
  # 1. Parse and decode the input image
39
  inputs = data.pop("inputs", None)
 
42
 
43
  # Check the type of inputs to handle both base64 strings and pre-processed PIL Images
44
  if isinstance(inputs, Image.Image):
 
45
  img = inputs.convert("RGB")
46
  elif isinstance(inputs, str):
 
47
  if inputs.startswith("data:"):
48
+ inputs = inputs.split(",", 1)[1]
49
  image_bytes = base64.b64decode(inputs)
50
  img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
51
  else:
 
90
  size = min(width, height) // 8
91
  mask_binary[center_y-size:center_y+size, center_x-size:center_x+size] = 255
92
 
93
+ # 5. Create and return the output PIL Image
94
+ output_img = Image.fromarray(mask_binary)
95
+ return output_img
 
 
 
 
 
96
 
97
  def main():
98
+ # This main function shows how a client would call the endpoint locally.
99
  input_path = "/Users/rp7/Downloads/test.jpeg"
100
  output_path = "output.png"
101
 
102
+ # 1. Prepare the payload with a base64-encoded image string
103
  with open(input_path, "rb") as f:
104
  img_bytes = f.read()
105
  img_b64 = base64.b64encode(img_bytes).decode("utf-8")
106
  payload = {"inputs": "data:image/jpeg;base64,{}".format(img_b64)}
107
 
108
+ # 2. Instantiate handler and get the PIL Image result
109
  handler = EndpointHandler(path=".")
110
+ result_img = handler(payload)
111
 
112
+ # 3. Save the resulting image
113
+ result_img.save(output_path)
 
 
 
 
114
  print("Wrote mask to {}".format(output_path))
115
 
116
  if __name__ == "__main__":