CodeJackR
commited on
Commit
·
d29fd9e
1
Parent(s):
06bd1fa
Manage image resizing
Browse files- handler.py +20 -19
handler.py
CHANGED
|
@@ -64,24 +64,21 @@ class EndpointHandler():
|
|
| 64 |
|
| 65 |
# 4. Process and select the best mask
|
| 66 |
try:
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
iou_scores = outputs.iou_scores.cpu()[0]
|
| 70 |
-
|
| 71 |
-
if pred_masks.ndim == 5:
|
| 72 |
-
pred_masks = pred_masks.squeeze(1)
|
| 73 |
-
|
| 74 |
best_mask_idx = torch.argmax(iou_scores)
|
| 75 |
-
best_mask_tensor =
|
| 76 |
-
|
| 77 |
-
upscaled_mask = F.interpolate(
|
| 78 |
-
best_mask_tensor.unsqueeze(0).unsqueeze(0).float(),
|
| 79 |
-
size=(original_height, original_width),
|
| 80 |
-
mode='bilinear',
|
| 81 |
-
align_corners=False
|
| 82 |
-
).squeeze()
|
| 83 |
|
| 84 |
-
|
|
|
|
| 85 |
|
| 86 |
except Exception as e:
|
| 87 |
print("Error processing masks: {}".format(e))
|
|
@@ -107,11 +104,15 @@ def main():
|
|
| 107 |
|
| 108 |
# 2. Instantiate handler and get the PIL Image result
|
| 109 |
handler = EndpointHandler(path=".")
|
| 110 |
-
|
| 111 |
|
| 112 |
-
# 3.
|
| 113 |
-
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
if __name__ == "__main__":
|
| 117 |
main()
|
|
|
|
| 64 |
|
| 65 |
# 4. Process and select the best mask
|
| 66 |
try:
|
| 67 |
+
# Use the processor's post-processing utility to resize masks and remove padding
|
| 68 |
+
masks = self.processor.post_process_masks(
|
| 69 |
+
outputs.pred_masks.cpu(),
|
| 70 |
+
inputs["original_sizes"].cpu(),
|
| 71 |
+
inputs["reshaped_input_sizes"].cpu()
|
| 72 |
+
)[0]
|
| 73 |
+
|
| 74 |
+
# The output of post_process_masks is a tensor of shape (num_masks, H, W)
|
| 75 |
+
# where H and W are the original image dimensions.
|
| 76 |
iou_scores = outputs.iou_scores.cpu()[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
best_mask_idx = torch.argmax(iou_scores)
|
| 78 |
+
best_mask_tensor = masks[best_mask_idx, :, :]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
+
# Convert to binary mask (float tensor to uint8 numpy array)
|
| 81 |
+
mask_binary = (best_mask_tensor > 0.0).numpy().astype(np.uint8) * 255
|
| 82 |
|
| 83 |
except Exception as e:
|
| 84 |
print("Error processing masks: {}".format(e))
|
|
|
|
| 104 |
|
| 105 |
# 2. Instantiate handler and get the PIL Image result
|
| 106 |
handler = EndpointHandler(path=".")
|
| 107 |
+
result = handler(payload)
|
| 108 |
|
| 109 |
+
# 3. Extract the image from the result and save it
|
| 110 |
+
if result and isinstance(result, list) and 'mask' in result[0]:
|
| 111 |
+
result_img = result[0]['mask']
|
| 112 |
+
result_img.save(output_path)
|
| 113 |
+
print("Wrote mask to {}".format(output_path))
|
| 114 |
+
else:
|
| 115 |
+
print("Failed to get a valid mask from the handler.")
|
| 116 |
|
| 117 |
if __name__ == "__main__":
|
| 118 |
main()
|