CodeJackR commited on
Commit
d29fd9e
·
1 Parent(s): 06bd1fa

Manage image resizing

Browse files
Files changed (1) hide show
  1. handler.py +20 -19
handler.py CHANGED
@@ -64,24 +64,21 @@ class EndpointHandler():
64
 
65
  # 4. Process and select the best mask
66
  try:
67
- original_height, original_width = inputs["original_sizes"][0].tolist()
68
- pred_masks = outputs.pred_masks.cpu()
 
 
 
 
 
 
 
69
  iou_scores = outputs.iou_scores.cpu()[0]
70
-
71
- if pred_masks.ndim == 5:
72
- pred_masks = pred_masks.squeeze(1)
73
-
74
  best_mask_idx = torch.argmax(iou_scores)
75
- best_mask_tensor = pred_masks[0, best_mask_idx, :, :]
76
-
77
- upscaled_mask = F.interpolate(
78
- best_mask_tensor.unsqueeze(0).unsqueeze(0).float(),
79
- size=(original_height, original_width),
80
- mode='bilinear',
81
- align_corners=False
82
- ).squeeze()
83
 
84
- mask_binary = (upscaled_mask > 0.0).numpy().astype(np.uint8) * 255
 
85
 
86
  except Exception as e:
87
  print("Error processing masks: {}".format(e))
@@ -107,11 +104,15 @@ def main():
107
 
108
  # 2. Instantiate handler and get the PIL Image result
109
  handler = EndpointHandler(path=".")
110
- result_img = handler(payload)
111
 
112
- # 3. Save the resulting image
113
- result_img.save(output_path)
114
- print("Wrote mask to {}".format(output_path))
 
 
 
 
115
 
116
  if __name__ == "__main__":
117
  main()
 
64
 
65
  # 4. Process and select the best mask
66
  try:
67
+ # Use the processor's post-processing utility to resize masks and remove padding
68
+ masks = self.processor.post_process_masks(
69
+ outputs.pred_masks.cpu(),
70
+ inputs["original_sizes"].cpu(),
71
+ inputs["reshaped_input_sizes"].cpu()
72
+ )[0]
73
+
74
+ # The output of post_process_masks is a tensor of shape (num_masks, H, W)
75
+ # where H and W are the original image dimensions.
76
  iou_scores = outputs.iou_scores.cpu()[0]
 
 
 
 
77
  best_mask_idx = torch.argmax(iou_scores)
78
+ best_mask_tensor = masks[best_mask_idx, :, :]
 
 
 
 
 
 
 
79
 
80
+ # Convert to binary mask (float tensor to uint8 numpy array)
81
+ mask_binary = (best_mask_tensor > 0.0).numpy().astype(np.uint8) * 255
82
 
83
  except Exception as e:
84
  print("Error processing masks: {}".format(e))
 
104
 
105
  # 2. Instantiate handler and get the PIL Image result
106
  handler = EndpointHandler(path=".")
107
+ result = handler(payload)
108
 
109
+ # 3. Extract the image from the result and save it
110
+ if result and isinstance(result, list) and 'mask' in result[0]:
111
+ result_img = result[0]['mask']
112
+ result_img.save(output_path)
113
+ print("Wrote mask to {}".format(output_path))
114
+ else:
115
+ print("Failed to get a valid mask from the handler.")
116
 
117
  if __name__ == "__main__":
118
  main()