Zhen Ye commited on
Commit
c7c9a25
·
1 Parent(s): 5658a25

modified segmentation mask display

Browse files
Files changed (1) hide show
  1. inference.py +8 -1
inference.py CHANGED
@@ -70,7 +70,7 @@ def draw_boxes(
70
  def draw_masks(
71
  frame: np.ndarray,
72
  masks: np.ndarray,
73
- alpha: float = 0.45,
74
  labels: Optional[Sequence[str]] = None,
75
  ) -> np.ndarray:
76
  output = frame.copy()
@@ -93,6 +93,11 @@ def draw_masks(
93
  color = _color_for_label(label)
94
  overlay[mask_bool] = color
95
  output = cv2.addWeighted(output, 1.0, overlay, alpha, 0)
 
 
 
 
 
96
  if label:
97
  coords = np.column_stack(np.where(mask_bool))
98
  if coords.size:
@@ -198,6 +203,8 @@ def infer_segmentation_frame(
198
  with lock:
199
  result = segmenter.predict(frame, text_prompts=text_queries)
200
  labels = text_queries or []
 
 
201
  return draw_masks(frame, result.masks, labels=labels), result
202
 
203
 
 
70
  def draw_masks(
71
  frame: np.ndarray,
72
  masks: np.ndarray,
73
+ alpha: float = 0.65,
74
  labels: Optional[Sequence[str]] = None,
75
  ) -> np.ndarray:
76
  output = frame.copy()
 
93
  color = _color_for_label(label)
94
  overlay[mask_bool] = color
95
  output = cv2.addWeighted(output, 1.0, overlay, alpha, 0)
96
+ contours, _ = cv2.findContours(
97
+ mask_bool.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
98
+ )
99
+ if contours:
100
+ cv2.drawContours(output, contours, -1, color, thickness=2)
101
  if label:
102
  coords = np.column_stack(np.where(mask_bool))
103
  if coords.size:
 
203
  with lock:
204
  result = segmenter.predict(frame, text_prompts=text_queries)
205
  labels = text_queries or []
206
+ if len(labels) == 1:
207
+ labels = [labels[0] for _ in range(len(result.masks or []))]
208
  return draw_masks(frame, result.masks, labels=labels), result
209
 
210