Spaces:
Paused
Paused
Zhen Ye commited on
Commit ·
32a10f3
1
Parent(s): 422c79a
modified bbox display
Browse files- inference.py +58 -4
inference.py
CHANGED
|
@@ -9,13 +9,61 @@ from models.segmenters.model_loader import load_segmenter
|
|
| 9 |
from utils.video import extract_frames, write_video
|
| 10 |
|
| 11 |
|
| 12 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
output = frame.copy()
|
| 14 |
if boxes is None:
|
| 15 |
return output
|
| 16 |
-
for box in boxes:
|
| 17 |
x1, y1, x2, y2 = [int(coord) for coord in box]
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
return output
|
| 20 |
|
| 21 |
|
|
@@ -103,7 +151,13 @@ def infer_frame(
|
|
| 103 |
except Exception:
|
| 104 |
logging.exception("Inference failed for queries %s", text_queries)
|
| 105 |
raise
|
| 106 |
-
return draw_boxes(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
|
| 109 |
def infer_segmentation_frame(
|
|
|
|
| 9 |
from utils.video import extract_frames, write_video
|
| 10 |
|
| 11 |
|
| 12 |
+
def _color_for_label(label: str) -> tuple[int, int, int]:
|
| 13 |
+
# Deterministic BGR color from label text.
|
| 14 |
+
value = abs(hash(label)) % 0xFFFFFF
|
| 15 |
+
blue = value & 0xFF
|
| 16 |
+
green = (value >> 8) & 0xFF
|
| 17 |
+
red = (value >> 16) & 0xFF
|
| 18 |
+
return (blue, green, red)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def draw_boxes(
|
| 22 |
+
frame: np.ndarray,
|
| 23 |
+
boxes: np.ndarray,
|
| 24 |
+
labels: Optional[Sequence[int]] = None,
|
| 25 |
+
queries: Optional[Sequence[str]] = None,
|
| 26 |
+
label_names: Optional[Sequence[str]] = None,
|
| 27 |
+
) -> np.ndarray:
|
| 28 |
output = frame.copy()
|
| 29 |
if boxes is None:
|
| 30 |
return output
|
| 31 |
+
for idx, box in enumerate(boxes):
|
| 32 |
x1, y1, x2, y2 = [int(coord) for coord in box]
|
| 33 |
+
if label_names is not None and idx < len(label_names):
|
| 34 |
+
label = label_names[idx]
|
| 35 |
+
elif labels is not None and idx < len(labels) and queries is not None:
|
| 36 |
+
label_idx = int(labels[idx])
|
| 37 |
+
if 0 <= label_idx < len(queries):
|
| 38 |
+
label = queries[label_idx]
|
| 39 |
+
else:
|
| 40 |
+
label = f"label_{label_idx}"
|
| 41 |
+
else:
|
| 42 |
+
label = f"label_{idx}"
|
| 43 |
+
color = _color_for_label(label)
|
| 44 |
+
cv2.rectangle(output, (x1, y1), (x2, y2), color, thickness=2)
|
| 45 |
+
if label:
|
| 46 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 47 |
+
font_scale = 0.5
|
| 48 |
+
thickness = 1
|
| 49 |
+
text_size, baseline = cv2.getTextSize(label, font, font_scale, thickness)
|
| 50 |
+
text_w, text_h = text_size
|
| 51 |
+
pad = 4
|
| 52 |
+
text_x = x1
|
| 53 |
+
text_y = max(y1 - 6, text_h + pad)
|
| 54 |
+
box_top_left = (text_x, text_y - text_h - pad)
|
| 55 |
+
box_bottom_right = (text_x + text_w + pad, text_y + baseline)
|
| 56 |
+
cv2.rectangle(output, box_top_left, box_bottom_right, color, thickness=-1)
|
| 57 |
+
cv2.putText(
|
| 58 |
+
output,
|
| 59 |
+
label,
|
| 60 |
+
(text_x + pad // 2, text_y - 2),
|
| 61 |
+
font,
|
| 62 |
+
font_scale,
|
| 63 |
+
(255, 255, 255),
|
| 64 |
+
thickness,
|
| 65 |
+
lineType=cv2.LINE_AA,
|
| 66 |
+
)
|
| 67 |
return output
|
| 68 |
|
| 69 |
|
|
|
|
| 151 |
except Exception:
|
| 152 |
logging.exception("Inference failed for queries %s", text_queries)
|
| 153 |
raise
|
| 154 |
+
return draw_boxes(
|
| 155 |
+
frame,
|
| 156 |
+
result.boxes,
|
| 157 |
+
labels=result.labels,
|
| 158 |
+
queries=text_queries,
|
| 159 |
+
label_names=result.label_names,
|
| 160 |
+
), detections
|
| 161 |
|
| 162 |
|
| 163 |
def infer_segmentation_frame(
|