MogensR commited on
Commit
19a2b07
·
1 Parent(s): 2586f05

Update utils/segmentation.py

Browse files
Files changed (1) hide show
  1. utils/segmentation.py +60 -5
utils/segmentation.py CHANGED
@@ -51,6 +51,58 @@ class SegmentationError(Exception):
51
  "SegmentationError",
52
  ]
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  # ============================================================================
55
  # MAIN API
56
  # ============================================================================
@@ -124,10 +176,13 @@ def segment_person_hq_original(image: np.ndarray, predictor: Any, fallback_enabl
124
  point_labels=labels,
125
  multimask_output=True,
126
  )
 
 
127
  if masks is not None and len(masks):
128
- mask = _process_mask(masks[int(np.argmax(scores))])
129
  if _validate_mask_quality(mask, image.shape[:2]):
130
  return mask
 
131
  if fallback_enabled:
132
  return _classical_segmentation_cascade(image)
133
  raise RuntimeError("SAM2 failed and fallback disabled")
@@ -164,10 +219,9 @@ def _sam2_predict(image: np.ndarray, predictor: Any,
164
  point_labels=labels,
165
  multimask_output=True,
166
  )
167
- if masks is None or len(masks) == 0:
168
- raise RuntimeError("SAM2 produced no masks")
169
- best = masks[int(np.argmax(scores))] if scores is not None else masks[0]
170
- return _process_mask(best)
171
 
172
 
173
  def _generate_smart_prompts(image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
@@ -275,6 +329,7 @@ def _validate_mask_quality(mask: np.ndarray, shape: Tuple[int,int]) -> bool:
275
  return MIN_AREA_RATIO <= ratio <= MAX_AREA_RATIO
276
 
277
  def _process_mask(mask: np.ndarray) -> np.ndarray:
 
278
  if mask.dtype in (np.float32, np.float64):
279
  if mask.max() <= 1.0:
280
  mask = (mask*255).astype(np.uint8)
 
51
  "SegmentationError",
52
  ]
53
 
54
+ # ============================================================================
55
+ # SAM2 TO MATANYONE MASK BRIDGE
56
+ # ============================================================================
57
+ def _sam2_to_matanyone_mask(masks: Any, scores: Any = None) -> np.ndarray:
58
+ """
59
+ Convert SAM2 multi-mask output to single best mask for MatAnyone.
60
+ SAM2 returns (N, H, W) where N is typically 3 masks.
61
+ We need to return a single (H, W) mask.
62
+ """
63
+ if masks is None or len(masks) == 0:
64
+ raise SegmentationError("No masks returned from SAM2")
65
+
66
+ # Handle torch tensors
67
+ if isinstance(masks, torch.Tensor):
68
+ masks = masks.cpu().numpy()
69
+ if scores is not None and isinstance(scores, torch.Tensor):
70
+ scores = scores.cpu().numpy()
71
+
72
+ # Ensure we have the right shape
73
+ if masks.ndim == 4: # (B, N, H, W)
74
+ masks = masks[0] # Take first batch
75
+ if masks.ndim != 3: # Should be (N, H, W)
76
+ raise SegmentationError(f"Unexpected mask shape: {masks.shape}")
77
+
78
+ # Select best mask
79
+ if scores is not None and len(scores) > 0:
80
+ best_idx = int(np.argmax(scores))
81
+ else:
82
+ # Fallback: pick mask with largest area
83
+ areas = [np.sum(m > 0.5) for m in masks]
84
+ best_idx = int(np.argmax(areas))
85
+
86
+ mask = masks[best_idx]
87
+
88
+ # Convert to uint8 binary mask
89
+ if mask.dtype in (np.float32, np.float64):
90
+ mask = (mask > 0.5).astype(np.uint8) * 255
91
+ elif mask.dtype != np.uint8:
92
+ mask = mask.astype(np.uint8)
93
+
94
+ # Ensure single channel
95
+ if mask.ndim == 3:
96
+ mask = mask[:, :, 0] if mask.shape[2] > 1 else mask.squeeze()
97
+
98
+ # Binary threshold
99
+ _, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
100
+
101
+ # Verify output shape
102
+ assert mask.ndim == 2, f"Output mask must be 2D, got shape {mask.shape}"
103
+
104
+ return mask
105
+
106
  # ============================================================================
107
  # MAIN API
108
  # ============================================================================
 
176
  point_labels=labels,
177
  multimask_output=True,
178
  )
179
+
180
+ # Use the bridge function to get single best mask
181
  if masks is not None and len(masks):
182
+ mask = _sam2_to_matanyone_mask(masks, scores)
183
  if _validate_mask_quality(mask, image.shape[:2]):
184
  return mask
185
+
186
  if fallback_enabled:
187
  return _classical_segmentation_cascade(image)
188
  raise RuntimeError("SAM2 failed and fallback disabled")
 
219
  point_labels=labels,
220
  multimask_output=True,
221
  )
222
+
223
+ # Use the bridge function to convert multi-mask to single mask
224
+ return _sam2_to_matanyone_mask(masks, scores)
 
225
 
226
 
227
  def _generate_smart_prompts(image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
 
329
  return MIN_AREA_RATIO <= ratio <= MAX_AREA_RATIO
330
 
331
  def _process_mask(mask: np.ndarray) -> np.ndarray:
332
+ """Legacy mask processor - kept for compatibility but mostly replaced by _sam2_to_matanyone_mask"""
333
  if mask.dtype in (np.float32, np.float64):
334
  if mask.max() <= 1.0:
335
  mask = (mask*255).astype(np.uint8)