Zhen Ye commited on
Commit
5658a25
·
1 Parent(s): af1a203

modified segmentation bbox display

Browse files
Files changed (1) hide show
  1. inference.py +42 -13
inference.py CHANGED
@@ -44,8 +44,8 @@ def draw_boxes(
44
  cv2.rectangle(output, (x1, y1), (x2, y2), color, thickness=2)
45
  if label:
46
  font = cv2.FONT_HERSHEY_SIMPLEX
47
- font_scale = 0.5
48
- thickness = 1
49
  text_size, baseline = cv2.getTextSize(label, font, font_scale, thickness)
50
  text_w, text_h = text_size
51
  pad = 4
@@ -67,18 +67,15 @@ def draw_boxes(
67
  return output
68
 
69
 
70
- def draw_masks(frame: np.ndarray, masks: np.ndarray, alpha: float = 0.45) -> np.ndarray:
 
 
 
 
 
71
  output = frame.copy()
72
  if masks is None or len(masks) == 0:
73
  return output
74
- colors = [
75
- (255, 0, 0),
76
- (0, 255, 0),
77
- (0, 0, 255),
78
- (255, 255, 0),
79
- (0, 255, 255),
80
- (255, 0, 255),
81
- ]
82
  for idx, mask in enumerate(masks):
83
  if mask is None:
84
  continue
@@ -88,8 +85,39 @@ def draw_masks(frame: np.ndarray, masks: np.ndarray, alpha: float = 0.45) -> np.
88
  mask = cv2.resize(mask, (output.shape[1], output.shape[0]), interpolation=cv2.INTER_NEAREST)
89
  mask_bool = mask.astype(bool)
90
  overlay = np.zeros_like(output, dtype=np.uint8)
91
- overlay[mask_bool] = colors[idx % len(colors)]
 
 
 
 
 
 
92
  output = cv2.addWeighted(output, 1.0, overlay, alpha, 0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  return output
94
 
95
 
@@ -169,7 +197,8 @@ def infer_segmentation_frame(
169
  lock = _get_model_lock("segmenter", segmenter.name)
170
  with lock:
171
  result = segmenter.predict(frame, text_prompts=text_queries)
172
- return draw_masks(frame, result.masks), result
 
173
 
174
 
175
  def extract_first_frame(video_path: str) -> Tuple[np.ndarray, float, int, int]:
 
44
  cv2.rectangle(output, (x1, y1), (x2, y2), color, thickness=2)
45
  if label:
46
  font = cv2.FONT_HERSHEY_SIMPLEX
47
+ font_scale = 1.0
48
+ thickness = 2
49
  text_size, baseline = cv2.getTextSize(label, font, font_scale, thickness)
50
  text_w, text_h = text_size
51
  pad = 4
 
67
  return output
68
 
69
 
70
+ def draw_masks(
71
+ frame: np.ndarray,
72
+ masks: np.ndarray,
73
+ alpha: float = 0.45,
74
+ labels: Optional[Sequence[str]] = None,
75
+ ) -> np.ndarray:
76
  output = frame.copy()
77
  if masks is None or len(masks) == 0:
78
  return output
 
 
 
 
 
 
 
 
79
  for idx, mask in enumerate(masks):
80
  if mask is None:
81
  continue
 
85
  mask = cv2.resize(mask, (output.shape[1], output.shape[0]), interpolation=cv2.INTER_NEAREST)
86
  mask_bool = mask.astype(bool)
87
  overlay = np.zeros_like(output, dtype=np.uint8)
88
+ label = None
89
+ if labels and idx < len(labels):
90
+ label = labels[idx]
91
+ if not label:
92
+ label = f"object_{idx}"
93
+ color = _color_for_label(label)
94
+ overlay[mask_bool] = color
95
  output = cv2.addWeighted(output, 1.0, overlay, alpha, 0)
96
+ if label:
97
+ coords = np.column_stack(np.where(mask_bool))
98
+ if coords.size:
99
+ y, x = coords[0]
100
+ font = cv2.FONT_HERSHEY_SIMPLEX
101
+ font_scale = 1.0
102
+ thickness = 2
103
+ text_size, baseline = cv2.getTextSize(label, font, font_scale, thickness)
104
+ text_w, text_h = text_size
105
+ pad = 4
106
+ text_x = int(x)
107
+ text_y = max(int(y) - 6, text_h + pad)
108
+ box_top_left = (text_x, text_y - text_h - pad)
109
+ box_bottom_right = (text_x + text_w + pad, text_y + baseline)
110
+ cv2.rectangle(output, box_top_left, box_bottom_right, color, thickness=-1)
111
+ cv2.putText(
112
+ output,
113
+ label,
114
+ (text_x + pad // 2, text_y - 2),
115
+ font,
116
+ font_scale,
117
+ (255, 255, 255),
118
+ thickness,
119
+ lineType=cv2.LINE_AA,
120
+ )
121
  return output
122
 
123
 
 
197
  lock = _get_model_lock("segmenter", segmenter.name)
198
  with lock:
199
  result = segmenter.predict(frame, text_prompts=text_queries)
200
+ labels = text_queries or []
201
+ return draw_masks(frame, result.masks, labels=labels), result
202
 
203
 
204
  def extract_first_frame(video_path: str) -> Tuple[np.ndarray, float, int, int]: