CodeJackR commited on
Commit
e0fb0e6
·
1 Parent(s): 064d94e

Input image as image

Browse files
Files changed (1) hide show
  1. handler.py +38 -52
handler.py CHANGED
@@ -24,111 +24,97 @@ class EndpointHandler():
24
  self.processor = SamProcessor.from_pretrained(path)
25
  except Exception as e:
26
  # Fallback to loading from a known SAM model if local loading fails
27
- print(f"Failed to load from local path: {e}")
28
  print("Attempting to load from facebook/sam-vit-base")
29
  self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
30
  self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
31
 
32
- def __call__(self, data: Any) -> Any:
33
  """
34
  Called on every HTTP request.
35
- Args:
36
- data (:obj:):
37
- includes the input data and the parameters for the inference.
38
  """
39
- inputs = data.pop("inputs", data)
40
- parameters = data.pop("parameters", {})
 
 
41
 
42
- img = Image.open(io.BytesIO(inputs))
43
-
44
- # img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
45
- # img = raw_images[0]
46
 
47
- # SAM requires input prompts, so we'll generate a center point prompt
48
- height, width = img.size[1], img.size[0] # PIL returns (width, height)
49
 
50
- # Create a center point prompt for automatic segmentation
 
51
  input_points = [[[width // 2, height // 2]]] # Center point
52
  input_labels = [[1]] # Positive prompt
53
 
54
- # Prepare inputs for the model with prompts
55
- inputs = self.processor(img, input_points=input_points, input_labels=input_labels, return_tensors="pt")
56
 
57
- # Generate masks using the model
58
  with torch.no_grad():
59
  outputs = self.model(**inputs)
60
 
 
61
  try:
62
- # Get original image size
63
  original_height, original_width = inputs["original_sizes"][0].tolist()
 
 
64
 
65
- # Get predicted masks and scores
66
- pred_masks = outputs.pred_masks.cpu() # (batch, num_masks, H, W)
67
- iou_scores = outputs.iou_scores.cpu()[0] # (num_masks,)
68
-
69
- # The model might return 4D or 5D tensors. Squeeze if 5D.
70
  if pred_masks.ndim == 5:
71
  pred_masks = pred_masks.squeeze(1)
72
 
73
- # Select the best mask
74
  best_mask_idx = torch.argmax(iou_scores)
75
- best_mask_tensor = pred_masks[0, best_mask_idx, :, :] # (H, W)
76
 
77
- # Upscale the mask to original image size
78
- # Add batch and channel dims for interpolate
79
  upscaled_mask = F.interpolate(
80
  best_mask_tensor.unsqueeze(0).unsqueeze(0).float(),
81
  size=(original_height, original_width),
82
  mode='bilinear',
83
  align_corners=False
84
- ).squeeze() # remove batch/channel dims
85
 
86
- # Convert to binary mask
87
  mask_binary = (upscaled_mask > 0.0).numpy().astype(np.uint8) * 255
88
 
89
  except Exception as e:
90
- print(f"Error processing masks: {e}")
91
- # Fallback
92
- height, width = img.size[1], img.size[0]
93
  mask_binary = np.zeros((height, width), dtype=np.uint8)
94
  center_x, center_y = width // 2, height // 2
95
  size = min(width, height) // 8
96
  mask_binary[center_y-size:center_y+size, center_x-size:center_x+size] = 255
97
 
98
- # Convert result to base64
99
  out = io.BytesIO()
100
  Image.fromarray(mask_binary).save(out, format="PNG")
101
  out.seek(0)
102
  mask_base64 = base64.b64encode(out.getvalue()).decode('utf-8')
103
-
104
- # Decode the returned mask and save
105
- mask_bytes = base64.b64decode(mask_base64)
106
- mask_img = Image.open(io.BytesIO(mask_bytes)).convert("RGB")
107
- # mask_img.save(output_path, format="JPEG")
108
- # print(f"Wrote mask to {output_path}")
109
 
110
- # Return in the expected format
111
- return mask_img
112
 
113
  def main():
114
- # Hardcoded input and output paths
115
  input_path = "/Users/rp7/Downloads/test.jpeg"
116
- output_path = "output.jpg"
117
 
118
- # Read and base64-encode the input image
119
  with open(input_path, "rb") as f:
120
  img_bytes = f.read()
121
  img_b64 = base64.b64encode(img_bytes).decode("utf-8")
122
- data_url = f"data:image/jpeg;base64,{img_b64}"
123
 
 
124
  handler = EndpointHandler(path=".")
125
- result = handler({"inputs": data_url})[0]
126
-
127
- # Decode the returned mask and save
128
- mask_bytes = base64.b64decode(result["mask_png_base64"])
129
- mask_img = Image.open(io.BytesIO(mask_bytes)).convert("RGB")
130
- mask_img.save(output_path, format="JPEG")
131
- print(f"Wrote mask to {output_path}")
 
 
132
 
133
  if __name__ == "__main__":
134
  main()
 
24
  self.processor = SamProcessor.from_pretrained(path)
25
  except Exception as e:
26
  # Fallback to loading from a known SAM model if local loading fails
27
+ print("Failed to load from local path: {}".format(e))
28
  print("Attempting to load from facebook/sam-vit-base")
29
  self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
30
  self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
31
 
32
+ def __call__(self, data):
33
  """
34
  Called on every HTTP request.
35
+ Expecting base64 encoded image in the 'inputs' field.
 
 
36
  """
37
+ # 1. Parse and decode the input image
38
+ image_data = data.pop("inputs", None)
39
+ if not image_data:
40
+ raise ValueError("Missing 'inputs' key with a base64 image string.")
41
 
42
+ if isinstance(image_data, str) and image_data.startswith("data:"):
43
+ image_data = image_data.split(",", 1)[1]
 
 
44
 
45
+ image_bytes = base64.b64decode(image_data)
46
+ img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
47
 
48
+ # 2. Prepare prompts and process the image
49
+ height, width = img.size[1], img.size[0]
50
  input_points = [[[width // 2, height // 2]]] # Center point
51
  input_labels = [[1]] # Positive prompt
52
 
53
+ inputs = self.processor(img, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device)
 
54
 
55
+ # 3. Generate masks
56
  with torch.no_grad():
57
  outputs = self.model(**inputs)
58
 
59
+ # 4. Process and select the best mask
60
  try:
 
61
  original_height, original_width = inputs["original_sizes"][0].tolist()
62
+ pred_masks = outputs.pred_masks.cpu()
63
+ iou_scores = outputs.iou_scores.cpu()[0]
64
 
 
 
 
 
 
65
  if pred_masks.ndim == 5:
66
  pred_masks = pred_masks.squeeze(1)
67
 
 
68
  best_mask_idx = torch.argmax(iou_scores)
69
+ best_mask_tensor = pred_masks[0, best_mask_idx, :, :]
70
 
 
 
71
  upscaled_mask = F.interpolate(
72
  best_mask_tensor.unsqueeze(0).unsqueeze(0).float(),
73
  size=(original_height, original_width),
74
  mode='bilinear',
75
  align_corners=False
76
+ ).squeeze()
77
 
 
78
  mask_binary = (upscaled_mask > 0.0).numpy().astype(np.uint8) * 255
79
 
80
  except Exception as e:
81
+ print("Error processing masks: {}".format(e))
 
 
82
  mask_binary = np.zeros((height, width), dtype=np.uint8)
83
  center_x, center_y = width // 2, height // 2
84
  size = min(width, height) // 8
85
  mask_binary[center_y-size:center_y+size, center_x-size:center_x+size] = 255
86
 
87
+ # 5. Encode the output image to base64
88
  out = io.BytesIO()
89
  Image.fromarray(mask_binary).save(out, format="PNG")
90
  out.seek(0)
91
  mask_base64 = base64.b64encode(out.getvalue()).decode('utf-8')
 
 
 
 
 
 
92
 
93
+ # 6. Return the response
94
+ return [{"mask_png_base64": mask_base64}]
95
 
96
  def main():
97
+ # This main function shows how a client would call the endpoint.
98
  input_path = "/Users/rp7/Downloads/test.jpeg"
99
+ output_path = "output.png"
100
 
101
+ # 1. Prepare the payload
102
  with open(input_path, "rb") as f:
103
  img_bytes = f.read()
104
  img_b64 = base64.b64encode(img_bytes).decode("utf-8")
105
+ payload = {"inputs": "data:image/jpeg;base64,{}".format(img_b64)}
106
 
107
+ # 2. Instantiate handler and call it
108
  handler = EndpointHandler(path=".")
109
+ result = handler(payload)
110
+
111
+ # 3. Process the response
112
+ mask_b64 = result[0]["mask_png_base64"]
113
+ mask_bytes = base64.b64decode(mask_b64)
114
+
115
+ mask_img = Image.open(io.BytesIO(mask_bytes))
116
+ mask_img.save(output_path)
117
+ print("Wrote mask to {}".format(output_path))
118
 
119
  if __name__ == "__main__":
120
  main()