CodeJackR commited on
Commit
40b9d26
·
1 Parent(s): 316f7a6

Manage image resizing

Browse files
Files changed (1) hide show
  1. handler.py +15 -8
handler.py CHANGED
@@ -53,8 +53,9 @@ class EndpointHandler():
53
 
54
  # 2. Prepare prompts and process the image
55
  height, width = img.size[1], img.size[0]
56
- input_points = [[[width // 2, height // 2]]]
57
- input_labels = [[1]]
 
58
 
59
  inputs = self.processor(img, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device)
60
 
@@ -65,7 +66,6 @@ class EndpointHandler():
65
  # 4. Process and select the best mask
66
  try:
67
  # Use the processor's post_process_masks method correctly
68
- # This method expects the raw model outputs and the input metadata
69
  post_processed_masks = self.processor.post_process_masks(
70
  outputs.pred_masks,
71
  inputs["original_sizes"],
@@ -73,13 +73,20 @@ class EndpointHandler():
73
  )
74
 
75
  # post_processed_masks is a list with one element (for batch size 1)
76
- # Each element has shape (num_masks, original_height, original_width)
77
  masks = post_processed_masks[0] # Shape: (num_masks, H, W)
78
-
79
- # Get IoU scores and select the best mask
80
  iou_scores = outputs.iou_scores[0] # Shape: (num_masks,)
81
- best_mask_idx = torch.argmax(iou_scores)
82
- best_mask = masks[best_mask_idx] # Shape: (H, W)
 
 
 
 
 
 
 
 
 
 
83
 
84
  # Convert to numpy and create binary mask
85
  mask_binary = (best_mask > 0.0).cpu().numpy().astype(np.uint8) * 255
 
53
 
54
  # 2. Prepare prompts and process the image
55
  height, width = img.size[1], img.size[0]
56
+ # Use multiple points for better segmentation
57
+ input_points = [[[width // 4, height // 4], [width // 2, height // 2], [3 * width // 4, 3 * height // 4]]]
58
+ input_labels = [[1, 1, 1]] # All positive points
59
 
60
  inputs = self.processor(img, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device)
61
 
 
66
  # 4. Process and select the best mask
67
  try:
68
  # Use the processor's post_process_masks method correctly
 
69
  post_processed_masks = self.processor.post_process_masks(
70
  outputs.pred_masks,
71
  inputs["original_sizes"],
 
73
  )
74
 
75
  # post_processed_masks is a list with one element (for batch size 1)
 
76
  masks = post_processed_masks[0] # Shape: (num_masks, H, W)
 
 
77
  iou_scores = outputs.iou_scores[0] # Shape: (num_masks,)
78
+
79
+ print("Number of masks generated: {}".format(masks.shape[0]))
80
+ print("IoU scores: {}".format(iou_scores.tolist()))
81
+
82
+ # Ensure we have masks and select the best one safely
83
+ if masks.shape[0] > 0:
84
+ best_mask_idx = torch.argmax(iou_scores)
85
+ # Ensure the index is within bounds
86
+ best_mask_idx = min(best_mask_idx.item(), masks.shape[0] - 1)
87
+ best_mask = masks[best_mask_idx] # Shape: (H, W)
88
+ else:
89
+ raise ValueError("No masks were generated")
90
 
91
  # Convert to numpy and create binary mask
92
  mask_binary = (best_mask > 0.0).cpu().numpy().astype(np.uint8) * 255