|
|
import os |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
import torch |
|
|
import random |
|
|
import math |
|
|
from matplotlib.patches import Rectangle |
|
|
import itertools |
|
|
from typing import Any, Dict, List, Tuple, Optional, Union |
|
|
|
|
|
from laser.preprocess.mask_generation_grounding_dino import mask_to_bbox |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clean_label(label): |
|
|
"""Replace underscores and slashes with spaces for uniformity.""" |
|
|
return label.replace("_", " ").replace("/", " ") |
|
|
|
|
|
|
|
|
def format_cate_preds(cate_preds): |
|
|
|
|
|
obj_pred_dict = {} |
|
|
for (oid, label), prob in cate_preds.items(): |
|
|
|
|
|
clean_pred = clean_label(label) |
|
|
if oid not in obj_pred_dict: |
|
|
obj_pred_dict[oid] = [] |
|
|
obj_pred_dict[oid].append((clean_pred, prob)) |
|
|
for oid in obj_pred_dict: |
|
|
obj_pred_dict[oid].sort(key=lambda x: x[1], reverse=True) |
|
|
return obj_pred_dict |
|
|
|
|
|
def format_binary_cate_preds(binary_preds): |
|
|
frame_binary_preds = [] |
|
|
for key, score in binary_preds.items(): |
|
|
|
|
|
try: |
|
|
f_id, (subj, obj), pred_rel = key |
|
|
frame_binary_preds.append((f_id, subj, obj, pred_rel, score)) |
|
|
except Exception as e: |
|
|
print("Skipping key with unexpected format:", key) |
|
|
continue |
|
|
frame_binary_preds.sort(key=lambda x: x[3], reverse=True) |
|
|
return frame_binary_preds |
|
|
|
|
|
_FONT = cv2.FONT_HERSHEY_SIMPLEX |
|
|
|
|
|
|
|
|
def _to_numpy_mask(mask: Union[np.ndarray, torch.Tensor, None]) -> Optional[np.ndarray]: |
|
|
if mask is None: |
|
|
return None |
|
|
if isinstance(mask, torch.Tensor): |
|
|
mask_np = mask.detach().cpu().numpy() |
|
|
else: |
|
|
mask_np = np.asarray(mask) |
|
|
if mask_np.ndim == 0: |
|
|
return None |
|
|
if mask_np.ndim == 3: |
|
|
mask_np = np.squeeze(mask_np) |
|
|
if mask_np.ndim != 2: |
|
|
return None |
|
|
if mask_np.dtype == bool: |
|
|
return mask_np |
|
|
return mask_np > 0 |
|
|
|
|
|
|
|
|
def _sanitize_bbox(bbox: Union[List[float], Tuple[float, ...], None], width: int, height: int) -> Optional[Tuple[int, int, int, int]]: |
|
|
if bbox is None: |
|
|
return None |
|
|
if isinstance(bbox, (list, tuple)) and len(bbox) >= 4: |
|
|
x1, y1, x2, y2 = [float(b) for b in bbox[:4]] |
|
|
elif isinstance(bbox, np.ndarray) and bbox.size >= 4: |
|
|
x1, y1, x2, y2 = [float(b) for b in bbox.flat[:4]] |
|
|
else: |
|
|
return None |
|
|
x1 = int(np.clip(round(x1), 0, width - 1)) |
|
|
y1 = int(np.clip(round(y1), 0, height - 1)) |
|
|
x2 = int(np.clip(round(x2), 0, width - 1)) |
|
|
y2 = int(np.clip(round(y2), 0, height - 1)) |
|
|
if x2 <= x1 or y2 <= y1: |
|
|
return None |
|
|
return (x1, y1, x2, y2) |
|
|
|
|
|
|
|
|
def _object_color_bgr(obj_id: int) -> Tuple[int, int, int]: |
|
|
color = get_color(obj_id) |
|
|
rgb = [int(np.clip(c, 0.0, 1.0) * 255) for c in color[:3]] |
|
|
return (rgb[2], rgb[1], rgb[0]) |
|
|
|
|
|
|
|
|
def _background_color(color: Tuple[int, int, int]) -> Tuple[int, int, int]: |
|
|
return tuple(int(0.25 * 255 + 0.75 * channel) for channel in color) |
|
|
|
|
|
|
|
|
def _draw_label_block( |
|
|
image: np.ndarray, |
|
|
lines: List[str], |
|
|
anchor: Tuple[int, int], |
|
|
color: Tuple[int, int, int], |
|
|
font_scale: float = 0.5, |
|
|
thickness: int = 1, |
|
|
direction: str = "up", |
|
|
) -> None: |
|
|
if not lines: |
|
|
return |
|
|
img_h, img_w = image.shape[:2] |
|
|
x, y = anchor |
|
|
x = int(np.clip(x, 0, img_w - 1)) |
|
|
y_cursor = int(np.clip(y, 0, img_h - 1)) |
|
|
bg_color = _background_color(color) |
|
|
|
|
|
if direction == "down": |
|
|
for text in lines: |
|
|
text = str(text) |
|
|
(tw, th), baseline = cv2.getTextSize(text, _FONT, font_scale, thickness) |
|
|
left_x = x |
|
|
right_x = min(left_x + tw + 8, img_w - 1) |
|
|
top_y = int(np.clip(y_cursor + 6, 0, img_h - 1)) |
|
|
bottom_y = int(np.clip(top_y + th + baseline + 6, 0, img_h - 1)) |
|
|
if bottom_y <= top_y: |
|
|
break |
|
|
cv2.rectangle(image, (left_x, top_y), (right_x, bottom_y), bg_color, -1) |
|
|
text_x = left_x + 4 |
|
|
text_y = min(bottom_y - baseline - 2, img_h - 1) |
|
|
cv2.putText(image, text, (text_x, text_y), _FONT, font_scale, (0, 0, 0), thickness, cv2.LINE_AA) |
|
|
y_cursor = bottom_y |
|
|
else: |
|
|
for text in lines: |
|
|
text = str(text) |
|
|
(tw, th), baseline = cv2.getTextSize(text, _FONT, font_scale, thickness) |
|
|
top_y = max(y_cursor - th - baseline - 6, 0) |
|
|
left_x = x |
|
|
right_x = min(left_x + tw + 8, img_w - 1) |
|
|
bottom_y = min(top_y + th + baseline + 6, img_h - 1) |
|
|
cv2.rectangle(image, (left_x, top_y), (right_x, bottom_y), bg_color, -1) |
|
|
text_x = left_x + 4 |
|
|
text_y = min(bottom_y - baseline - 2, img_h - 1) |
|
|
cv2.putText(image, text, (text_x, text_y), _FONT, font_scale, (0, 0, 0), thickness, cv2.LINE_AA) |
|
|
y_cursor = top_y |
|
|
|
|
|
|
|
|
def _draw_centered_label( |
|
|
image: np.ndarray, |
|
|
text: str, |
|
|
center: Tuple[int, int], |
|
|
color: Tuple[int, int, int], |
|
|
font_scale: float = 0.5, |
|
|
thickness: int = 1, |
|
|
) -> None: |
|
|
text = str(text) |
|
|
img_h, img_w = image.shape[:2] |
|
|
(tw, th), baseline = cv2.getTextSize(text, _FONT, font_scale, thickness) |
|
|
cx = int(np.clip(center[0], 0, img_w - 1)) |
|
|
cy = int(np.clip(center[1], 0, img_h - 1)) |
|
|
left_x = int(np.clip(cx - tw // 2 - 4, 0, img_w - 1)) |
|
|
top_y = int(np.clip(cy - th // 2 - baseline - 4, 0, img_h - 1)) |
|
|
right_x = int(np.clip(left_x + tw + 8, 0, img_w - 1)) |
|
|
bottom_y = int(np.clip(top_y + th + baseline + 6, 0, img_h - 1)) |
|
|
cv2.rectangle(image, (left_x, top_y), (right_x, bottom_y), _background_color(color), -1) |
|
|
text_x = left_x + 4 |
|
|
text_y = min(bottom_y - baseline - 2, img_h - 1) |
|
|
cv2.putText(image, text, (text_x, text_y), _FONT, font_scale, (0, 0, 0), thickness, cv2.LINE_AA) |
|
|
|
|
|
|
|
|
def _extract_frame_entities(store: Union[Dict[int, Dict[int, Any]], List, None], frame_idx: int) -> Dict[int, Any]: |
|
|
if isinstance(store, dict): |
|
|
frame_entry = store.get(frame_idx, {}) |
|
|
elif isinstance(store, list) and 0 <= frame_idx < len(store): |
|
|
frame_entry = store[frame_idx] |
|
|
else: |
|
|
frame_entry = {} |
|
|
if isinstance(frame_entry, dict): |
|
|
return frame_entry |
|
|
if isinstance(frame_entry, list): |
|
|
return {i: value for i, value in enumerate(frame_entry)} |
|
|
return {} |
|
|
|
|
|
|
|
|
def _label_anchor_and_direction( |
|
|
bbox: Tuple[int, int, int, int], |
|
|
position: str, |
|
|
) -> Tuple[Tuple[int, int], str]: |
|
|
x1, y1, x2, y2 = bbox |
|
|
if position == "bottom": |
|
|
return (x1, y2), "down" |
|
|
return (x1, y1), "up" |
|
|
|
|
|
|
|
|
def _draw_bbox_with_label( |
|
|
image: np.ndarray, |
|
|
bbox: Tuple[int, int, int, int], |
|
|
obj_id: int, |
|
|
title: Optional[str] = None, |
|
|
sub_lines: Optional[List[str]] = None, |
|
|
label_position: str = "top", |
|
|
) -> None: |
|
|
color = _object_color_bgr(obj_id) |
|
|
cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), color, 2) |
|
|
head = title if title else f"#{obj_id}" |
|
|
if not head.startswith("#"): |
|
|
head = f"#{obj_id} {head}" |
|
|
lines = [head] |
|
|
if sub_lines: |
|
|
lines.extend(sub_lines) |
|
|
anchor, direction = _label_anchor_and_direction(bbox, label_position) |
|
|
_draw_label_block(image, lines, anchor, color, direction=direction) |
|
|
|
|
|
|
|
|
def render_sam_frames( |
|
|
frames: Union[np.ndarray, List[np.ndarray]], |
|
|
sam_masks: Union[Dict[int, Dict[int, Union[np.ndarray, torch.Tensor]]], List, None], |
|
|
dino_labels: Optional[Dict[int, str]] = None, |
|
|
) -> List[np.ndarray]: |
|
|
results: List[np.ndarray] = [] |
|
|
frames_iterable = frames if isinstance(frames, list) else list(frames) |
|
|
dino_labels = dino_labels or {} |
|
|
|
|
|
for frame_idx, frame in enumerate(frames_iterable): |
|
|
if frame is None: |
|
|
continue |
|
|
frame_rgb = np.asarray(frame) |
|
|
frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR) |
|
|
overlay = frame_bgr.astype(np.float32) |
|
|
masks_for_frame = _extract_frame_entities(sam_masks, frame_idx) |
|
|
|
|
|
for obj_id, mask in masks_for_frame.items(): |
|
|
mask_np = _to_numpy_mask(mask) |
|
|
if mask_np is None or not np.any(mask_np): |
|
|
continue |
|
|
color = _object_color_bgr(obj_id) |
|
|
alpha = 0.45 |
|
|
overlay[mask_np] = (1.0 - alpha) * overlay[mask_np] + alpha * np.array(color, dtype=np.float32) |
|
|
|
|
|
annotated = np.clip(overlay, 0, 255).astype(np.uint8) |
|
|
frame_h, frame_w = annotated.shape[:2] |
|
|
|
|
|
for obj_id, mask in masks_for_frame.items(): |
|
|
mask_np = _to_numpy_mask(mask) |
|
|
if mask_np is None or not np.any(mask_np): |
|
|
continue |
|
|
bbox = mask_to_bbox(mask_np) |
|
|
bbox = _sanitize_bbox(bbox, frame_w, frame_h) |
|
|
if not bbox: |
|
|
continue |
|
|
label = dino_labels.get(obj_id) |
|
|
title = f"{label}" if label else None |
|
|
_draw_bbox_with_label(annotated, bbox, obj_id, title=title) |
|
|
|
|
|
results.append(cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)) |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def render_dino_frames( |
|
|
frames: Union[np.ndarray, List[np.ndarray]], |
|
|
bboxes: Union[Dict[int, Dict[int, Union[List[float], np.ndarray]]], List, None], |
|
|
dino_labels: Optional[Dict[int, str]] = None, |
|
|
) -> List[np.ndarray]: |
|
|
results: List[np.ndarray] = [] |
|
|
frames_iterable = frames if isinstance(frames, list) else list(frames) |
|
|
dino_labels = dino_labels or {} |
|
|
|
|
|
for frame_idx, frame in enumerate(frames_iterable): |
|
|
if frame is None: |
|
|
continue |
|
|
frame_rgb = np.asarray(frame) |
|
|
annotated = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR) |
|
|
frame_h, frame_w = annotated.shape[:2] |
|
|
frame_bboxes = _extract_frame_entities(bboxes, frame_idx) |
|
|
|
|
|
for obj_id, bbox_values in frame_bboxes.items(): |
|
|
bbox = _sanitize_bbox(bbox_values, frame_w, frame_h) |
|
|
if not bbox: |
|
|
continue |
|
|
label = dino_labels.get(obj_id) |
|
|
title = f"{label}" if label else None |
|
|
_draw_bbox_with_label(annotated, bbox, obj_id, title=title) |
|
|
|
|
|
results.append(cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)) |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def render_vine_frame_sets( |
|
|
frames: Union[np.ndarray, List[np.ndarray]], |
|
|
bboxes: Union[Dict[int, Dict[int, Union[List[float], np.ndarray]]], List, None], |
|
|
cat_label_lookup: Dict[int, Tuple[str, float]], |
|
|
unary_lookup: Dict[int, Dict[int, List[Tuple[float, str]]]], |
|
|
binary_lookup: Dict[int, List[Tuple[Tuple[int, int], List[Tuple[float, str]]]]], |
|
|
masks: Union[Dict[int, Dict[int, Union[np.ndarray, torch.Tensor]]], List, None] = None, |
|
|
) -> Dict[str, List[np.ndarray]]: |
|
|
frame_groups: Dict[str, List[np.ndarray]] = { |
|
|
"object": [], |
|
|
"unary": [], |
|
|
"binary": [], |
|
|
"all": [], |
|
|
} |
|
|
frames_iterable = frames if isinstance(frames, list) else list(frames) |
|
|
|
|
|
for frame_idx, frame in enumerate(frames_iterable): |
|
|
if frame is None: |
|
|
continue |
|
|
frame_rgb = np.asarray(frame) |
|
|
base_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR) |
|
|
frame_h, frame_w = base_bgr.shape[:2] |
|
|
frame_bboxes = _extract_frame_entities(bboxes, frame_idx) |
|
|
frame_masks = _extract_frame_entities(masks, frame_idx) if masks is not None else {} |
|
|
|
|
|
objects_bgr = base_bgr.copy() |
|
|
unary_bgr = base_bgr.copy() |
|
|
binary_bgr = base_bgr.copy() |
|
|
all_bgr = base_bgr.copy() |
|
|
|
|
|
bbox_lookup: Dict[int, Tuple[int, int, int, int]] = {} |
|
|
unary_lines_lookup: Dict[int, List[str]] = {} |
|
|
titles_lookup: Dict[int, Optional[str]] = {} |
|
|
|
|
|
for obj_id, bbox_values in frame_bboxes.items(): |
|
|
bbox = _sanitize_bbox(bbox_values, frame_w, frame_h) |
|
|
if not bbox: |
|
|
continue |
|
|
bbox_lookup[obj_id] = bbox |
|
|
cat_label, cat_prob = cat_label_lookup.get(obj_id, (None, None)) |
|
|
title_parts = [] |
|
|
if cat_label: |
|
|
if cat_prob is not None: |
|
|
title_parts.append(f"{cat_label} {cat_prob:.2f}") |
|
|
else: |
|
|
title_parts.append(cat_label) |
|
|
titles_lookup[obj_id] = " ".join(title_parts) if title_parts else None |
|
|
unary_preds = unary_lookup.get(frame_idx, {}).get(obj_id, []) |
|
|
unary_lines = [f"{label} {prob:.2f}" for prob, label in unary_preds] |
|
|
unary_lines_lookup[obj_id] = unary_lines |
|
|
|
|
|
for obj_id, bbox in bbox_lookup.items(): |
|
|
unary_lines = unary_lines_lookup.get(obj_id, []) |
|
|
if not unary_lines: |
|
|
continue |
|
|
mask_raw = frame_masks.get(obj_id) |
|
|
mask_np = _to_numpy_mask(mask_raw) |
|
|
if mask_np is None or not np.any(mask_np): |
|
|
continue |
|
|
color = np.array(_object_color_bgr(obj_id), dtype=np.float32) |
|
|
alpha = 0.45 |
|
|
for target in (unary_bgr, all_bgr): |
|
|
target_vals = target[mask_np].astype(np.float32) |
|
|
blended = (1.0 - alpha) * target_vals + alpha * color |
|
|
target[mask_np] = np.clip(blended, 0, 255).astype(np.uint8) |
|
|
|
|
|
for obj_id, bbox in bbox_lookup.items(): |
|
|
title = titles_lookup.get(obj_id) |
|
|
unary_lines = unary_lines_lookup.get(obj_id, []) |
|
|
_draw_bbox_with_label(objects_bgr, bbox, obj_id, title=title, label_position="top") |
|
|
_draw_bbox_with_label(unary_bgr, bbox, obj_id, title=title, label_position="top") |
|
|
if unary_lines: |
|
|
anchor, direction = _label_anchor_and_direction(bbox, "bottom") |
|
|
_draw_label_block(unary_bgr, unary_lines, anchor, _object_color_bgr(obj_id), direction=direction) |
|
|
_draw_bbox_with_label(binary_bgr, bbox, obj_id, title=title, label_position="top") |
|
|
_draw_bbox_with_label(all_bgr, bbox, obj_id, title=title, label_position="top") |
|
|
if unary_lines: |
|
|
anchor, direction = _label_anchor_and_direction(bbox, "bottom") |
|
|
_draw_label_block(all_bgr, unary_lines, anchor, _object_color_bgr(obj_id), direction=direction) |
|
|
|
|
|
for obj_pair, relation_preds in binary_lookup.get(frame_idx, []): |
|
|
if len(obj_pair) != 2 or not relation_preds: |
|
|
continue |
|
|
subj_id, obj_id = obj_pair |
|
|
subj_bbox = bbox_lookup.get(subj_id) |
|
|
obj_bbox = bbox_lookup.get(obj_id) |
|
|
if not subj_bbox or not obj_bbox: |
|
|
continue |
|
|
start, end = relation_line(subj_bbox, obj_bbox) |
|
|
color = tuple(int(c) for c in np.clip( |
|
|
(np.array(_object_color_bgr(subj_id), dtype=np.float32) + |
|
|
np.array(_object_color_bgr(obj_id), dtype=np.float32)) / 2.0, |
|
|
0, 255 |
|
|
)) |
|
|
prob, relation = relation_preds[0] |
|
|
label_text = f"{relation} {prob:.2f}" |
|
|
mid_point = (int((start[0] + end[0]) / 2), int((start[1] + end[1]) / 2)) |
|
|
cv2.line(binary_bgr, start, end, color, 6, cv2.LINE_AA) |
|
|
cv2.line(all_bgr, start, end, color, 6, cv2.LINE_AA) |
|
|
_draw_centered_label(binary_bgr, label_text, mid_point, color) |
|
|
_draw_centered_label(all_bgr, label_text, mid_point, color) |
|
|
|
|
|
frame_groups["object"].append(cv2.cvtColor(objects_bgr, cv2.COLOR_BGR2RGB)) |
|
|
frame_groups["unary"].append(cv2.cvtColor(unary_bgr, cv2.COLOR_BGR2RGB)) |
|
|
frame_groups["binary"].append(cv2.cvtColor(binary_bgr, cv2.COLOR_BGR2RGB)) |
|
|
frame_groups["all"].append(cv2.cvtColor(all_bgr, cv2.COLOR_BGR2RGB)) |
|
|
|
|
|
return frame_groups |
|
|
|
|
|
|
|
|
def render_vine_frames( |
|
|
frames: Union[np.ndarray, List[np.ndarray]], |
|
|
bboxes: Union[Dict[int, Dict[int, Union[List[float], np.ndarray]]], List, None], |
|
|
cat_label_lookup: Dict[int, Tuple[str, float]], |
|
|
unary_lookup: Dict[int, Dict[int, List[Tuple[float, str]]]], |
|
|
binary_lookup: Dict[int, List[Tuple[Tuple[int, int], List[Tuple[float, str]]]]], |
|
|
masks: Union[Dict[int, Dict[int, Union[np.ndarray, torch.Tensor]]], List, None] = None, |
|
|
) -> List[np.ndarray]: |
|
|
return render_vine_frame_sets( |
|
|
frames, |
|
|
bboxes, |
|
|
cat_label_lookup, |
|
|
unary_lookup, |
|
|
binary_lookup, |
|
|
masks, |
|
|
).get("all", []) |
|
|
|
|
|
def color_for_cate_correctness(obj_pred_dict, gt_labels, topk_object): |
|
|
all_colors = [] |
|
|
all_texts = [] |
|
|
for (obj_id, bbox, gt_label) in gt_labels: |
|
|
preds = obj_pred_dict.get(obj_id, []) |
|
|
if len(preds) == 0: |
|
|
top1 = "N/A" |
|
|
box_color = (0, 0, 255) |
|
|
else: |
|
|
top1, prob1 = preds[0] |
|
|
topk_labels = [p[0] for p in preds[:topk_object]] |
|
|
|
|
|
if top1.lower() == gt_label.lower(): |
|
|
box_color = (0, 255, 0) |
|
|
elif gt_label.lower() in [p.lower() for p in topk_labels]: |
|
|
box_color = (0, 165, 255) |
|
|
else: |
|
|
box_color = (0, 0, 255) |
|
|
|
|
|
label_text = f"ID:{obj_id}/P:{top1}/GT:{gt_label}" |
|
|
all_colors.append(box_color) |
|
|
all_texts.append(label_text) |
|
|
return all_colors, all_texts |
|
|
|
|
|
def plot_unary(frame_img, gt_labels, all_colors, all_texts): |
|
|
|
|
|
for (obj_id, bbox, gt_label), box_color, label_text in zip(gt_labels, all_colors, all_texts): |
|
|
x1, y1, x2, y2 = map(int, bbox) |
|
|
cv2.rectangle(frame_img, (x1, y1), (x2, y2), color=box_color, thickness=2) |
|
|
(tw, th), baseline = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) |
|
|
cv2.rectangle(frame_img, (x1, y1 - th - baseline - 4), (x1 + tw, y1), box_color, -1) |
|
|
cv2.putText(frame_img, label_text, (x1, y1 - 2), cv2.FONT_HERSHEY_SIMPLEX, |
|
|
0.5, (0, 0, 0), 1, cv2.LINE_AA) |
|
|
|
|
|
return frame_img |
|
|
|
|
|
def get_white_pane(pane_height, |
|
|
pane_width=600, |
|
|
header_height = 50, |
|
|
header_font = cv2.FONT_HERSHEY_SIMPLEX, |
|
|
header_font_scale = 0.7, |
|
|
header_thickness = 2, |
|
|
header_color = (0, 0, 0)): |
|
|
|
|
|
white_pane = 255 * np.ones((pane_height, pane_width, 3), dtype=np.uint8) |
|
|
|
|
|
|
|
|
left_width = int(pane_width * 0.6) |
|
|
right_width = pane_width - left_width |
|
|
left_pane = white_pane[:, :left_width, :].copy() |
|
|
right_pane = white_pane[:, left_width:, :].copy() |
|
|
|
|
|
cv2.putText(left_pane, "Binary Predictions", (10, header_height - 30), |
|
|
header_font, header_font_scale, header_color, header_thickness, cv2.LINE_AA) |
|
|
cv2.putText(right_pane, "Ground Truth", (10, header_height - 30), |
|
|
header_font, header_font_scale, header_color, header_thickness, cv2.LINE_AA) |
|
|
|
|
|
return white_pane |
|
|
|
|
|
|
|
|
def plot_binary_sg(frame_img, |
|
|
white_pane, |
|
|
bin_preds, |
|
|
gt_relations, |
|
|
topk_binary, |
|
|
header_height=50, |
|
|
indicator_size=20, |
|
|
pane_width=600): |
|
|
|
|
|
line_height = 30 |
|
|
x_text = 10 |
|
|
y_text_left = header_height + 10 |
|
|
y_text_right = header_height + 10 |
|
|
|
|
|
|
|
|
left_width = int(pane_width * 0.6) |
|
|
right_width = pane_width - left_width |
|
|
left_pane = white_pane[:, :left_width, :].copy() |
|
|
right_pane = white_pane[:, left_width:, :].copy() |
|
|
|
|
|
for (subj, pred_rel, obj, score) in bin_preds[:topk_binary]: |
|
|
correct = any((subj == gt[0] and pred_rel.lower() == gt[2].lower() and obj == gt[1]) |
|
|
for gt in gt_relations) |
|
|
indicator_color = (0, 255, 0) if correct else (0, 0, 255) |
|
|
cv2.rectangle(left_pane, (x_text, y_text_left - indicator_size + 5), |
|
|
(x_text + indicator_size, y_text_left + 5), indicator_color, -1) |
|
|
text = f"{subj} - {pred_rel} - {obj} :: {score:.2f}" |
|
|
cv2.putText(left_pane, text, (x_text + indicator_size + 5, y_text_left + 5), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 1, cv2.LINE_AA) |
|
|
y_text_left += line_height |
|
|
|
|
|
|
|
|
for gt in gt_relations: |
|
|
if len(gt) != 3: |
|
|
continue |
|
|
text = f"{gt[0]} - {gt[2]} - {gt[1]}" |
|
|
cv2.putText(right_pane, text, (x_text, y_text_right + 5), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 1, cv2.LINE_AA) |
|
|
y_text_right += line_height |
|
|
|
|
|
|
|
|
combined_pane = np.hstack((left_pane, right_pane)) |
|
|
combined_image = np.hstack((frame_img, combined_pane)) |
|
|
return combined_image |
|
|
|
|
|
def visualized_frame(frame_img, |
|
|
bboxes, |
|
|
object_ids, |
|
|
gt_labels, |
|
|
cate_preds, |
|
|
binary_preds, |
|
|
gt_relations, |
|
|
topk_object, |
|
|
topk_binary, |
|
|
phase="unary"): |
|
|
|
|
|
"""Return the combined annotated frame for frame index i as an image (in BGR).""" |
|
|
|
|
|
|
|
|
|
|
|
if phase == "unary": |
|
|
objs = [] |
|
|
for ((_, f_id, obj_id), bbox, gt_label) in zip(object_ids, bboxes, gt_labels): |
|
|
gt_label = clean_label(gt_label) |
|
|
objs.append((obj_id, bbox, gt_label)) |
|
|
|
|
|
formatted_cate_preds = format_cate_preds(cate_preds) |
|
|
all_colors, all_texts = color_for_cate_correctness(formatted_cate_preds, gt_labels, topk_object) |
|
|
updated_frame_img = plot_unary(frame_img, gt_labels, all_colors, all_texts) |
|
|
return updated_frame_img |
|
|
|
|
|
else: |
|
|
|
|
|
formatted_binary_preds = format_binary_cate_preds(binary_preds) |
|
|
|
|
|
|
|
|
|
|
|
gt_relations = [(clean_label(str(s)), clean_label(str(o)), clean_label(rel)) for s, o, rel in gt_relations] |
|
|
|
|
|
pane_width = 600 |
|
|
pane_height = frame_img.shape[0] |
|
|
|
|
|
|
|
|
header_height = 50 |
|
|
white_pane = get_white_pane(pane_height, pane_width, header_height=header_height) |
|
|
|
|
|
combined_image = plot_binary_sg(frame_img, white_pane, formatted_binary_preds, gt_relations, topk_binary) |
|
|
|
|
|
return combined_image |
|
|
|
|
|
def show_mask(mask, ax, obj_id=None, det_class=None, random_color=False): |
|
|
|
|
|
mask = np.array(mask) |
|
|
|
|
|
if mask.ndim == 3: |
|
|
|
|
|
if mask.shape[0] == 1: |
|
|
mask = mask.squeeze(0) |
|
|
|
|
|
elif mask.shape[2] == 1: |
|
|
mask = mask.squeeze(2) |
|
|
|
|
|
assert mask.ndim == 2, f"Mask must be 2D after squeezing, got shape {mask.shape}" |
|
|
|
|
|
if random_color: |
|
|
color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0) |
|
|
else: |
|
|
cmap = plt.get_cmap("gist_rainbow") |
|
|
cmap_idx = 0 if obj_id is None else obj_id |
|
|
color = list(cmap((cmap_idx * 47) % 256)) |
|
|
color[3] = 0.5 |
|
|
color = np.array(color) |
|
|
|
|
|
|
|
|
mask_expanded = mask[..., None] |
|
|
mask_image = mask_expanded * color.reshape(1, 1, -1) |
|
|
|
|
|
|
|
|
if not det_class is None: |
|
|
|
|
|
y_indices, x_indices = np.where(mask > 0) |
|
|
if y_indices.size > 0 and x_indices.size > 0: |
|
|
x_min, x_max = x_indices.min(), x_indices.max() |
|
|
y_min, y_max = y_indices.min(), y_indices.max() |
|
|
rect = Rectangle( |
|
|
(x_min, y_min), |
|
|
x_max - x_min, |
|
|
y_max - y_min, |
|
|
linewidth=1.5, |
|
|
edgecolor=color[:3], |
|
|
facecolor="none", |
|
|
alpha=color[3] |
|
|
) |
|
|
ax.add_patch(rect) |
|
|
ax.text( |
|
|
x_min, |
|
|
y_min - 5, |
|
|
f"{det_class}", |
|
|
color="white", |
|
|
fontsize=6, |
|
|
backgroundcolor=np.array(color), |
|
|
alpha=1 |
|
|
) |
|
|
ax.imshow(mask_image) |
|
|
|
|
|
def save_mask_one_image(frame_image, masks, save_path): |
|
|
"""Render masks on top of a frame and store the visualization on disk.""" |
|
|
fig, ax = plt.subplots(1, figsize=(6, 6)) |
|
|
|
|
|
frame_np = ( |
|
|
frame_image.detach().cpu().numpy() |
|
|
if torch.is_tensor(frame_image) |
|
|
else np.asarray(frame_image) |
|
|
) |
|
|
frame_np = np.ascontiguousarray(frame_np) |
|
|
|
|
|
if isinstance(masks, dict): |
|
|
mask_iter = masks.items() |
|
|
else: |
|
|
mask_iter = enumerate(masks) |
|
|
|
|
|
prepared_masks = { |
|
|
obj_id: ( |
|
|
mask.detach().cpu().numpy() |
|
|
if torch.is_tensor(mask) |
|
|
else np.asarray(mask) |
|
|
) |
|
|
for obj_id, mask in mask_iter |
|
|
} |
|
|
|
|
|
ax.imshow(frame_np) |
|
|
ax.axis("off") |
|
|
|
|
|
for obj_id, mask_np in prepared_masks.items(): |
|
|
show_mask(mask_np, ax, obj_id=obj_id, det_class=None, random_color=False) |
|
|
|
|
|
fig.savefig(save_path, bbox_inches="tight", pad_inches=0) |
|
|
plt.close(fig) |
|
|
return save_path |
|
|
|
|
|
def get_video_masks_visualization(video_tensor, |
|
|
video_masks, |
|
|
video_id, |
|
|
video_save_base_dir, |
|
|
oid_class_pred=None, |
|
|
sample_rate = 1): |
|
|
|
|
|
video_save_dir = os.path.join(video_save_base_dir, video_id) |
|
|
if not os.path.exists(video_save_dir): |
|
|
os.makedirs(video_save_dir, exist_ok=True) |
|
|
|
|
|
for frame_id, image in enumerate(video_tensor): |
|
|
if frame_id not in video_masks: |
|
|
print("No mask for Frame", frame_id) |
|
|
continue |
|
|
|
|
|
masks = video_masks[frame_id] |
|
|
save_path = os.path.join(video_save_dir, f"{frame_id}.jpg") |
|
|
get_mask_one_image(image, masks, oid_class_pred) |
|
|
|
|
|
def get_mask_one_image(frame_image, masks, oid_class_pred=None): |
|
|
|
|
|
fig, ax = plt.subplots(1, figsize=(6, 6)) |
|
|
|
|
|
|
|
|
ax.imshow(frame_image) |
|
|
ax.axis('off') |
|
|
|
|
|
if type(masks) == list: |
|
|
masks = {i: m for i, m in enumerate(masks)} |
|
|
|
|
|
|
|
|
for obj_id, mask in masks.items(): |
|
|
det_class = f"{obj_id}. {oid_class_pred[obj_id]}" if not oid_class_pred is None else None |
|
|
show_mask(mask, ax, obj_id=obj_id, det_class=det_class, random_color=False) |
|
|
|
|
|
|
|
|
return fig, ax |
|
|
|
|
|
def save_video(frames, output_filename, output_fps): |
|
|
|
|
|
|
|
|
num_frames = len(frames) |
|
|
frame_h, frame_w = frames.shape[:2] |
|
|
|
|
|
|
|
|
fourcc = cv2.VideoWriter_fourcc(*'avc1') |
|
|
out = cv2.VideoWriter(output_filename, fourcc, output_fps, (frame_w, frame_h)) |
|
|
|
|
|
print(f"Processing {num_frames} frames...") |
|
|
for i in range(num_frames): |
|
|
vis_frame = get_visualized_frame(i) |
|
|
out.write(vis_frame) |
|
|
if i % 10 == 0: |
|
|
print(f"Processed frame {i+1}/{num_frames}") |
|
|
|
|
|
out.release() |
|
|
print(f"Video saved as {output_filename}") |
|
|
|
|
|
|
|
|
def list_depth(lst): |
|
|
"""Calculates the depth of a nested list.""" |
|
|
if not (isinstance(lst, list) or isinstance(lst, torch.Tensor)): |
|
|
return 0 |
|
|
elif (isinstance(lst, torch.Tensor) and lst.shape == torch.Size([])) or (isinstance(lst, list) and len(lst) == 0): |
|
|
return 1 |
|
|
else: |
|
|
return 1 + max(list_depth(item) for item in lst) |
|
|
|
|
|
def normalize_prompt(points, labels): |
|
|
if list_depth(points) == 3: |
|
|
points = torch.stack([p.unsqueeze(0) for p in points]) |
|
|
labels = torch.stack([l.unsqueeze(0) for l in labels]) |
|
|
return points, labels |
|
|
|
|
|
|
|
|
def show_box(box, ax, object_id): |
|
|
if len(box) == 0: |
|
|
return |
|
|
|
|
|
cmap = plt.get_cmap("gist_rainbow") |
|
|
cmap_idx = 0 if object_id is None else object_id |
|
|
color = list(cmap((cmap_idx * 47) % 256)) |
|
|
|
|
|
x0, y0 = box[0], box[1] |
|
|
w, h = box[2] - box[0], box[3] - box[1] |
|
|
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor=color, facecolor=(0,0,0,0), lw=2)) |
|
|
|
|
|
def show_points(coords, labels, ax, object_id=None, marker_size=375): |
|
|
if len(labels) == 0: |
|
|
return |
|
|
|
|
|
pos_points = coords[labels==1] |
|
|
neg_points = coords[labels==0] |
|
|
|
|
|
cmap = plt.get_cmap("gist_rainbow") |
|
|
cmap_idx = 0 if object_id is None else object_id |
|
|
color = list(cmap((cmap_idx * 47) % 256)) |
|
|
|
|
|
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='P', s=marker_size, edgecolor=color, linewidth=1.25) |
|
|
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='s', s=marker_size, edgecolor=color, linewidth=1.25) |
|
|
|
|
|
def save_prompts_one_image(frame_image, boxes, points, labels, save_path): |
|
|
|
|
|
fig, ax = plt.subplots(1, figsize=(6, 6)) |
|
|
|
|
|
|
|
|
ax.imshow(frame_image) |
|
|
ax.axis('off') |
|
|
|
|
|
points, labels = normalize_prompt(points, labels) |
|
|
if type(boxes) == torch.Tensor: |
|
|
for object_id, box in enumerate(boxes): |
|
|
|
|
|
if not box is None: |
|
|
show_box(box.cpu(), ax, object_id=object_id) |
|
|
elif type(boxes) == dict: |
|
|
for object_id, box in boxes.items(): |
|
|
|
|
|
if not box is None: |
|
|
show_box(box.cpu(), ax, object_id=object_id) |
|
|
elif type(boxes) == list and len(boxes) == 0: |
|
|
pass |
|
|
else: |
|
|
raise Exception() |
|
|
|
|
|
for object_id, (point_ls, label_ls) in enumerate(zip(points, labels)): |
|
|
if not len(point_ls) == 0: |
|
|
show_points(point_ls.cpu(), label_ls.cpu(), ax, object_id=object_id) |
|
|
|
|
|
|
|
|
plt.savefig(save_path) |
|
|
plt.close() |
|
|
|
|
|
def save_video_prompts_visualization(video_tensor, video_boxes, video_points, video_labels, video_id, video_save_base_dir): |
|
|
video_save_dir = os.path.join(video_save_base_dir, video_id) |
|
|
if not os.path.exists(video_save_dir): |
|
|
os.makedirs(video_save_dir, exist_ok=True) |
|
|
|
|
|
for frame_id, image in enumerate(video_tensor): |
|
|
boxes, points, labels = [], [], [] |
|
|
|
|
|
if frame_id in video_boxes: |
|
|
boxes = video_boxes[frame_id] |
|
|
|
|
|
if frame_id in video_points: |
|
|
points = video_points[frame_id] |
|
|
if frame_id in video_labels: |
|
|
labels = video_labels[frame_id] |
|
|
|
|
|
save_path = os.path.join(video_save_dir, f"{frame_id}.jpg") |
|
|
save_prompts_one_image(image, boxes, points, labels, save_path) |
|
|
|
|
|
|
|
|
def save_video_masks_visualization(video_tensor, video_masks, video_id, video_save_base_dir, oid_class_pred=None, sample_rate = 1): |
|
|
video_save_dir = os.path.join(video_save_base_dir, video_id) |
|
|
if not os.path.exists(video_save_dir): |
|
|
os.makedirs(video_save_dir, exist_ok=True) |
|
|
|
|
|
for frame_id, image in enumerate(video_tensor): |
|
|
if random.random() > sample_rate: |
|
|
continue |
|
|
if frame_id not in video_masks: |
|
|
print("No mask for Frame", frame_id) |
|
|
continue |
|
|
masks = video_masks[frame_id] |
|
|
save_path = os.path.join(video_save_dir, f"{frame_id}.jpg") |
|
|
save_mask_one_image(image, masks, save_path) |
|
|
|
|
|
|
|
|
|
|
|
def get_color(obj_id, cmap_name="gist_rainbow",alpha=0.5): |
|
|
cmap = plt.get_cmap(cmap_name) |
|
|
cmap_idx = 0 if obj_id is None else obj_id |
|
|
color = list(cmap((cmap_idx * 47) % 256)) |
|
|
color[3] = 0.5 |
|
|
color = np.array(color) |
|
|
return color |
|
|
|
|
|
|
|
|
def _bbox_center(bbox: Tuple[int, int, int, int]) -> Tuple[float, float]: |
|
|
return ((bbox[0] + bbox[2]) / 2.0, (bbox[1] + bbox[3]) / 2.0) |
|
|
|
|
|
|
|
|
def relation_line( |
|
|
bbox1: Tuple[int, int, int, int], |
|
|
bbox2: Tuple[int, int, int, int], |
|
|
) -> Tuple[Tuple[int, int], Tuple[int, int]]: |
|
|
""" |
|
|
Returns integer pixel centers suitable for drawing a relation line. For |
|
|
coincident boxes, nudges the target center to ensure the segment has span. |
|
|
""" |
|
|
center1 = _bbox_center(bbox1) |
|
|
center2 = _bbox_center(bbox2) |
|
|
if math.isclose(center1[0], center2[0], abs_tol=1e-3) and math.isclose(center1[1], center2[1], abs_tol=1e-3): |
|
|
offset = max(1.0, (bbox2[2] - bbox2[0]) * 0.05) |
|
|
center2 = (center2[0] + offset, center2[1]) |
|
|
start = (int(round(center1[0])), int(round(center1[1]))) |
|
|
end = (int(round(center2[0])), int(round(center2[1]))) |
|
|
if start == end: |
|
|
end = (end[0] + 1, end[1]) |
|
|
return start, end |
|
|
|
|
|
def get_binary_mask_one_image(frame_image, masks, rel_pred_ls=None): |
|
|
|
|
|
fig, ax = plt.subplots(1, figsize=(6, 6)) |
|
|
|
|
|
|
|
|
ax.imshow(frame_image) |
|
|
ax.axis('off') |
|
|
|
|
|
all_objs_to_show = set() |
|
|
all_lines_to_show = [] |
|
|
|
|
|
|
|
|
for (from_obj_id, to_obj_id), rel_text in rel_pred_ls.items(): |
|
|
all_objs_to_show.add(from_obj_id) |
|
|
all_objs_to_show.add(to_obj_id) |
|
|
|
|
|
from_mask = masks[from_obj_id] |
|
|
bbox1 = mask_to_bbox(from_mask) |
|
|
to_mask = masks[to_obj_id] |
|
|
bbox2 = mask_to_bbox(to_mask) |
|
|
|
|
|
c1, c2 = shortest_line_between_bboxes(bbox1, bbox2) |
|
|
|
|
|
line_color = get_color(from_obj_id) |
|
|
face_color = get_color(to_obj_id) |
|
|
line = c1, c2, face_color, line_color, rel_text |
|
|
all_lines_to_show.append(line) |
|
|
|
|
|
masks_to_show = {} |
|
|
for oid in all_objs_to_show: |
|
|
masks_to_show[oid] = masks[oid] |
|
|
|
|
|
|
|
|
for obj_id, mask in masks_to_show.items(): |
|
|
show_mask(mask, ax, obj_id=obj_id, random_color=False) |
|
|
|
|
|
for (from_pt_x, from_pt_y), (to_pt_x, to_pt_y), face_color, line_color, rel_text in all_lines_to_show: |
|
|
|
|
|
plt.plot([from_pt_x, to_pt_x], [from_pt_y, to_pt_y], color=line_color, linestyle='-', linewidth=3) |
|
|
mid_pt_x = (from_pt_x + to_pt_x) / 2 |
|
|
mid_pt_y = (from_pt_y + to_pt_y) / 2 |
|
|
ax.text( |
|
|
mid_pt_x - 5, |
|
|
mid_pt_y, |
|
|
rel_text, |
|
|
color="white", |
|
|
fontsize=6, |
|
|
backgroundcolor=np.array(line_color), |
|
|
bbox=dict(facecolor=face_color, edgecolor=line_color, boxstyle='round,pad=1'), |
|
|
alpha=1 |
|
|
) |
|
|
|
|
|
|
|
|
return fig, ax |
|
|
|