Zhen Ye commited on
Commit
32a10f3
·
1 Parent(s): 422c79a

modified bbox display

Browse files
Files changed (1) hide show
  1. inference.py +58 -4
inference.py CHANGED
@@ -9,13 +9,61 @@ from models.segmenters.model_loader import load_segmenter
9
  from utils.video import extract_frames, write_video
10
 
11
 
12
- def draw_boxes(frame: np.ndarray, boxes: np.ndarray) -> np.ndarray:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  output = frame.copy()
14
  if boxes is None:
15
  return output
16
- for box in boxes:
17
  x1, y1, x2, y2 = [int(coord) for coord in box]
18
- cv2.rectangle(output, (x1, y1), (x2, y2), (0, 255, 0), thickness=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  return output
20
 
21
 
@@ -103,7 +151,13 @@ def infer_frame(
103
  except Exception:
104
  logging.exception("Inference failed for queries %s", text_queries)
105
  raise
106
- return draw_boxes(frame, result.boxes), detections
 
 
 
 
 
 
107
 
108
 
109
  def infer_segmentation_frame(
 
9
  from utils.video import extract_frames, write_video
10
 
11
 
12
+ def _color_for_label(label: str) -> tuple[int, int, int]:
13
+ # Deterministic BGR color from label text.
14
+ value = abs(hash(label)) % 0xFFFFFF
15
+ blue = value & 0xFF
16
+ green = (value >> 8) & 0xFF
17
+ red = (value >> 16) & 0xFF
18
+ return (blue, green, red)
19
+
20
+
21
+ def draw_boxes(
22
+ frame: np.ndarray,
23
+ boxes: np.ndarray,
24
+ labels: Optional[Sequence[int]] = None,
25
+ queries: Optional[Sequence[str]] = None,
26
+ label_names: Optional[Sequence[str]] = None,
27
+ ) -> np.ndarray:
28
  output = frame.copy()
29
  if boxes is None:
30
  return output
31
+ for idx, box in enumerate(boxes):
32
  x1, y1, x2, y2 = [int(coord) for coord in box]
33
+ if label_names is not None and idx < len(label_names):
34
+ label = label_names[idx]
35
+ elif labels is not None and idx < len(labels) and queries is not None:
36
+ label_idx = int(labels[idx])
37
+ if 0 <= label_idx < len(queries):
38
+ label = queries[label_idx]
39
+ else:
40
+ label = f"label_{label_idx}"
41
+ else:
42
+ label = f"label_{idx}"
43
+ color = _color_for_label(label)
44
+ cv2.rectangle(output, (x1, y1), (x2, y2), color, thickness=2)
45
+ if label:
46
+ font = cv2.FONT_HERSHEY_SIMPLEX
47
+ font_scale = 0.5
48
+ thickness = 1
49
+ text_size, baseline = cv2.getTextSize(label, font, font_scale, thickness)
50
+ text_w, text_h = text_size
51
+ pad = 4
52
+ text_x = x1
53
+ text_y = max(y1 - 6, text_h + pad)
54
+ box_top_left = (text_x, text_y - text_h - pad)
55
+ box_bottom_right = (text_x + text_w + pad, text_y + baseline)
56
+ cv2.rectangle(output, box_top_left, box_bottom_right, color, thickness=-1)
57
+ cv2.putText(
58
+ output,
59
+ label,
60
+ (text_x + pad // 2, text_y - 2),
61
+ font,
62
+ font_scale,
63
+ (255, 255, 255),
64
+ thickness,
65
+ lineType=cv2.LINE_AA,
66
+ )
67
  return output
68
 
69
 
 
151
  except Exception:
152
  logging.exception("Inference failed for queries %s", text_queries)
153
  raise
154
+ return draw_boxes(
155
+ frame,
156
+ result.boxes,
157
+ labels=result.labels,
158
+ queries=text_queries,
159
+ label_names=result.label_names,
160
+ ), detections
161
 
162
 
163
  def infer_segmentation_frame(