|
|
from __future__ import annotations |
|
|
|
|
|
from collections import defaultdict |
|
|
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
|
|
|
MaskType = Union[np.ndarray, torch.Tensor] |
|
|
|
|
|
|
|
|
def _to_numpy_mask(mask: MaskType) -> np.ndarray: |
|
|
""" |
|
|
Convert assorted mask formats to a 2D numpy boolean array. |
|
|
""" |
|
|
if isinstance(mask, torch.Tensor): |
|
|
mask_np = mask.detach().cpu().numpy() |
|
|
else: |
|
|
mask_np = np.asarray(mask) |
|
|
|
|
|
|
|
|
while mask_np.ndim > 2 and mask_np.shape[0] == 1: |
|
|
mask_np = np.squeeze(mask_np, axis=0) |
|
|
if mask_np.ndim > 2 and mask_np.shape[-1] == 1: |
|
|
mask_np = np.squeeze(mask_np, axis=-1) |
|
|
|
|
|
if mask_np.ndim != 2: |
|
|
raise ValueError(f"Expected mask to be 2D after squeezing, got shape {mask_np.shape}") |
|
|
|
|
|
return mask_np.astype(bool) |
|
|
|
|
|
|
|
|
def _mask_to_bbox(mask: np.ndarray) -> Optional[Tuple[int, int, int, int]]: |
|
|
""" |
|
|
Compute a bounding box for a 2D boolean mask. |
|
|
""" |
|
|
if not mask.any(): |
|
|
return None |
|
|
rows, cols = np.nonzero(mask) |
|
|
y_min, y_max = rows.min(), rows.max() |
|
|
x_min, x_max = cols.min(), cols.max() |
|
|
return x_min, y_min, x_max, y_max |
|
|
|
|
|
|
|
|
def flatten_segments_for_batch( |
|
|
video_id: int, |
|
|
segments: Dict[int, Dict[int, MaskType]], |
|
|
bbox_min_dim: int = 5, |
|
|
) -> Dict[str, List]: |
|
|
""" |
|
|
Flatten nested segmentation data into batched lists suitable for predicate |
|
|
models or downstream visualizations. Mirrors the notebook helper but is |
|
|
robust to differing mask dtypes/shapes. |
|
|
""" |
|
|
batched_object_ids: List[Tuple[int, int, int]] = [] |
|
|
batched_masks: List[np.ndarray] = [] |
|
|
batched_bboxes: List[Tuple[int, int, int, int]] = [] |
|
|
frame_pairs: List[Tuple[int, int, Tuple[int, int]]] = [] |
|
|
|
|
|
for frame_id, frame_objects in segments.items(): |
|
|
valid_objects: List[int] = [] |
|
|
for object_id, raw_mask in frame_objects.items(): |
|
|
mask = _to_numpy_mask(raw_mask) |
|
|
bbox = _mask_to_bbox(mask) |
|
|
if bbox is None: |
|
|
continue |
|
|
|
|
|
x_min, y_min, x_max, y_max = bbox |
|
|
if abs(y_max - y_min) < bbox_min_dim or abs(x_max - x_min) < bbox_min_dim: |
|
|
continue |
|
|
|
|
|
valid_objects.append(object_id) |
|
|
batched_object_ids.append((video_id, frame_id, object_id)) |
|
|
batched_masks.append(mask) |
|
|
batched_bboxes.append(bbox) |
|
|
|
|
|
for i in valid_objects: |
|
|
for j in valid_objects: |
|
|
if i == j: |
|
|
continue |
|
|
frame_pairs.append((video_id, frame_id, (i, j))) |
|
|
|
|
|
return { |
|
|
"object_ids": batched_object_ids, |
|
|
"masks": batched_masks, |
|
|
"bboxes": batched_bboxes, |
|
|
"pairs": frame_pairs, |
|
|
} |
|
|
|
|
|
|
|
|
def extract_valid_object_pairs( |
|
|
batched_object_ids: Sequence[Tuple[int, int, int]], |
|
|
interested_object_pairs: Optional[Iterable[Tuple[int, int]]] = None, |
|
|
) -> List[Tuple[int, int, Tuple[int, int]]]: |
|
|
""" |
|
|
Filter object pairs per frame. If `interested_object_pairs` is provided, only |
|
|
emit those combinations when both objects are present; otherwise emit all |
|
|
permutations (i, j) with i != j for each frame. |
|
|
""" |
|
|
frame_to_objects: Dict[Tuple[int, int], set] = defaultdict(set) |
|
|
for vid, fid, oid in batched_object_ids: |
|
|
frame_to_objects[(vid, fid)].add(oid) |
|
|
|
|
|
interested = ( |
|
|
list(interested_object_pairs) |
|
|
if interested_object_pairs is not None |
|
|
else None |
|
|
) |
|
|
|
|
|
valid_pairs: List[Tuple[int, int, Tuple[int, int]]] = [] |
|
|
for (vid, fid), object_ids in frame_to_objects.items(): |
|
|
if interested: |
|
|
for src, dst in interested: |
|
|
if src in object_ids and dst in object_ids: |
|
|
valid_pairs.append((vid, fid, (src, dst))) |
|
|
else: |
|
|
for src in object_ids: |
|
|
for dst in object_ids: |
|
|
if src == dst: |
|
|
continue |
|
|
valid_pairs.append((vid, fid, (src, dst))) |
|
|
|
|
|
return valid_pairs |
|
|
|