CodeJackR commited on
Commit
316f7a6
·
1 Parent(s): 38a30a4

Manage image resizing

Browse files
Files changed (1) hide show
  1. handler.py +15 -44
handler.py CHANGED
@@ -64,54 +64,25 @@ class EndpointHandler():
64
 
65
  # 4. Process and select the best mask
66
  try:
67
- # Get the original and reshaped sizes
68
- original_sizes = inputs["original_sizes"][0].tolist() # [H, W]
69
- reshaped_input_sizes = inputs["reshaped_input_sizes"][0].tolist() # [H, W]
 
 
 
 
70
 
71
- # Get predicted masks and scores
72
- pred_masks = outputs.pred_masks.cpu() # Shape: (batch, num_masks, H, W)
73
- iou_scores = outputs.iou_scores.cpu()[0] # Shape: (num_masks,)
74
 
75
- # Handle different tensor dimensions
76
- if pred_masks.ndim == 5:
77
- pred_masks = pred_masks.squeeze(1) # Remove extra dimension if present
78
-
79
- # Select the best mask
80
  best_mask_idx = torch.argmax(iou_scores)
81
- best_mask = pred_masks[0, best_mask_idx, :, :] # Shape: (H, W)
82
-
83
- # The mask is currently at the model's internal resolution
84
- # We need to resize it to the reshaped input size first, then crop/pad to original size
85
-
86
- # Step 1: Resize to reshaped input size
87
- resized_mask = F.interpolate(
88
- best_mask.unsqueeze(0).unsqueeze(0).float(),
89
- size=reshaped_input_sizes,
90
- mode='bilinear',
91
- align_corners=False
92
- ).squeeze()
93
-
94
- # Step 2: Handle padding/cropping to get back to original size
95
- original_h, original_w = original_sizes
96
- reshaped_h, reshaped_w = reshaped_input_sizes
97
-
98
- # Calculate padding that was added during preprocessing
99
- if reshaped_h > original_h or reshaped_w > original_w:
100
- # There was padding, we need to crop
101
- start_h = (reshaped_h - original_h) // 2
102
- start_w = (reshaped_w - original_w) // 2
103
- final_mask = resized_mask[start_h:start_h + original_h, start_w:start_w + original_w]
104
- else:
105
- # No padding or different scaling, just resize directly
106
- final_mask = F.interpolate(
107
- resized_mask.unsqueeze(0).unsqueeze(0),
108
- size=original_sizes,
109
- mode='bilinear',
110
- align_corners=False
111
- ).squeeze()
112
 
113
- # Convert to binary mask
114
- mask_binary = (final_mask > 0.0).numpy().astype(np.uint8) * 255
115
 
116
  except Exception as e:
117
  print("Error processing masks: {}".format(e))
 
64
 
65
  # 4. Process and select the best mask
66
  try:
67
+ # Use the processor's post_process_masks method correctly
68
+ # This method expects the raw model outputs and the input metadata
69
+ post_processed_masks = self.processor.post_process_masks(
70
+ outputs.pred_masks,
71
+ inputs["original_sizes"],
72
+ inputs["reshaped_input_sizes"]
73
+ )
74
 
75
+ # post_processed_masks is a list with one element (for batch size 1)
76
+ # Each element has shape (num_masks, original_height, original_width)
77
+ masks = post_processed_masks[0] # Shape: (num_masks, H, W)
78
 
79
+ # Get IoU scores and select the best mask
80
+ iou_scores = outputs.iou_scores[0] # Shape: (num_masks,)
 
 
 
81
  best_mask_idx = torch.argmax(iou_scores)
82
+ best_mask = masks[best_mask_idx] # Shape: (H, W)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
+ # Convert to numpy and create binary mask
85
+ mask_binary = (best_mask > 0.0).cpu().numpy().astype(np.uint8) * 255
86
 
87
  except Exception as e:
88
  print("Error processing masks: {}".format(e))