Upload VINE model - model
Browse files- config.json +6 -2
- flattening.py +124 -0
- model.safetensors +3 -0
- vine_model.py +658 -0
- vis_utils.py +941 -0
config.json
CHANGED
|
@@ -1,9 +1,12 @@
|
|
| 1 |
{
|
| 2 |
-
"_attn_implementation_autoset": true,
|
| 3 |
"_device": "cuda",
|
| 4 |
"alpha": 0.5,
|
|
|
|
|
|
|
|
|
|
| 5 |
"auto_map": {
|
| 6 |
-
"AutoConfig": "vine_config.VineConfig"
|
|
|
|
| 7 |
},
|
| 8 |
"bbox_min_dim": 5,
|
| 9 |
"box_threshold": 0.35,
|
|
@@ -23,6 +26,7 @@
|
|
| 23 |
"target_fps": 1,
|
| 24 |
"text_threshold": 0.25,
|
| 25 |
"topk_cate": 3,
|
|
|
|
| 26 |
"transformers_version": "4.46.2",
|
| 27 |
"visualization_dir": null,
|
| 28 |
"visualize": false,
|
|
|
|
| 1 |
{
|
|
|
|
| 2 |
"_device": "cuda",
|
| 3 |
"alpha": 0.5,
|
| 4 |
+
"architectures": [
|
| 5 |
+
"VineModel"
|
| 6 |
+
],
|
| 7 |
"auto_map": {
|
| 8 |
+
"AutoConfig": "vine_config.VineConfig",
|
| 9 |
+
"AutoModel": "vine_model.VineModel"
|
| 10 |
},
|
| 11 |
"bbox_min_dim": 5,
|
| 12 |
"box_threshold": 0.35,
|
|
|
|
| 26 |
"target_fps": 1,
|
| 27 |
"text_threshold": 0.25,
|
| 28 |
"topk_cate": 3,
|
| 29 |
+
"torch_dtype": "float32",
|
| 30 |
"transformers_version": "4.46.2",
|
| 31 |
"visualization_dir": null,
|
| 32 |
"visualize": false,
|
flattening.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
MaskType = Union[np.ndarray, torch.Tensor]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _to_numpy_mask(mask: MaskType) -> np.ndarray:
|
| 14 |
+
"""
|
| 15 |
+
Convert assorted mask formats to a 2D numpy boolean array.
|
| 16 |
+
"""
|
| 17 |
+
if isinstance(mask, torch.Tensor):
|
| 18 |
+
mask_np = mask.detach().cpu().numpy()
|
| 19 |
+
else:
|
| 20 |
+
mask_np = np.asarray(mask)
|
| 21 |
+
|
| 22 |
+
# Remove singleton dimensions at the front/back
|
| 23 |
+
while mask_np.ndim > 2 and mask_np.shape[0] == 1:
|
| 24 |
+
mask_np = np.squeeze(mask_np, axis=0)
|
| 25 |
+
if mask_np.ndim > 2 and mask_np.shape[-1] == 1:
|
| 26 |
+
mask_np = np.squeeze(mask_np, axis=-1)
|
| 27 |
+
|
| 28 |
+
if mask_np.ndim != 2:
|
| 29 |
+
raise ValueError(f"Expected mask to be 2D after squeezing, got shape {mask_np.shape}")
|
| 30 |
+
|
| 31 |
+
return mask_np.astype(bool)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _mask_to_bbox(mask: np.ndarray) -> Optional[Tuple[int, int, int, int]]:
|
| 35 |
+
"""
|
| 36 |
+
Compute a bounding box for a 2D boolean mask.
|
| 37 |
+
"""
|
| 38 |
+
if not mask.any():
|
| 39 |
+
return None
|
| 40 |
+
rows, cols = np.nonzero(mask)
|
| 41 |
+
y_min, y_max = rows.min(), rows.max()
|
| 42 |
+
x_min, x_max = cols.min(), cols.max()
|
| 43 |
+
return x_min, y_min, x_max, y_max
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def flatten_segments_for_batch(
|
| 47 |
+
video_id: int,
|
| 48 |
+
segments: Dict[int, Dict[int, MaskType]],
|
| 49 |
+
bbox_min_dim: int = 5,
|
| 50 |
+
) -> Dict[str, List]:
|
| 51 |
+
"""
|
| 52 |
+
Flatten nested segmentation data into batched lists suitable for predicate
|
| 53 |
+
models or downstream visualizations. Mirrors the notebook helper but is
|
| 54 |
+
robust to differing mask dtypes/shapes.
|
| 55 |
+
"""
|
| 56 |
+
batched_object_ids: List[Tuple[int, int, int]] = []
|
| 57 |
+
batched_masks: List[np.ndarray] = []
|
| 58 |
+
batched_bboxes: List[Tuple[int, int, int, int]] = []
|
| 59 |
+
frame_pairs: List[Tuple[int, int, Tuple[int, int]]] = []
|
| 60 |
+
|
| 61 |
+
for frame_id, frame_objects in segments.items():
|
| 62 |
+
valid_objects: List[int] = []
|
| 63 |
+
for object_id, raw_mask in frame_objects.items():
|
| 64 |
+
mask = _to_numpy_mask(raw_mask)
|
| 65 |
+
bbox = _mask_to_bbox(mask)
|
| 66 |
+
if bbox is None:
|
| 67 |
+
continue
|
| 68 |
+
|
| 69 |
+
x_min, y_min, x_max, y_max = bbox
|
| 70 |
+
if abs(y_max - y_min) < bbox_min_dim or abs(x_max - x_min) < bbox_min_dim:
|
| 71 |
+
continue
|
| 72 |
+
|
| 73 |
+
valid_objects.append(object_id)
|
| 74 |
+
batched_object_ids.append((video_id, frame_id, object_id))
|
| 75 |
+
batched_masks.append(mask)
|
| 76 |
+
batched_bboxes.append(bbox)
|
| 77 |
+
|
| 78 |
+
for i in valid_objects:
|
| 79 |
+
for j in valid_objects:
|
| 80 |
+
if i == j:
|
| 81 |
+
continue
|
| 82 |
+
frame_pairs.append((video_id, frame_id, (i, j)))
|
| 83 |
+
|
| 84 |
+
return {
|
| 85 |
+
"object_ids": batched_object_ids,
|
| 86 |
+
"masks": batched_masks,
|
| 87 |
+
"bboxes": batched_bboxes,
|
| 88 |
+
"pairs": frame_pairs,
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def extract_valid_object_pairs(
|
| 93 |
+
batched_object_ids: Sequence[Tuple[int, int, int]],
|
| 94 |
+
interested_object_pairs: Optional[Iterable[Tuple[int, int]]] = None,
|
| 95 |
+
) -> List[Tuple[int, int, Tuple[int, int]]]:
|
| 96 |
+
"""
|
| 97 |
+
Filter object pairs per frame. If `interested_object_pairs` is provided, only
|
| 98 |
+
emit those combinations when both objects are present; otherwise emit all
|
| 99 |
+
permutations (i, j) with i != j for each frame.
|
| 100 |
+
"""
|
| 101 |
+
frame_to_objects: Dict[Tuple[int, int], set] = defaultdict(set)
|
| 102 |
+
for vid, fid, oid in batched_object_ids:
|
| 103 |
+
frame_to_objects[(vid, fid)].add(oid)
|
| 104 |
+
|
| 105 |
+
interested = (
|
| 106 |
+
list(interested_object_pairs)
|
| 107 |
+
if interested_object_pairs is not None
|
| 108 |
+
else None
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
valid_pairs: List[Tuple[int, int, Tuple[int, int]]] = []
|
| 112 |
+
for (vid, fid), object_ids in frame_to_objects.items():
|
| 113 |
+
if interested:
|
| 114 |
+
for src, dst in interested:
|
| 115 |
+
if src in object_ids and dst in object_ids:
|
| 116 |
+
valid_pairs.append((vid, fid, (src, dst)))
|
| 117 |
+
else:
|
| 118 |
+
for src in object_ids:
|
| 119 |
+
for dst in object_ids:
|
| 120 |
+
if src == dst:
|
| 121 |
+
continue
|
| 122 |
+
valid_pairs.append((vid, fid, (src, dst)))
|
| 123 |
+
|
| 124 |
+
return valid_pairs
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c91c273c5f61b7f17fc6cc265e14bb78ed134c71d7b54611208420fcbe4f81de
|
| 3 |
+
size 1815491340
|
vine_model.py
ADDED
|
@@ -0,0 +1,658 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from flax import config
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import torch.utils.checkpoint as cp
|
| 6 |
+
from transformers import PreTrainedModel, AutoTokenizer, AutoModel, AutoProcessor
|
| 7 |
+
from typing import Dict, List, Tuple, Optional, Any, Union
|
| 8 |
+
import numpy as np
|
| 9 |
+
import os
|
| 10 |
+
import cv2
|
| 11 |
+
from collections import defaultdict
|
| 12 |
+
import builtins
|
| 13 |
+
import sys
|
| 14 |
+
from laser.models import llava_clip_model_v3
|
| 15 |
+
sys.modules["llava_clip_model_v3"] = llava_clip_model_v3
|
| 16 |
+
import inspect
|
| 17 |
+
from transformers.models.clip import modeling_clip
|
| 18 |
+
import transformers
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
from .vine_config import VineConfig
|
| 24 |
+
from laser.models.model_utils import (
|
| 25 |
+
extract_single_object,
|
| 26 |
+
extract_object_subject,
|
| 27 |
+
crop_image_contain_bboxes,
|
| 28 |
+
segment_list
|
| 29 |
+
)
|
| 30 |
+
from .flattening import (
|
| 31 |
+
extract_valid_object_pairs,
|
| 32 |
+
flatten_segments_for_batch,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
from .vis_utils import save_mask_one_image
|
| 36 |
+
|
| 37 |
+
class VineModel(PreTrainedModel):
|
| 38 |
+
"""
|
| 39 |
+
VINE (Video Understanding with Natural Language) Model
|
| 40 |
+
|
| 41 |
+
This model processes videos along with categorical, unary, and binary keywords
|
| 42 |
+
to return probability distributions over those keywords for detected objects
|
| 43 |
+
and their relationships in the video.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
config_class = VineConfig
|
| 47 |
+
|
| 48 |
+
def __init__(self, config: VineConfig):
|
| 49 |
+
super().__init__(config)
|
| 50 |
+
|
| 51 |
+
self.config = config
|
| 52 |
+
self.visualize = getattr(config, "visualize", False)
|
| 53 |
+
self.visualization_dir = getattr(config, "visualization_dir", None)
|
| 54 |
+
self.debug_visualizations = getattr(config, "debug_visualizations", False)
|
| 55 |
+
self._device = getattr(config, "_device")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# Initialize CLIP components
|
| 60 |
+
|
| 61 |
+
self.clip_tokenizer = AutoTokenizer.from_pretrained(config.model_name)
|
| 62 |
+
if self.clip_tokenizer.pad_token is None:
|
| 63 |
+
self.clip_tokenizer.pad_token = (
|
| 64 |
+
self.clip_tokenizer.unk_token
|
| 65 |
+
if self.clip_tokenizer.unk_token
|
| 66 |
+
else self.clip_tokenizer.eos_token
|
| 67 |
+
)
|
| 68 |
+
self.clip_processor = AutoProcessor.from_pretrained(config.model_name)
|
| 69 |
+
self.clip_cate_model = AutoModel.from_pretrained(config.model_name)
|
| 70 |
+
self.clip_unary_model = AutoModel.from_pretrained(config.model_name)
|
| 71 |
+
self.clip_binary_model = AutoModel.from_pretrained(config.model_name)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# Then try to load pretrained VINE weights if specified
|
| 75 |
+
if config.pretrained_vine_path:
|
| 76 |
+
self._load_pretrained_vine_weights(config.pretrained_vine_path)
|
| 77 |
+
|
| 78 |
+
# Move models to devicexwxw
|
| 79 |
+
self.to(self._device)
|
| 80 |
+
|
| 81 |
+
def _load_pretrained_vine_weights(self, pretrained_path: str, epoch: int = 0):
|
| 82 |
+
"""
|
| 83 |
+
Load pretrained VINE weights from a saved .pt file or ensemble format.
|
| 84 |
+
"""
|
| 85 |
+
#try: # simple .pt or .pth checkpoint
|
| 86 |
+
|
| 87 |
+
# x = torch.load(pretrained_path, map_location=self._device, weights_only=False)
|
| 88 |
+
# print(f"Loaded VINE checkpoint type: {type(x)}")
|
| 89 |
+
if pretrained_path == "video-fm/vine_v0":
|
| 90 |
+
self.clip_tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
|
| 91 |
+
self.clip_cate_model = AutoModel.from_pretrained(pretrained_path)
|
| 92 |
+
self.clip_unary_model = AutoModel.from_pretrained(pretrained_path)
|
| 93 |
+
self.clip_binary_model = AutoModel.from_pretrained(pretrained_path)
|
| 94 |
+
|
| 95 |
+
if pretrained_path.endswith(".pkl"):
|
| 96 |
+
print(f"Loading VINE weights from: {pretrained_path}")
|
| 97 |
+
loaded_vine_model = torch.load(pretrained_path, map_location=self._device, weights_only=False)
|
| 98 |
+
|
| 99 |
+
print(f"Loaded state type: {type(loaded_vine_model)}")
|
| 100 |
+
if not isinstance(loaded_vine_model, dict):
|
| 101 |
+
if hasattr(loaded_vine_model, 'clip_cate_model'):
|
| 102 |
+
self.clip_cate_model.load_state_dict(loaded_vine_model.clip_cate_model.state_dict())
|
| 103 |
+
if hasattr(loaded_vine_model, 'clip_unary_model'):
|
| 104 |
+
self.clip_unary_model.load_state_dict(loaded_vine_model.clip_unary_model.state_dict())
|
| 105 |
+
if hasattr(loaded_vine_model, 'clip_binary_model'):
|
| 106 |
+
self.clip_binary_model.load_state_dict(loaded_vine_model.clip_binary_model.state_dict())
|
| 107 |
+
return True
|
| 108 |
+
|
| 109 |
+
elif pretrained_path.endswith(".pt") or pretrained_path.endswith(".pth"):
|
| 110 |
+
state = torch.load(pretrained_path, map_location=self._device, weights_only=True)
|
| 111 |
+
print(f"Loaded state type: {type(state)}")
|
| 112 |
+
self.load_state_dict(state)
|
| 113 |
+
return True
|
| 114 |
+
|
| 115 |
+
# handle directory + epoch format
|
| 116 |
+
if os.path.isdir(pretrained_path):
|
| 117 |
+
model_files = [f for f in os.listdir(pretrained_path) if f.endswith(f'.{epoch}.model')]
|
| 118 |
+
if model_files:
|
| 119 |
+
model_file = os.path.join(pretrained_path, model_files[0])
|
| 120 |
+
print(f"Loading VINE weights from: {model_file}")
|
| 121 |
+
pretrained_model = torch.load(model_file, map_location="cpu")
|
| 122 |
+
|
| 123 |
+
# Conversion from PredicateModel-like object to VineModel
|
| 124 |
+
# Only copy if attributes exist
|
| 125 |
+
if hasattr(pretrained_model, 'clip_cate_model'):
|
| 126 |
+
self.clip_cate_model.load_state_dict(pretrained_model.clip_cate_model.state_dict())
|
| 127 |
+
if hasattr(pretrained_model, 'clip_unary_model'):
|
| 128 |
+
self.clip_unary_model.load_state_dict(pretrained_model.clip_unary_model.state_dict())
|
| 129 |
+
if hasattr(pretrained_model, 'clip_binary_model'):
|
| 130 |
+
self.clip_binary_model.load_state_dict(pretrained_model.clip_binary_model.state_dict())
|
| 131 |
+
print("✓ Loaded all sub-model weights from ensemble format")
|
| 132 |
+
return True
|
| 133 |
+
else:
|
| 134 |
+
print(f"No model file found for epoch {epoch} in {pretrained_path}")
|
| 135 |
+
return False
|
| 136 |
+
|
| 137 |
+
print("Unsupported format for pretrained_vine_path")
|
| 138 |
+
return False
|
| 139 |
+
|
| 140 |
+
# except Exception as e:
|
| 141 |
+
# print(f"✗ Error loading VINE weights: {e}")
|
| 142 |
+
# print("Using base CLIP models instead")
|
| 143 |
+
# return False
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
# def _load_pretrained_vine_weights(self, pretrained_path: str, epoch: int = 0):
|
| 148 |
+
# """
|
| 149 |
+
# Load pretrained VINE weights from local ensemble format.
|
| 150 |
+
|
| 151 |
+
# Args:
|
| 152 |
+
# pretrained_path: Path to the pretrained model directory or HF model name
|
| 153 |
+
# epoch: Epoch number to load (for ensemble format)
|
| 154 |
+
# """
|
| 155 |
+
# if pretrained_path == "video-fm/vine_v0":
|
| 156 |
+
# # Try to load from HuggingFace Hubtry:
|
| 157 |
+
# # ✅ TODO FIXED: Added support for loading .pt/.pth checkpoints with state dicts
|
| 158 |
+
# if pretrained_path.endswith(".pt") or pretrained_path.endswith(".pth"):
|
| 159 |
+
# print(f"Loading VINE weights from: {pretrained_path}")
|
| 160 |
+
# state = torch.load(pretrained_path, map_location="cpu")
|
| 161 |
+
|
| 162 |
+
# if "clip_cate_model" in state:
|
| 163 |
+
# self.clip_cate_model.load_state_dict(state["clip_cate_model"])
|
| 164 |
+
# print("✓ Loaded categorical model weights")
|
| 165 |
+
# if "clip_unary_model" in state:
|
| 166 |
+
# self.clip_unary_model.load_state_dict(state["clip_unary_model"])
|
| 167 |
+
# print("✓ Loaded unary model weights")
|
| 168 |
+
# if "clip_binary_model" in state:
|
| 169 |
+
# self.clip_binary_model.load_state_dict(state["clip_binary_model"])
|
| 170 |
+
# print("✓ Loaded binary model weights")
|
| 171 |
+
|
| 172 |
+
# if "clip_tokenizer" in state:
|
| 173 |
+
# self.clip_tokenizer = state["clip_tokenizer"]
|
| 174 |
+
# print("✓ Loaded tokenizer")
|
| 175 |
+
# if "clip_processor" in state:
|
| 176 |
+
# self.clip_processor = state["clip_processor"]
|
| 177 |
+
# print("✓ Loaded processor")
|
| 178 |
+
|
| 179 |
+
# print("✓ All VINE weights loaded successfully")
|
| 180 |
+
# return True
|
| 181 |
+
|
| 182 |
+
# # Load from local ensemble format
|
| 183 |
+
# try:
|
| 184 |
+
# if os.path.isdir(pretrained_path):
|
| 185 |
+
# # Directory format - look for ensemble file
|
| 186 |
+
# model_files = [f for f in os.listdir(pretrained_path) if f.endswith(f'.{epoch}.model')]
|
| 187 |
+
# if model_files:
|
| 188 |
+
# model_file = os.path.join(pretrained_path, model_files[0])
|
| 189 |
+
# else:
|
| 190 |
+
# print(f"No model file found for epoch {epoch} in {pretrained_path}")
|
| 191 |
+
# return False
|
| 192 |
+
# else:
|
| 193 |
+
# # Direct file path
|
| 194 |
+
# model_file = pretrained_path
|
| 195 |
+
|
| 196 |
+
# print(f"Loading VINE weights from: {model_file}")
|
| 197 |
+
|
| 198 |
+
# # Load the ensemble model (PredicateModel instance)
|
| 199 |
+
# # TODO: conversion from PredicateModel to VineModel
|
| 200 |
+
# pretrained_model = torch.load(model_file, map_location='cpu', weights_only=False)
|
| 201 |
+
|
| 202 |
+
# # Transfer weights from the pretrained model to our HuggingFace models
|
| 203 |
+
# if hasattr(pretrained_model, 'clip_cate_model'):
|
| 204 |
+
# self.clip_cate_model.load_state_dict(pretrained_model.clip_cate_model.state_dict())
|
| 205 |
+
# print("✓ Loaded categorical model weights")
|
| 206 |
+
|
| 207 |
+
# if hasattr(pretrained_model, 'clip_unary_model'):
|
| 208 |
+
# self.clip_unary_model.load_state_dict(pretrained_model.clip_unary_model.state_dict())
|
| 209 |
+
# print("✓ Loaded unary model weights")
|
| 210 |
+
|
| 211 |
+
# if hasattr(pretrained_model, 'clip_binary_model'):
|
| 212 |
+
# self.clip_binary_model.load_state_dict(pretrained_model.clip_binary_model.state_dict())
|
| 213 |
+
# print("✓ Loaded binary model weights")
|
| 214 |
+
|
| 215 |
+
# # Also transfer tokenizer and processor if available
|
| 216 |
+
# if hasattr(pretrained_model, 'clip_tokenizer'):
|
| 217 |
+
# self.clip_tokenizer = pretrained_model.clip_tokenizer
|
| 218 |
+
# print("✓ Loaded tokenizer")
|
| 219 |
+
|
| 220 |
+
# if hasattr(pretrained_model, 'clip_processor'):
|
| 221 |
+
# self.clip_processor = pretrained_model.clip_processor
|
| 222 |
+
# print("✓ Loaded processor")
|
| 223 |
+
|
| 224 |
+
# print("✓ Successfully loaded all VINE weights")
|
| 225 |
+
# return True
|
| 226 |
+
|
| 227 |
+
# except Exception as e:
|
| 228 |
+
# print(f"✗ Error loading VINE weights: {e}")
|
| 229 |
+
# print("Using base CLIP models instead")
|
| 230 |
+
# return False
|
| 231 |
+
|
| 232 |
+
@classmethod
|
| 233 |
+
def from_pretrained_vine(
|
| 234 |
+
cls,
|
| 235 |
+
model_path: str,
|
| 236 |
+
config: Optional[VineConfig] = None,
|
| 237 |
+
epoch: int = 0,
|
| 238 |
+
**kwargs
|
| 239 |
+
):
|
| 240 |
+
"""
|
| 241 |
+
Create VineModel from pretrained VINE weights.
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
model_path: Path to pretrained VINE model
|
| 245 |
+
config: Optional config, will create default if None
|
| 246 |
+
epoch: Epoch number to load
|
| 247 |
+
**kwargs: Additional arguments
|
| 248 |
+
|
| 249 |
+
Returns:
|
| 250 |
+
VineModel instance with loaded weights
|
| 251 |
+
"""
|
| 252 |
+
if config is None:
|
| 253 |
+
config = VineConfig(pretrained_vine_path=model_path)
|
| 254 |
+
else:
|
| 255 |
+
config.pretrained_vine_path = model_path
|
| 256 |
+
|
| 257 |
+
# Create model instance (will automatically load weights)
|
| 258 |
+
model = cls(config, **kwargs)
|
| 259 |
+
|
| 260 |
+
return model
|
| 261 |
+
|
| 262 |
+
def _text_features_checkpoint(self, model, tokens):
|
| 263 |
+
"""Extract text features with gradient checkpointing."""
|
| 264 |
+
token_keys = list(tokens.keys())
|
| 265 |
+
|
| 266 |
+
def get_text_features_wrapped(*inputs):
|
| 267 |
+
kwargs = {key: value for key, value in zip(token_keys, inputs)}
|
| 268 |
+
return model.get_text_features(**kwargs)
|
| 269 |
+
|
| 270 |
+
token_values = [tokens[key] for key in token_keys]
|
| 271 |
+
return cp.checkpoint(get_text_features_wrapped, *token_values, use_reentrant=False)
|
| 272 |
+
|
| 273 |
+
def _image_features_checkpoint(self, model, images):
|
| 274 |
+
"""Extract image features with gradient checkpointing."""
|
| 275 |
+
return cp.checkpoint(model.get_image_features, images, use_reentrant=False)
|
| 276 |
+
|
| 277 |
+
def clip_sim(self, model, nl_feat, img_feat):
|
| 278 |
+
img_feat = img_feat / img_feat.norm(p=2, dim=-1, keepdim=True)
|
| 279 |
+
nl_feat = nl_feat / nl_feat.norm(p=2, dim=-1, keepdim=True)
|
| 280 |
+
logits = torch.matmul(img_feat, nl_feat.T)
|
| 281 |
+
if hasattr(model, "logit_scale"):
|
| 282 |
+
logits = logits * model.logit_scale.exp()
|
| 283 |
+
return logits
|
| 284 |
+
|
| 285 |
+
def forward(
|
| 286 |
+
self,
|
| 287 |
+
video_frames: torch.Tensor,
|
| 288 |
+
masks: Dict[int, Dict[int, torch.Tensor]],
|
| 289 |
+
bboxes: Dict[int, Dict[int, List]],
|
| 290 |
+
categorical_keywords: List[str],
|
| 291 |
+
unary_keywords: Optional[List[str]] = None,
|
| 292 |
+
binary_keywords: Optional[List[str]] = None,
|
| 293 |
+
object_pairs: Optional[List[Tuple[int, int]]] = None,
|
| 294 |
+
return_flattened_segments: Optional[bool] = None,
|
| 295 |
+
return_valid_pairs: Optional[bool] = None,
|
| 296 |
+
interested_object_pairs: Optional[List[Tuple[int, int]]] = None,
|
| 297 |
+
debug_visualizations: Optional[bool] = None,
|
| 298 |
+
**kwargs
|
| 299 |
+
) -> Dict[str, Any]:
|
| 300 |
+
"""
|
| 301 |
+
Forward pass of the VINE model.
|
| 302 |
+
|
| 303 |
+
Args:
|
| 304 |
+
video_frames: Tensor of shape (num_frames, height, width, 3)
|
| 305 |
+
masks: Dict mapping frame_id -> object_id -> mask tensor
|
| 306 |
+
bboxes: Dict mapping frame_id -> object_id -> [x1, y1, x2, y2]
|
| 307 |
+
categorical_keywords: List of category names to classify objects
|
| 308 |
+
unary_keywords: Optional list of unary predicates (actions on single objects)
|
| 309 |
+
binary_keywords: Optional list of binary predicates (relations between objects)
|
| 310 |
+
object_pairs: Optional list of (obj1_id, obj2_id) pairs for binary classification
|
| 311 |
+
|
| 312 |
+
Returns:
|
| 313 |
+
Dict containing probability distributions for categorical, unary, and binary predictions
|
| 314 |
+
"""
|
| 315 |
+
if unary_keywords is None:
|
| 316 |
+
unary_keywords = []
|
| 317 |
+
if binary_keywords is None:
|
| 318 |
+
binary_keywords = []
|
| 319 |
+
if object_pairs is None:
|
| 320 |
+
object_pairs = []
|
| 321 |
+
if return_flattened_segments is None:
|
| 322 |
+
return_flattened_segments = self.config.return_flattened_segments
|
| 323 |
+
if return_valid_pairs is None:
|
| 324 |
+
return_valid_pairs = self.config.return_valid_pairs
|
| 325 |
+
if interested_object_pairs is None or len(interested_object_pairs) == 0:
|
| 326 |
+
interested_object_pairs = getattr(self.config, "interested_object_pairs", []) or []
|
| 327 |
+
if debug_visualizations is None:
|
| 328 |
+
debug_visualizations = self.debug_visualizations
|
| 329 |
+
|
| 330 |
+
# Prepare dummy strings for empty categories
|
| 331 |
+
dummy_str = ""
|
| 332 |
+
|
| 333 |
+
# Fill empty categories with dummy strings
|
| 334 |
+
if len(categorical_keywords) == 0:
|
| 335 |
+
categorical_keywords = [dummy_str]
|
| 336 |
+
if len(unary_keywords) == 0:
|
| 337 |
+
unary_keywords = [dummy_str]
|
| 338 |
+
if len(binary_keywords) == 0:
|
| 339 |
+
binary_keywords = [dummy_str]
|
| 340 |
+
|
| 341 |
+
# Extract text features for all keyword types
|
| 342 |
+
categorical_features = self._extract_text_features(
|
| 343 |
+
self.clip_cate_model, categorical_keywords
|
| 344 |
+
)
|
| 345 |
+
unary_features = self._extract_text_features(
|
| 346 |
+
self.clip_unary_model, unary_keywords
|
| 347 |
+
)
|
| 348 |
+
binary_features = self._extract_text_features(
|
| 349 |
+
self.clip_binary_model, binary_keywords
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
# Process video frames and extract object features
|
| 353 |
+
categorical_probs = {}
|
| 354 |
+
unary_probs = {}
|
| 355 |
+
binary_probs = {}
|
| 356 |
+
|
| 357 |
+
# Process each frame
|
| 358 |
+
for frame_id, frame_masks in masks.items():
|
| 359 |
+
if frame_id >= len(video_frames):
|
| 360 |
+
continue
|
| 361 |
+
|
| 362 |
+
frame = self._frame_to_numpy(video_frames[frame_id])
|
| 363 |
+
frame_bboxes = bboxes.get(frame_id, {})
|
| 364 |
+
|
| 365 |
+
# Extract object features for categorical classification
|
| 366 |
+
for obj_id, mask in frame_masks.items():
|
| 367 |
+
if obj_id not in frame_bboxes:
|
| 368 |
+
continue
|
| 369 |
+
|
| 370 |
+
bbox = frame_bboxes[obj_id]
|
| 371 |
+
|
| 372 |
+
# Extract single object image
|
| 373 |
+
mask_np = self._mask_to_numpy(mask)
|
| 374 |
+
|
| 375 |
+
obj_image = extract_single_object(
|
| 376 |
+
frame, mask_np, alpha=self.config.alpha
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
# Get image features
|
| 380 |
+
obj_features = self._extract_image_features(
|
| 381 |
+
self.clip_cate_model, obj_image
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
# Compute similarities for categorical classification
|
| 385 |
+
cat_similarities = self.clip_sim(
|
| 386 |
+
self.clip_cate_model, categorical_features, obj_features
|
| 387 |
+
)
|
| 388 |
+
cat_probs = F.softmax(cat_similarities, dim=-1)
|
| 389 |
+
|
| 390 |
+
# Store categorical predictions
|
| 391 |
+
for i, keyword in enumerate(categorical_keywords):
|
| 392 |
+
if keyword != dummy_str:
|
| 393 |
+
categorical_probs[(obj_id, keyword)] = cat_probs[0, i].item()
|
| 394 |
+
|
| 395 |
+
# Compute unary predictions
|
| 396 |
+
if len(unary_keywords) > 0 and unary_keywords[0] != dummy_str:
|
| 397 |
+
unary_similarities = self.clip_sim(
|
| 398 |
+
self.clip_unary_model, unary_features, obj_features
|
| 399 |
+
)
|
| 400 |
+
unary_probs_tensor = F.softmax(unary_similarities, dim=-1)
|
| 401 |
+
|
| 402 |
+
for i, keyword in enumerate(unary_keywords):
|
| 403 |
+
if keyword != dummy_str:
|
| 404 |
+
unary_probs[(frame_id, obj_id, keyword)] = unary_probs_tensor[0, i].item()
|
| 405 |
+
|
| 406 |
+
# Process binary relationships
|
| 407 |
+
if len(binary_keywords) > 0 and binary_keywords[0] != dummy_str and len(object_pairs) > 0:
|
| 408 |
+
for obj1_id, obj2_id in object_pairs:
|
| 409 |
+
for frame_id, frame_masks in masks.items():
|
| 410 |
+
if frame_id >= len(video_frames):
|
| 411 |
+
continue
|
| 412 |
+
if (obj1_id in frame_masks and obj2_id in frame_masks and
|
| 413 |
+
obj1_id in bboxes.get(frame_id, {}) and obj2_id in bboxes.get(frame_id, {})):
|
| 414 |
+
|
| 415 |
+
frame = self._frame_to_numpy(video_frames[frame_id])
|
| 416 |
+
mask1 = frame_masks[obj1_id]
|
| 417 |
+
mask2 = frame_masks[obj2_id]
|
| 418 |
+
|
| 419 |
+
mask1_np = self._mask_to_numpy(mask1)
|
| 420 |
+
mask2_np = self._mask_to_numpy(mask2)
|
| 421 |
+
|
| 422 |
+
# Extract object pair image
|
| 423 |
+
pair_image = extract_object_subject(
|
| 424 |
+
frame, mask1_np[..., None], mask2_np[..., None],
|
| 425 |
+
alpha=self.config.alpha,
|
| 426 |
+
white_alpha=self.config.white_alpha
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
# Crop to contain both objects
|
| 430 |
+
bbox1 = bboxes[frame_id][obj1_id]
|
| 431 |
+
bbox2 = bboxes[frame_id][obj2_id]
|
| 432 |
+
|
| 433 |
+
# Bounding box overlap check
|
| 434 |
+
if bbox1[0] >= bbox2[2] or bbox2[1] >= bbox1[3] or \
|
| 435 |
+
bbox2[0] >= bbox1[2] or bbox1[1] >= bbox2[3]:
|
| 436 |
+
continue
|
| 437 |
+
|
| 438 |
+
cropped_image = crop_image_contain_bboxes(
|
| 439 |
+
pair_image, [bbox1, bbox2], f"frame_{frame_id}"
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
# Get image features
|
| 443 |
+
pair_features = self._extract_image_features(
|
| 444 |
+
self.clip_binary_model, cropped_image
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
# Compute similarities for binary classification
|
| 448 |
+
binary_similarities = self.clip_sim(
|
| 449 |
+
self.clip_binary_model, binary_features, pair_features
|
| 450 |
+
)
|
| 451 |
+
binary_probs_tensor = F.softmax(binary_similarities, dim=-1)
|
| 452 |
+
|
| 453 |
+
for i, keyword in enumerate(binary_keywords):
|
| 454 |
+
if keyword != dummy_str:
|
| 455 |
+
binary_probs[(frame_id, (obj1_id, obj2_id), keyword)] = binary_probs_tensor[0, i].item()
|
| 456 |
+
|
| 457 |
+
# Calculate dummy probability (for compatibility)
|
| 458 |
+
dummy_prob = 1.0 / max(len(categorical_keywords), len(unary_keywords), len(binary_keywords))
|
| 459 |
+
|
| 460 |
+
result: Dict[str, Any] = {
|
| 461 |
+
"categorical_probs": {0: categorical_probs}, # Video ID 0
|
| 462 |
+
"unary_probs": {0: unary_probs},
|
| 463 |
+
"binary_probs": [binary_probs], # List format for compatibility
|
| 464 |
+
"dummy_prob": dummy_prob
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
if return_flattened_segments or return_valid_pairs:
|
| 468 |
+
flattened = flatten_segments_for_batch(
|
| 469 |
+
video_id=0,
|
| 470 |
+
segments=masks,
|
| 471 |
+
bbox_min_dim=self.config.bbox_min_dim,
|
| 472 |
+
)
|
| 473 |
+
if return_flattened_segments:
|
| 474 |
+
result["flattened_segments"] = flattened
|
| 475 |
+
if return_valid_pairs:
|
| 476 |
+
interested_pairs = interested_object_pairs if interested_object_pairs else None
|
| 477 |
+
result["valid_pairs"] = extract_valid_object_pairs(
|
| 478 |
+
flattened["object_ids"],
|
| 479 |
+
interested_pairs,
|
| 480 |
+
)
|
| 481 |
+
if interested_pairs is None:
|
| 482 |
+
# Provide all generated pairs for clarity when auto-generated.
|
| 483 |
+
result["valid_pairs_metadata"] = {"pair_source": "all_pairs"}
|
| 484 |
+
else:
|
| 485 |
+
result["valid_pairs_metadata"] = {"pair_source": "filtered", "requested_pairs": interested_pairs}
|
| 486 |
+
|
| 487 |
+
return result
|
| 488 |
+
|
| 489 |
+
def _frame_to_numpy(self, frame: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
|
| 490 |
+
"""Convert a frame tensor/array to a contiguous numpy array."""
|
| 491 |
+
if torch.is_tensor(frame):
|
| 492 |
+
frame_np = frame.detach().cpu().numpy()
|
| 493 |
+
else:
|
| 494 |
+
frame_np = np.asarray(frame)
|
| 495 |
+
return np.ascontiguousarray(frame_np)
|
| 496 |
+
|
| 497 |
+
def _mask_to_numpy(self, mask: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
|
| 498 |
+
"""Convert a mask tensor/array to a 2D boolean numpy array."""
|
| 499 |
+
if torch.is_tensor(mask):
|
| 500 |
+
mask_np = mask.detach().cpu().numpy()
|
| 501 |
+
else:
|
| 502 |
+
mask_np = np.asarray(mask)
|
| 503 |
+
|
| 504 |
+
if mask_np.ndim == 3:
|
| 505 |
+
if mask_np.shape[0] == 1:
|
| 506 |
+
mask_np = mask_np.squeeze(0)
|
| 507 |
+
elif mask_np.shape[2] == 1:
|
| 508 |
+
mask_np = mask_np.squeeze(2)
|
| 509 |
+
|
| 510 |
+
if mask_np.ndim != 2:
|
| 511 |
+
raise ValueError(f"Mask must be 2D after squeezing, got shape {mask_np.shape}")
|
| 512 |
+
|
| 513 |
+
return mask_np.astype(bool, copy=False)
|
| 514 |
+
|
| 515 |
+
def _extract_text_features(self, model, keywords):
|
| 516 |
+
"""Extract text features for given keywords."""
|
| 517 |
+
tokens = self.clip_tokenizer(
|
| 518 |
+
keywords,
|
| 519 |
+
return_tensors="pt",
|
| 520 |
+
max_length=75,
|
| 521 |
+
truncation=True,
|
| 522 |
+
padding='max_length'
|
| 523 |
+
).to(self._device)
|
| 524 |
+
|
| 525 |
+
return self._text_features_checkpoint(model, tokens)
|
| 526 |
+
|
| 527 |
+
def _extract_image_features(self, model, image):
|
| 528 |
+
"""Extract image features for given image."""
|
| 529 |
+
# Ensure image is in correct format
|
| 530 |
+
if isinstance(image, np.ndarray):
|
| 531 |
+
if image.dtype != np.uint8:
|
| 532 |
+
image = image.astype(np.uint8)
|
| 533 |
+
# Convert BGR to RGB if needed
|
| 534 |
+
if len(image.shape) == 3 and image.shape[2] == 3:
|
| 535 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 536 |
+
|
| 537 |
+
# Process image with CLIP processor
|
| 538 |
+
inputs = self.clip_processor(
|
| 539 |
+
images=image,
|
| 540 |
+
return_tensors="pt"
|
| 541 |
+
).to(self._device)
|
| 542 |
+
|
| 543 |
+
return self._image_features_checkpoint(model, inputs['pixel_values'])
|
| 544 |
+
#TODO: return masks and bboxes and their corresponding index
|
| 545 |
+
def predict(
|
| 546 |
+
self,
|
| 547 |
+
video_frames: torch.Tensor,
|
| 548 |
+
masks: Dict[int, Dict[int, torch.Tensor]],
|
| 549 |
+
bboxes: Dict[int, Dict[int, List]],
|
| 550 |
+
categorical_keywords: List[str],
|
| 551 |
+
unary_keywords: Optional[List[str]] = None,
|
| 552 |
+
binary_keywords: Optional[List[str]] = None,
|
| 553 |
+
object_pairs: Optional[List[Tuple[int, int]]] = None,
|
| 554 |
+
return_top_k: int = 3,
|
| 555 |
+
return_flattened_segments: Optional[bool] = None,
|
| 556 |
+
return_valid_pairs: Optional[bool] = None,
|
| 557 |
+
interested_object_pairs: Optional[List[Tuple[int, int]]] = None,
|
| 558 |
+
debug_visualizations: Optional[bool] = None,
|
| 559 |
+
) -> Dict[str, Any]:
|
| 560 |
+
"""
|
| 561 |
+
High-level prediction method that returns formatted results.
|
| 562 |
+
|
| 563 |
+
Args:
|
| 564 |
+
video_frames: Tensor of shape (num_frames, height, width, 3)
|
| 565 |
+
masks: Dict mapping frame_id -> object_id -> mask tensor
|
| 566 |
+
bboxes: Dict mapping frame_id -> object_id -> [x1, y1, x2, y2]
|
| 567 |
+
categorical_keywords: List of category names
|
| 568 |
+
unary_keywords: Optional list of unary predicates
|
| 569 |
+
binary_keywords: Optional list of binary predicates
|
| 570 |
+
object_pairs: Optional list of object pairs for binary relations
|
| 571 |
+
return_top_k: Number of top predictions to return
|
| 572 |
+
return_flattened_segments: Whether to include flattened mask/bbox tensors
|
| 573 |
+
return_valid_pairs: Whether to compute valid object pairs per frame
|
| 574 |
+
interested_object_pairs: Optional subset of object pairs to track
|
| 575 |
+
|
| 576 |
+
Returns:
|
| 577 |
+
Formatted prediction results
|
| 578 |
+
"""
|
| 579 |
+
|
| 580 |
+
with torch.no_grad():
|
| 581 |
+
outputs = self.forward(
|
| 582 |
+
video_frames=video_frames,
|
| 583 |
+
masks=masks,
|
| 584 |
+
bboxes=bboxes,
|
| 585 |
+
categorical_keywords=categorical_keywords,
|
| 586 |
+
unary_keywords=unary_keywords,
|
| 587 |
+
binary_keywords=binary_keywords,
|
| 588 |
+
object_pairs=object_pairs,
|
| 589 |
+
return_flattened_segments=return_flattened_segments,
|
| 590 |
+
return_valid_pairs=return_valid_pairs,
|
| 591 |
+
interested_object_pairs=interested_object_pairs,
|
| 592 |
+
debug_visualizations=debug_visualizations,
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
# Format categorical results
|
| 596 |
+
formatted_categorical = {}
|
| 597 |
+
for (obj_id, category), prob in outputs["categorical_probs"][0].items():
|
| 598 |
+
if obj_id not in formatted_categorical:
|
| 599 |
+
formatted_categorical[obj_id] = []
|
| 600 |
+
formatted_categorical[obj_id].append((prob, category))
|
| 601 |
+
|
| 602 |
+
# Sort and take top-k for each object
|
| 603 |
+
for obj_id in formatted_categorical:
|
| 604 |
+
formatted_categorical[obj_id] = sorted(
|
| 605 |
+
formatted_categorical[obj_id], reverse=True
|
| 606 |
+
)[:return_top_k]
|
| 607 |
+
|
| 608 |
+
# Format unary results
|
| 609 |
+
formatted_unary = {}
|
| 610 |
+
for (frame_id, obj_id, predicate), prob in outputs["unary_probs"][0].items():
|
| 611 |
+
key = (frame_id, obj_id)
|
| 612 |
+
if key not in formatted_unary:
|
| 613 |
+
formatted_unary[key] = []
|
| 614 |
+
formatted_unary[key].append((prob, predicate))
|
| 615 |
+
|
| 616 |
+
# Sort and take top-k
|
| 617 |
+
for key in formatted_unary:
|
| 618 |
+
formatted_unary[key] = sorted(
|
| 619 |
+
formatted_unary[key], reverse=True
|
| 620 |
+
)[:return_top_k]
|
| 621 |
+
|
| 622 |
+
# Format binary results
|
| 623 |
+
formatted_binary = {}
|
| 624 |
+
if len(outputs["binary_probs"]) > 0:
|
| 625 |
+
for (frame_id, obj_pair, predicate), prob in outputs["binary_probs"][0].items():
|
| 626 |
+
key = (frame_id, obj_pair)
|
| 627 |
+
if key not in formatted_binary:
|
| 628 |
+
formatted_binary[key] = []
|
| 629 |
+
formatted_binary[key].append((prob, predicate))
|
| 630 |
+
|
| 631 |
+
# Sort and take top-k
|
| 632 |
+
for key in formatted_binary:
|
| 633 |
+
formatted_binary[key] = sorted(
|
| 634 |
+
formatted_binary[key], reverse=True
|
| 635 |
+
)[:return_top_k]
|
| 636 |
+
|
| 637 |
+
result: Dict[str, Any] = {
|
| 638 |
+
"categorical_predictions": formatted_categorical,
|
| 639 |
+
"unary_predictions": formatted_unary,
|
| 640 |
+
"binary_predictions": formatted_binary,
|
| 641 |
+
"confidence_scores": {
|
| 642 |
+
"categorical": max([max([p for p, _ in preds], default=0.0)
|
| 643 |
+
for preds in formatted_categorical.values()], default=0.0),
|
| 644 |
+
"unary": max([max([p for p, _ in preds], default=0.0)
|
| 645 |
+
for preds in formatted_unary.values()], default=0.0),
|
| 646 |
+
"binary": max([max([p for p, _ in preds], default=0.0)
|
| 647 |
+
for preds in formatted_binary.values()], default=0.0)
|
| 648 |
+
}
|
| 649 |
+
}
|
| 650 |
+
|
| 651 |
+
if "flattened_segments" in outputs:
|
| 652 |
+
result["flattened_segments"] = outputs["flattened_segments"]
|
| 653 |
+
if "valid_pairs" in outputs:
|
| 654 |
+
result["valid_pairs"] = outputs["valid_pairs"]
|
| 655 |
+
if "valid_pairs_metadata" in outputs:
|
| 656 |
+
result["valid_pairs_metadata"] = outputs["valid_pairs_metadata"]
|
| 657 |
+
|
| 658 |
+
return result
|
vis_utils.py
ADDED
|
@@ -0,0 +1,941 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import torch
|
| 6 |
+
import random
|
| 7 |
+
import math
|
| 8 |
+
from matplotlib.patches import Rectangle
|
| 9 |
+
import itertools
|
| 10 |
+
from typing import Any, Dict, List, Tuple, Optional, Union
|
| 11 |
+
|
| 12 |
+
from laser.preprocess.mask_generation_grounding_dino import mask_to_bbox
|
| 13 |
+
|
| 14 |
+
########################################################################################
|
| 15 |
+
########## Visualization Library ########
|
| 16 |
+
########################################################################################
|
| 17 |
+
# This module renders SAM masks, GroundingDINO boxes, and VINE predictions.
|
| 18 |
+
#
|
| 19 |
+
# Conventions (RGB frames, pixel coords):
|
| 20 |
+
# - Frames: list[np.ndarray] with shape (H, W, 3) in RGB, or np.ndarray with shape (T, H, W, 3).
|
| 21 |
+
# - Masks: 2D boolean arrays (H, W) or tensors convertible to that; (H, W, 1) is also accepted.
|
| 22 |
+
# - BBoxes: (x1, y1, x2, y2) integer pixel coordinates with x2 > x1 and y2 > y1.
|
| 23 |
+
#
|
| 24 |
+
# Per-frame stores use one of:
|
| 25 |
+
# - Dict[int(frame_id) -> Dict[int(obj_id) -> value]]
|
| 26 |
+
# - List indexed by frame_id (each item may be a dict of obj_id->value or a list in order)
|
| 27 |
+
#
|
| 28 |
+
# Renderer inputs/outputs:
|
| 29 |
+
# 1) render_sam_frames(frames, sam_masks, dino_labels=None) -> List[np.ndarray]
|
| 30 |
+
# - sam_masks: Dict[frame_id, Dict[obj_id, Mask]] or a list; Mask can be np.ndarray or torch.Tensor.
|
| 31 |
+
# - dino_labels: Optional Dict[obj_id, str] to annotate boxes derived from masks.
|
| 32 |
+
#
|
| 33 |
+
# 2) render_dino_frames(frames, bboxes, dino_labels=None) -> List[np.ndarray]
|
| 34 |
+
# - bboxes: Dict[frame_id, Dict[obj_id, Sequence[float]]] or a list; each bbox as [x1, y1, x2, y2].
|
| 35 |
+
#
|
| 36 |
+
# 3) render_vine_frames(frames, bboxes, cat_label_lookup, unary_lookup, binary_lookup, masks=None)
|
| 37 |
+
# -> List[np.ndarray] (the "all" view)
|
| 38 |
+
# - cat_label_lookup: Dict[obj_id, (label: str, prob: float)]
|
| 39 |
+
# - unary_lookup: Dict[frame_id, Dict[obj_id, List[(prob: float, label: str)]]]
|
| 40 |
+
# - binary_lookup: Dict[frame_id, List[((sub_id: int, obj_id: int), List[(prob: float, relation: str)])]]
|
| 41 |
+
# - masks: Optional; same structure as sam_masks, used for translucent overlays when unary labels exist.
|
| 42 |
+
#
|
| 43 |
+
# Ground-truth helpers used by plotting utilities:
|
| 44 |
+
# - For a single frame, gt_relations is represented as List[(subject_label, object_label, relation_label)].
|
| 45 |
+
#
|
| 46 |
+
# All rendered frames returned by functions are RGB np.ndarray images suitable for saving or video writing.
|
| 47 |
+
########################################################################################
|
| 48 |
+
|
| 49 |
+
def clean_label(label):
|
| 50 |
+
"""Replace underscores and slashes with spaces for uniformity."""
|
| 51 |
+
return label.replace("_", " ").replace("/", " ")
|
| 52 |
+
|
| 53 |
+
# Should be performed somewhere else I believe
|
| 54 |
+
def format_cate_preds(cate_preds):
|
| 55 |
+
# Group object predictions from the model output.
|
| 56 |
+
obj_pred_dict = {}
|
| 57 |
+
for (oid, label), prob in cate_preds.items():
|
| 58 |
+
# Clean the predicted label as well.
|
| 59 |
+
clean_pred = clean_label(label)
|
| 60 |
+
if oid not in obj_pred_dict:
|
| 61 |
+
obj_pred_dict[oid] = []
|
| 62 |
+
obj_pred_dict[oid].append((clean_pred, prob))
|
| 63 |
+
for oid in obj_pred_dict:
|
| 64 |
+
obj_pred_dict[oid].sort(key=lambda x: x[1], reverse=True)
|
| 65 |
+
return obj_pred_dict
|
| 66 |
+
|
| 67 |
+
def format_binary_cate_preds(binary_preds):
|
| 68 |
+
frame_binary_preds = []
|
| 69 |
+
for key, score in binary_preds.items():
|
| 70 |
+
# Expect key format: (frame_id, (subject, object), predicted_relation)
|
| 71 |
+
try:
|
| 72 |
+
f_id, (subj, obj), pred_rel = key
|
| 73 |
+
frame_binary_preds.append((f_id, subj, obj, pred_rel, score))
|
| 74 |
+
except Exception as e:
|
| 75 |
+
print("Skipping key with unexpected format:", key)
|
| 76 |
+
continue
|
| 77 |
+
frame_binary_preds.sort(key=lambda x: x[3], reverse=True)
|
| 78 |
+
return frame_binary_preds
|
| 79 |
+
|
| 80 |
+
_FONT = cv2.FONT_HERSHEY_SIMPLEX
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _to_numpy_mask(mask: Union[np.ndarray, torch.Tensor, None]) -> Optional[np.ndarray]:
|
| 84 |
+
if mask is None:
|
| 85 |
+
return None
|
| 86 |
+
if isinstance(mask, torch.Tensor):
|
| 87 |
+
mask_np = mask.detach().cpu().numpy()
|
| 88 |
+
else:
|
| 89 |
+
mask_np = np.asarray(mask)
|
| 90 |
+
if mask_np.ndim == 0:
|
| 91 |
+
return None
|
| 92 |
+
if mask_np.ndim == 3:
|
| 93 |
+
mask_np = np.squeeze(mask_np)
|
| 94 |
+
if mask_np.ndim != 2:
|
| 95 |
+
return None
|
| 96 |
+
if mask_np.dtype == bool:
|
| 97 |
+
return mask_np
|
| 98 |
+
return mask_np > 0
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def _sanitize_bbox(bbox: Union[List[float], Tuple[float, ...], None], width: int, height: int) -> Optional[Tuple[int, int, int, int]]:
|
| 102 |
+
if bbox is None:
|
| 103 |
+
return None
|
| 104 |
+
if isinstance(bbox, (list, tuple)) and len(bbox) >= 4:
|
| 105 |
+
x1, y1, x2, y2 = [float(b) for b in bbox[:4]]
|
| 106 |
+
elif isinstance(bbox, np.ndarray) and bbox.size >= 4:
|
| 107 |
+
x1, y1, x2, y2 = [float(b) for b in bbox.flat[:4]]
|
| 108 |
+
else:
|
| 109 |
+
return None
|
| 110 |
+
x1 = int(np.clip(round(x1), 0, width - 1))
|
| 111 |
+
y1 = int(np.clip(round(y1), 0, height - 1))
|
| 112 |
+
x2 = int(np.clip(round(x2), 0, width - 1))
|
| 113 |
+
y2 = int(np.clip(round(y2), 0, height - 1))
|
| 114 |
+
if x2 <= x1 or y2 <= y1:
|
| 115 |
+
return None
|
| 116 |
+
return (x1, y1, x2, y2)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def _object_color_bgr(obj_id: int) -> Tuple[int, int, int]:
|
| 120 |
+
color = get_color(obj_id)
|
| 121 |
+
rgb = [int(np.clip(c, 0.0, 1.0) * 255) for c in color[:3]]
|
| 122 |
+
return (rgb[2], rgb[1], rgb[0])
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def _background_color(color: Tuple[int, int, int]) -> Tuple[int, int, int]:
|
| 126 |
+
return tuple(int(0.25 * 255 + 0.75 * channel) for channel in color)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def _draw_label_block(
|
| 130 |
+
image: np.ndarray,
|
| 131 |
+
lines: List[str],
|
| 132 |
+
anchor: Tuple[int, int],
|
| 133 |
+
color: Tuple[int, int, int],
|
| 134 |
+
font_scale: float = 0.5,
|
| 135 |
+
thickness: int = 1,
|
| 136 |
+
direction: str = "up",
|
| 137 |
+
) -> None:
|
| 138 |
+
if not lines:
|
| 139 |
+
return
|
| 140 |
+
img_h, img_w = image.shape[:2]
|
| 141 |
+
x, y = anchor
|
| 142 |
+
x = int(np.clip(x, 0, img_w - 1))
|
| 143 |
+
y_cursor = int(np.clip(y, 0, img_h - 1))
|
| 144 |
+
bg_color = _background_color(color)
|
| 145 |
+
|
| 146 |
+
if direction == "down":
|
| 147 |
+
for text in lines:
|
| 148 |
+
text = str(text)
|
| 149 |
+
(tw, th), baseline = cv2.getTextSize(text, _FONT, font_scale, thickness)
|
| 150 |
+
left_x = x
|
| 151 |
+
right_x = min(left_x + tw + 8, img_w - 1)
|
| 152 |
+
top_y = int(np.clip(y_cursor + 6, 0, img_h - 1))
|
| 153 |
+
bottom_y = int(np.clip(top_y + th + baseline + 6, 0, img_h - 1))
|
| 154 |
+
if bottom_y <= top_y:
|
| 155 |
+
break
|
| 156 |
+
cv2.rectangle(image, (left_x, top_y), (right_x, bottom_y), bg_color, -1)
|
| 157 |
+
text_x = left_x + 4
|
| 158 |
+
text_y = min(bottom_y - baseline - 2, img_h - 1)
|
| 159 |
+
cv2.putText(image, text, (text_x, text_y), _FONT, font_scale, (0, 0, 0), thickness, cv2.LINE_AA)
|
| 160 |
+
y_cursor = bottom_y
|
| 161 |
+
else:
|
| 162 |
+
for text in lines:
|
| 163 |
+
text = str(text)
|
| 164 |
+
(tw, th), baseline = cv2.getTextSize(text, _FONT, font_scale, thickness)
|
| 165 |
+
top_y = max(y_cursor - th - baseline - 6, 0)
|
| 166 |
+
left_x = x
|
| 167 |
+
right_x = min(left_x + tw + 8, img_w - 1)
|
| 168 |
+
bottom_y = min(top_y + th + baseline + 6, img_h - 1)
|
| 169 |
+
cv2.rectangle(image, (left_x, top_y), (right_x, bottom_y), bg_color, -1)
|
| 170 |
+
text_x = left_x + 4
|
| 171 |
+
text_y = min(bottom_y - baseline - 2, img_h - 1)
|
| 172 |
+
cv2.putText(image, text, (text_x, text_y), _FONT, font_scale, (0, 0, 0), thickness, cv2.LINE_AA)
|
| 173 |
+
y_cursor = top_y
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def _draw_centered_label(
|
| 177 |
+
image: np.ndarray,
|
| 178 |
+
text: str,
|
| 179 |
+
center: Tuple[int, int],
|
| 180 |
+
color: Tuple[int, int, int],
|
| 181 |
+
font_scale: float = 0.5,
|
| 182 |
+
thickness: int = 1,
|
| 183 |
+
) -> None:
|
| 184 |
+
text = str(text)
|
| 185 |
+
img_h, img_w = image.shape[:2]
|
| 186 |
+
(tw, th), baseline = cv2.getTextSize(text, _FONT, font_scale, thickness)
|
| 187 |
+
cx = int(np.clip(center[0], 0, img_w - 1))
|
| 188 |
+
cy = int(np.clip(center[1], 0, img_h - 1))
|
| 189 |
+
left_x = int(np.clip(cx - tw // 2 - 4, 0, img_w - 1))
|
| 190 |
+
top_y = int(np.clip(cy - th // 2 - baseline - 4, 0, img_h - 1))
|
| 191 |
+
right_x = int(np.clip(left_x + tw + 8, 0, img_w - 1))
|
| 192 |
+
bottom_y = int(np.clip(top_y + th + baseline + 6, 0, img_h - 1))
|
| 193 |
+
cv2.rectangle(image, (left_x, top_y), (right_x, bottom_y), _background_color(color), -1)
|
| 194 |
+
text_x = left_x + 4
|
| 195 |
+
text_y = min(bottom_y - baseline - 2, img_h - 1)
|
| 196 |
+
cv2.putText(image, text, (text_x, text_y), _FONT, font_scale, (0, 0, 0), thickness, cv2.LINE_AA)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def _extract_frame_entities(store: Union[Dict[int, Dict[int, Any]], List, None], frame_idx: int) -> Dict[int, Any]:
|
| 200 |
+
if isinstance(store, dict):
|
| 201 |
+
frame_entry = store.get(frame_idx, {})
|
| 202 |
+
elif isinstance(store, list) and 0 <= frame_idx < len(store):
|
| 203 |
+
frame_entry = store[frame_idx]
|
| 204 |
+
else:
|
| 205 |
+
frame_entry = {}
|
| 206 |
+
if isinstance(frame_entry, dict):
|
| 207 |
+
return frame_entry
|
| 208 |
+
if isinstance(frame_entry, list):
|
| 209 |
+
return {i: value for i, value in enumerate(frame_entry)}
|
| 210 |
+
return {}
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def _label_anchor_and_direction(
|
| 214 |
+
bbox: Tuple[int, int, int, int],
|
| 215 |
+
position: str,
|
| 216 |
+
) -> Tuple[Tuple[int, int], str]:
|
| 217 |
+
x1, y1, x2, y2 = bbox
|
| 218 |
+
if position == "bottom":
|
| 219 |
+
return (x1, y2), "down"
|
| 220 |
+
return (x1, y1), "up"
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def _draw_bbox_with_label(
|
| 224 |
+
image: np.ndarray,
|
| 225 |
+
bbox: Tuple[int, int, int, int],
|
| 226 |
+
obj_id: int,
|
| 227 |
+
title: Optional[str] = None,
|
| 228 |
+
sub_lines: Optional[List[str]] = None,
|
| 229 |
+
label_position: str = "top",
|
| 230 |
+
) -> None:
|
| 231 |
+
color = _object_color_bgr(obj_id)
|
| 232 |
+
cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), color, 2)
|
| 233 |
+
head = title if title else f"#{obj_id}"
|
| 234 |
+
if not head.startswith("#"):
|
| 235 |
+
head = f"#{obj_id} {head}"
|
| 236 |
+
lines = [head]
|
| 237 |
+
if sub_lines:
|
| 238 |
+
lines.extend(sub_lines)
|
| 239 |
+
anchor, direction = _label_anchor_and_direction(bbox, label_position)
|
| 240 |
+
_draw_label_block(image, lines, anchor, color, direction=direction)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def render_sam_frames(
|
| 244 |
+
frames: Union[np.ndarray, List[np.ndarray]],
|
| 245 |
+
sam_masks: Union[Dict[int, Dict[int, Union[np.ndarray, torch.Tensor]]], List, None],
|
| 246 |
+
dino_labels: Optional[Dict[int, str]] = None,
|
| 247 |
+
) -> List[np.ndarray]:
|
| 248 |
+
results: List[np.ndarray] = []
|
| 249 |
+
frames_iterable = frames if isinstance(frames, list) else list(frames)
|
| 250 |
+
dino_labels = dino_labels or {}
|
| 251 |
+
|
| 252 |
+
for frame_idx, frame in enumerate(frames_iterable):
|
| 253 |
+
if frame is None:
|
| 254 |
+
continue
|
| 255 |
+
frame_rgb = np.asarray(frame)
|
| 256 |
+
frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
|
| 257 |
+
overlay = frame_bgr.astype(np.float32)
|
| 258 |
+
masks_for_frame = _extract_frame_entities(sam_masks, frame_idx)
|
| 259 |
+
|
| 260 |
+
for obj_id, mask in masks_for_frame.items():
|
| 261 |
+
mask_np = _to_numpy_mask(mask)
|
| 262 |
+
if mask_np is None or not np.any(mask_np):
|
| 263 |
+
continue
|
| 264 |
+
color = _object_color_bgr(obj_id)
|
| 265 |
+
alpha = 0.45
|
| 266 |
+
overlay[mask_np] = (1.0 - alpha) * overlay[mask_np] + alpha * np.array(color, dtype=np.float32)
|
| 267 |
+
|
| 268 |
+
annotated = np.clip(overlay, 0, 255).astype(np.uint8)
|
| 269 |
+
frame_h, frame_w = annotated.shape[:2]
|
| 270 |
+
|
| 271 |
+
for obj_id, mask in masks_for_frame.items():
|
| 272 |
+
mask_np = _to_numpy_mask(mask)
|
| 273 |
+
if mask_np is None or not np.any(mask_np):
|
| 274 |
+
continue
|
| 275 |
+
bbox = mask_to_bbox(mask_np)
|
| 276 |
+
bbox = _sanitize_bbox(bbox, frame_w, frame_h)
|
| 277 |
+
if not bbox:
|
| 278 |
+
continue
|
| 279 |
+
label = dino_labels.get(obj_id)
|
| 280 |
+
title = f"{label}" if label else None
|
| 281 |
+
_draw_bbox_with_label(annotated, bbox, obj_id, title=title)
|
| 282 |
+
|
| 283 |
+
results.append(cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB))
|
| 284 |
+
|
| 285 |
+
return results
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def render_dino_frames(
|
| 289 |
+
frames: Union[np.ndarray, List[np.ndarray]],
|
| 290 |
+
bboxes: Union[Dict[int, Dict[int, Union[List[float], np.ndarray]]], List, None],
|
| 291 |
+
dino_labels: Optional[Dict[int, str]] = None,
|
| 292 |
+
) -> List[np.ndarray]:
|
| 293 |
+
results: List[np.ndarray] = []
|
| 294 |
+
frames_iterable = frames if isinstance(frames, list) else list(frames)
|
| 295 |
+
dino_labels = dino_labels or {}
|
| 296 |
+
|
| 297 |
+
for frame_idx, frame in enumerate(frames_iterable):
|
| 298 |
+
if frame is None:
|
| 299 |
+
continue
|
| 300 |
+
frame_rgb = np.asarray(frame)
|
| 301 |
+
annotated = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
|
| 302 |
+
frame_h, frame_w = annotated.shape[:2]
|
| 303 |
+
frame_bboxes = _extract_frame_entities(bboxes, frame_idx)
|
| 304 |
+
|
| 305 |
+
for obj_id, bbox_values in frame_bboxes.items():
|
| 306 |
+
bbox = _sanitize_bbox(bbox_values, frame_w, frame_h)
|
| 307 |
+
if not bbox:
|
| 308 |
+
continue
|
| 309 |
+
label = dino_labels.get(obj_id)
|
| 310 |
+
title = f"{label}" if label else None
|
| 311 |
+
_draw_bbox_with_label(annotated, bbox, obj_id, title=title)
|
| 312 |
+
|
| 313 |
+
results.append(cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB))
|
| 314 |
+
|
| 315 |
+
return results
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def render_vine_frame_sets(
|
| 319 |
+
frames: Union[np.ndarray, List[np.ndarray]],
|
| 320 |
+
bboxes: Union[Dict[int, Dict[int, Union[List[float], np.ndarray]]], List, None],
|
| 321 |
+
cat_label_lookup: Dict[int, Tuple[str, float]],
|
| 322 |
+
unary_lookup: Dict[int, Dict[int, List[Tuple[float, str]]]],
|
| 323 |
+
binary_lookup: Dict[int, List[Tuple[Tuple[int, int], List[Tuple[float, str]]]]],
|
| 324 |
+
masks: Union[Dict[int, Dict[int, Union[np.ndarray, torch.Tensor]]], List, None] = None,
|
| 325 |
+
) -> Dict[str, List[np.ndarray]]:
|
| 326 |
+
frame_groups: Dict[str, List[np.ndarray]] = {
|
| 327 |
+
"object": [],
|
| 328 |
+
"unary": [],
|
| 329 |
+
"binary": [],
|
| 330 |
+
"all": [],
|
| 331 |
+
}
|
| 332 |
+
frames_iterable = frames if isinstance(frames, list) else list(frames)
|
| 333 |
+
|
| 334 |
+
for frame_idx, frame in enumerate(frames_iterable):
|
| 335 |
+
if frame is None:
|
| 336 |
+
continue
|
| 337 |
+
frame_rgb = np.asarray(frame)
|
| 338 |
+
base_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
|
| 339 |
+
frame_h, frame_w = base_bgr.shape[:2]
|
| 340 |
+
frame_bboxes = _extract_frame_entities(bboxes, frame_idx)
|
| 341 |
+
frame_masks = _extract_frame_entities(masks, frame_idx) if masks is not None else {}
|
| 342 |
+
|
| 343 |
+
objects_bgr = base_bgr.copy()
|
| 344 |
+
unary_bgr = base_bgr.copy()
|
| 345 |
+
binary_bgr = base_bgr.copy()
|
| 346 |
+
all_bgr = base_bgr.copy()
|
| 347 |
+
|
| 348 |
+
bbox_lookup: Dict[int, Tuple[int, int, int, int]] = {}
|
| 349 |
+
unary_lines_lookup: Dict[int, List[str]] = {}
|
| 350 |
+
titles_lookup: Dict[int, Optional[str]] = {}
|
| 351 |
+
|
| 352 |
+
for obj_id, bbox_values in frame_bboxes.items():
|
| 353 |
+
bbox = _sanitize_bbox(bbox_values, frame_w, frame_h)
|
| 354 |
+
if not bbox:
|
| 355 |
+
continue
|
| 356 |
+
bbox_lookup[obj_id] = bbox
|
| 357 |
+
cat_label, cat_prob = cat_label_lookup.get(obj_id, (None, None))
|
| 358 |
+
title_parts = []
|
| 359 |
+
if cat_label:
|
| 360 |
+
if cat_prob is not None:
|
| 361 |
+
title_parts.append(f"{cat_label} {cat_prob:.2f}")
|
| 362 |
+
else:
|
| 363 |
+
title_parts.append(cat_label)
|
| 364 |
+
titles_lookup[obj_id] = " ".join(title_parts) if title_parts else None
|
| 365 |
+
unary_preds = unary_lookup.get(frame_idx, {}).get(obj_id, [])
|
| 366 |
+
unary_lines = [f"{label} {prob:.2f}" for prob, label in unary_preds]
|
| 367 |
+
unary_lines_lookup[obj_id] = unary_lines
|
| 368 |
+
|
| 369 |
+
for obj_id, bbox in bbox_lookup.items():
|
| 370 |
+
unary_lines = unary_lines_lookup.get(obj_id, [])
|
| 371 |
+
if not unary_lines:
|
| 372 |
+
continue
|
| 373 |
+
mask_raw = frame_masks.get(obj_id)
|
| 374 |
+
mask_np = _to_numpy_mask(mask_raw)
|
| 375 |
+
if mask_np is None or not np.any(mask_np):
|
| 376 |
+
continue
|
| 377 |
+
color = np.array(_object_color_bgr(obj_id), dtype=np.float32)
|
| 378 |
+
alpha = 0.45
|
| 379 |
+
for target in (unary_bgr, all_bgr):
|
| 380 |
+
target_vals = target[mask_np].astype(np.float32)
|
| 381 |
+
blended = (1.0 - alpha) * target_vals + alpha * color
|
| 382 |
+
target[mask_np] = np.clip(blended, 0, 255).astype(np.uint8)
|
| 383 |
+
|
| 384 |
+
for obj_id, bbox in bbox_lookup.items():
|
| 385 |
+
title = titles_lookup.get(obj_id)
|
| 386 |
+
unary_lines = unary_lines_lookup.get(obj_id, [])
|
| 387 |
+
_draw_bbox_with_label(objects_bgr, bbox, obj_id, title=title, label_position="top")
|
| 388 |
+
_draw_bbox_with_label(unary_bgr, bbox, obj_id, title=title, label_position="top")
|
| 389 |
+
if unary_lines:
|
| 390 |
+
anchor, direction = _label_anchor_and_direction(bbox, "bottom")
|
| 391 |
+
_draw_label_block(unary_bgr, unary_lines, anchor, _object_color_bgr(obj_id), direction=direction)
|
| 392 |
+
_draw_bbox_with_label(binary_bgr, bbox, obj_id, title=title, label_position="top")
|
| 393 |
+
_draw_bbox_with_label(all_bgr, bbox, obj_id, title=title, label_position="top")
|
| 394 |
+
if unary_lines:
|
| 395 |
+
anchor, direction = _label_anchor_and_direction(bbox, "bottom")
|
| 396 |
+
_draw_label_block(all_bgr, unary_lines, anchor, _object_color_bgr(obj_id), direction=direction)
|
| 397 |
+
|
| 398 |
+
for obj_pair, relation_preds in binary_lookup.get(frame_idx, []):
|
| 399 |
+
if len(obj_pair) != 2 or not relation_preds:
|
| 400 |
+
continue
|
| 401 |
+
subj_id, obj_id = obj_pair
|
| 402 |
+
subj_bbox = bbox_lookup.get(subj_id)
|
| 403 |
+
obj_bbox = bbox_lookup.get(obj_id)
|
| 404 |
+
if not subj_bbox or not obj_bbox:
|
| 405 |
+
continue
|
| 406 |
+
start, end = relation_line(subj_bbox, obj_bbox)
|
| 407 |
+
color = tuple(int(c) for c in np.clip(
|
| 408 |
+
(np.array(_object_color_bgr(subj_id), dtype=np.float32) +
|
| 409 |
+
np.array(_object_color_bgr(obj_id), dtype=np.float32)) / 2.0,
|
| 410 |
+
0, 255
|
| 411 |
+
))
|
| 412 |
+
prob, relation = relation_preds[0]
|
| 413 |
+
label_text = f"{relation} {prob:.2f}"
|
| 414 |
+
mid_point = (int((start[0] + end[0]) / 2), int((start[1] + end[1]) / 2))
|
| 415 |
+
cv2.line(binary_bgr, start, end, color, 6, cv2.LINE_AA)
|
| 416 |
+
cv2.line(all_bgr, start, end, color, 6, cv2.LINE_AA)
|
| 417 |
+
_draw_centered_label(binary_bgr, label_text, mid_point, color)
|
| 418 |
+
_draw_centered_label(all_bgr, label_text, mid_point, color)
|
| 419 |
+
|
| 420 |
+
frame_groups["object"].append(cv2.cvtColor(objects_bgr, cv2.COLOR_BGR2RGB))
|
| 421 |
+
frame_groups["unary"].append(cv2.cvtColor(unary_bgr, cv2.COLOR_BGR2RGB))
|
| 422 |
+
frame_groups["binary"].append(cv2.cvtColor(binary_bgr, cv2.COLOR_BGR2RGB))
|
| 423 |
+
frame_groups["all"].append(cv2.cvtColor(all_bgr, cv2.COLOR_BGR2RGB))
|
| 424 |
+
|
| 425 |
+
return frame_groups
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def render_vine_frames(
|
| 429 |
+
frames: Union[np.ndarray, List[np.ndarray]],
|
| 430 |
+
bboxes: Union[Dict[int, Dict[int, Union[List[float], np.ndarray]]], List, None],
|
| 431 |
+
cat_label_lookup: Dict[int, Tuple[str, float]],
|
| 432 |
+
unary_lookup: Dict[int, Dict[int, List[Tuple[float, str]]]],
|
| 433 |
+
binary_lookup: Dict[int, List[Tuple[Tuple[int, int], List[Tuple[float, str]]]]],
|
| 434 |
+
masks: Union[Dict[int, Dict[int, Union[np.ndarray, torch.Tensor]]], List, None] = None,
|
| 435 |
+
) -> List[np.ndarray]:
|
| 436 |
+
return render_vine_frame_sets(
|
| 437 |
+
frames,
|
| 438 |
+
bboxes,
|
| 439 |
+
cat_label_lookup,
|
| 440 |
+
unary_lookup,
|
| 441 |
+
binary_lookup,
|
| 442 |
+
masks,
|
| 443 |
+
).get("all", [])
|
| 444 |
+
|
| 445 |
+
def color_for_cate_correctness(obj_pred_dict, gt_labels, topk_object):
|
| 446 |
+
all_colors = []
|
| 447 |
+
all_texts = []
|
| 448 |
+
for (obj_id, bbox, gt_label) in gt_labels:
|
| 449 |
+
preds = obj_pred_dict.get(obj_id, [])
|
| 450 |
+
if len(preds) == 0:
|
| 451 |
+
top1 = "N/A"
|
| 452 |
+
box_color = (0, 0, 255) # bright red if no prediction
|
| 453 |
+
else:
|
| 454 |
+
top1, prob1 = preds[0]
|
| 455 |
+
topk_labels = [p[0] for p in preds[:topk_object]]
|
| 456 |
+
# Compare cleaned labels.
|
| 457 |
+
if top1.lower() == gt_label.lower():
|
| 458 |
+
box_color = (0, 255, 0) # bright green for correct
|
| 459 |
+
elif gt_label.lower() in [p.lower() for p in topk_labels]:
|
| 460 |
+
box_color = (0, 165, 255) # bright orange for partial match
|
| 461 |
+
else:
|
| 462 |
+
box_color = (0, 0, 255) # bright red for incorrect
|
| 463 |
+
|
| 464 |
+
label_text = f"ID:{obj_id}/P:{top1}/GT:{gt_label}"
|
| 465 |
+
all_colors.append(box_color)
|
| 466 |
+
all_texts.append(label_text)
|
| 467 |
+
return all_colors, all_texts
|
| 468 |
+
|
| 469 |
+
def plot_unary(frame_img, gt_labels, all_colors, all_texts):
|
| 470 |
+
|
| 471 |
+
for (obj_id, bbox, gt_label), box_color, label_text in zip(gt_labels, all_colors, all_texts):
|
| 472 |
+
x1, y1, x2, y2 = map(int, bbox)
|
| 473 |
+
cv2.rectangle(frame_img, (x1, y1), (x2, y2), color=box_color, thickness=2)
|
| 474 |
+
(tw, th), baseline = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
| 475 |
+
cv2.rectangle(frame_img, (x1, y1 - th - baseline - 4), (x1 + tw, y1), box_color, -1)
|
| 476 |
+
cv2.putText(frame_img, label_text, (x1, y1 - 2), cv2.FONT_HERSHEY_SIMPLEX,
|
| 477 |
+
0.5, (0, 0, 0), 1, cv2.LINE_AA)
|
| 478 |
+
|
| 479 |
+
return frame_img
|
| 480 |
+
|
| 481 |
+
def get_white_pane(pane_height,
|
| 482 |
+
pane_width=600,
|
| 483 |
+
header_height = 50,
|
| 484 |
+
header_font = cv2.FONT_HERSHEY_SIMPLEX,
|
| 485 |
+
header_font_scale = 0.7,
|
| 486 |
+
header_thickness = 2,
|
| 487 |
+
header_color = (0, 0, 0)):
|
| 488 |
+
# Create an expanded white pane to display text info.
|
| 489 |
+
white_pane = 255 * np.ones((pane_height, pane_width, 3), dtype=np.uint8)
|
| 490 |
+
|
| 491 |
+
# --- Adjust pane split: make predictions column wider (60% vs. 40%) ---
|
| 492 |
+
left_width = int(pane_width * 0.6)
|
| 493 |
+
right_width = pane_width - left_width
|
| 494 |
+
left_pane = white_pane[:, :left_width, :].copy()
|
| 495 |
+
right_pane = white_pane[:, left_width:, :].copy()
|
| 496 |
+
|
| 497 |
+
cv2.putText(left_pane, "Binary Predictions", (10, header_height - 30),
|
| 498 |
+
header_font, header_font_scale, header_color, header_thickness, cv2.LINE_AA)
|
| 499 |
+
cv2.putText(right_pane, "Ground Truth", (10, header_height - 30),
|
| 500 |
+
header_font, header_font_scale, header_color, header_thickness, cv2.LINE_AA)
|
| 501 |
+
|
| 502 |
+
return white_pane
|
| 503 |
+
|
| 504 |
+
# This is for ploting binary prediction results with frame-based scene graphs
|
| 505 |
+
def plot_binary_sg(frame_img,
|
| 506 |
+
white_pane,
|
| 507 |
+
bin_preds,
|
| 508 |
+
gt_relations,
|
| 509 |
+
topk_binary,
|
| 510 |
+
header_height=50,
|
| 511 |
+
indicator_size=20,
|
| 512 |
+
pane_width=600):
|
| 513 |
+
# Leave vertical space for the headers.
|
| 514 |
+
line_height = 30 # vertical spacing per line
|
| 515 |
+
x_text = 10 # left margin for text
|
| 516 |
+
y_text_left = header_height + 10 # starting y for left pane text
|
| 517 |
+
y_text_right = header_height + 10 # starting y for right pane text
|
| 518 |
+
|
| 519 |
+
# Left section: top-k binary predictions.
|
| 520 |
+
left_width = int(pane_width * 0.6)
|
| 521 |
+
right_width = pane_width - left_width
|
| 522 |
+
left_pane = white_pane[:, :left_width, :].copy()
|
| 523 |
+
right_pane = white_pane[:, left_width:, :].copy()
|
| 524 |
+
|
| 525 |
+
for (subj, pred_rel, obj, score) in bin_preds[:topk_binary]:
|
| 526 |
+
correct = any((subj == gt[0] and pred_rel.lower() == gt[2].lower() and obj == gt[1])
|
| 527 |
+
for gt in gt_relations)
|
| 528 |
+
indicator_color = (0, 255, 0) if correct else (0, 0, 255)
|
| 529 |
+
cv2.rectangle(left_pane, (x_text, y_text_left - indicator_size + 5),
|
| 530 |
+
(x_text + indicator_size, y_text_left + 5), indicator_color, -1)
|
| 531 |
+
text = f"{subj} - {pred_rel} - {obj} :: {score:.2f}"
|
| 532 |
+
cv2.putText(left_pane, text, (x_text + indicator_size + 5, y_text_left + 5),
|
| 533 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 1, cv2.LINE_AA)
|
| 534 |
+
y_text_left += line_height
|
| 535 |
+
|
| 536 |
+
# Right section: ground truth binary relations.
|
| 537 |
+
for gt in gt_relations:
|
| 538 |
+
if len(gt) != 3:
|
| 539 |
+
continue
|
| 540 |
+
text = f"{gt[0]} - {gt[2]} - {gt[1]}"
|
| 541 |
+
cv2.putText(right_pane, text, (x_text, y_text_right + 5),
|
| 542 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 1, cv2.LINE_AA)
|
| 543 |
+
y_text_right += line_height
|
| 544 |
+
|
| 545 |
+
# Combine the two text panes and then with the frame image.
|
| 546 |
+
combined_pane = np.hstack((left_pane, right_pane))
|
| 547 |
+
combined_image = np.hstack((frame_img, combined_pane))
|
| 548 |
+
return combined_image
|
| 549 |
+
|
| 550 |
+
def visualized_frame(frame_img,
|
| 551 |
+
bboxes,
|
| 552 |
+
object_ids,
|
| 553 |
+
gt_labels,
|
| 554 |
+
cate_preds,
|
| 555 |
+
binary_preds,
|
| 556 |
+
gt_relations,
|
| 557 |
+
topk_object,
|
| 558 |
+
topk_binary,
|
| 559 |
+
phase="unary"):
|
| 560 |
+
|
| 561 |
+
"""Return the combined annotated frame for frame index i as an image (in BGR)."""
|
| 562 |
+
# Get the frame image (assuming batched_data['batched_reshaped_raw_videos'] is a list of frames)
|
| 563 |
+
|
| 564 |
+
# --- Process Object Predictions (for overlaying bboxes) ---
|
| 565 |
+
if phase == "unary":
|
| 566 |
+
objs = []
|
| 567 |
+
for ((_, f_id, obj_id), bbox, gt_label) in zip(object_ids, bboxes, gt_labels):
|
| 568 |
+
gt_label = clean_label(gt_label)
|
| 569 |
+
objs.append((obj_id, bbox, gt_label))
|
| 570 |
+
|
| 571 |
+
formatted_cate_preds = format_cate_preds(cate_preds)
|
| 572 |
+
all_colors, all_texts = color_for_cate_correctness(formatted_cate_preds, gt_labels, topk_object)
|
| 573 |
+
updated_frame_img = plot_unary(frame_img, gt_labels, all_colors, all_texts)
|
| 574 |
+
return updated_frame_img
|
| 575 |
+
|
| 576 |
+
else:
|
| 577 |
+
# --- Process Binary Predictions & Ground Truth for the Text Pane ---
|
| 578 |
+
formatted_binary_preds = format_binary_cate_preds(binary_preds)
|
| 579 |
+
|
| 580 |
+
# Ground truth binary relations for the frame.
|
| 581 |
+
# Clean ground truth relations.
|
| 582 |
+
gt_relations = [(clean_label(str(s)), clean_label(str(o)), clean_label(rel)) for s, o, rel in gt_relations]
|
| 583 |
+
|
| 584 |
+
pane_width = 600 # increased pane width for more horizontal space
|
| 585 |
+
pane_height = frame_img.shape[0]
|
| 586 |
+
|
| 587 |
+
# --- Add header labels to each text pane with extra space ---
|
| 588 |
+
header_height = 50 # increased header space
|
| 589 |
+
white_pane = get_white_pane(pane_height, pane_width, header_height=header_height)
|
| 590 |
+
|
| 591 |
+
combined_image = plot_binary_sg(frame_img, white_pane, formatted_binary_preds, gt_relations, topk_binary)
|
| 592 |
+
|
| 593 |
+
return combined_image
|
| 594 |
+
|
| 595 |
+
def show_mask(mask, ax, obj_id=None, det_class=None, random_color=False):
|
| 596 |
+
# Ensure mask is a numpy array
|
| 597 |
+
mask = np.array(mask)
|
| 598 |
+
# Handle different mask shapes
|
| 599 |
+
if mask.ndim == 3:
|
| 600 |
+
# (1, H, W) -> (H, W)
|
| 601 |
+
if mask.shape[0] == 1:
|
| 602 |
+
mask = mask.squeeze(0)
|
| 603 |
+
# (H, W, 1) -> (H, W)
|
| 604 |
+
elif mask.shape[2] == 1:
|
| 605 |
+
mask = mask.squeeze(2)
|
| 606 |
+
# Now mask should be (H, W)
|
| 607 |
+
assert mask.ndim == 2, f"Mask must be 2D after squeezing, got shape {mask.shape}"
|
| 608 |
+
|
| 609 |
+
if random_color:
|
| 610 |
+
color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0)
|
| 611 |
+
else:
|
| 612 |
+
cmap = plt.get_cmap("gist_rainbow")
|
| 613 |
+
cmap_idx = 0 if obj_id is None else obj_id
|
| 614 |
+
color = list(cmap((cmap_idx * 47) % 256))
|
| 615 |
+
color[3] = 0.5
|
| 616 |
+
color = np.array(color)
|
| 617 |
+
|
| 618 |
+
# Expand mask to (H, W, 1) for broadcasting
|
| 619 |
+
mask_expanded = mask[..., None]
|
| 620 |
+
mask_image = mask_expanded * color.reshape(1, 1, -1)
|
| 621 |
+
|
| 622 |
+
# draw a box around the mask with the det_class as the label
|
| 623 |
+
if not det_class is None:
|
| 624 |
+
# Find the bounding box coordinates
|
| 625 |
+
y_indices, x_indices = np.where(mask > 0)
|
| 626 |
+
if y_indices.size > 0 and x_indices.size > 0:
|
| 627 |
+
x_min, x_max = x_indices.min(), x_indices.max()
|
| 628 |
+
y_min, y_max = y_indices.min(), y_indices.max()
|
| 629 |
+
rect = Rectangle(
|
| 630 |
+
(x_min, y_min),
|
| 631 |
+
x_max - x_min,
|
| 632 |
+
y_max - y_min,
|
| 633 |
+
linewidth=1.5,
|
| 634 |
+
edgecolor=color[:3],
|
| 635 |
+
facecolor="none",
|
| 636 |
+
alpha=color[3]
|
| 637 |
+
)
|
| 638 |
+
ax.add_patch(rect)
|
| 639 |
+
ax.text(
|
| 640 |
+
x_min,
|
| 641 |
+
y_min - 5,
|
| 642 |
+
f"{det_class}",
|
| 643 |
+
color="white",
|
| 644 |
+
fontsize=6,
|
| 645 |
+
backgroundcolor=np.array(color),
|
| 646 |
+
alpha=1
|
| 647 |
+
)
|
| 648 |
+
ax.imshow(mask_image)
|
| 649 |
+
|
| 650 |
+
def save_mask_one_image(frame_image, masks, save_path):
|
| 651 |
+
"""Render masks on top of a frame and store the visualization on disk."""
|
| 652 |
+
fig, ax = plt.subplots(1, figsize=(6, 6))
|
| 653 |
+
|
| 654 |
+
frame_np = (
|
| 655 |
+
frame_image.detach().cpu().numpy()
|
| 656 |
+
if torch.is_tensor(frame_image)
|
| 657 |
+
else np.asarray(frame_image)
|
| 658 |
+
)
|
| 659 |
+
frame_np = np.ascontiguousarray(frame_np)
|
| 660 |
+
|
| 661 |
+
if isinstance(masks, dict):
|
| 662 |
+
mask_iter = masks.items()
|
| 663 |
+
else:
|
| 664 |
+
mask_iter = enumerate(masks)
|
| 665 |
+
|
| 666 |
+
prepared_masks = {
|
| 667 |
+
obj_id: (
|
| 668 |
+
mask.detach().cpu().numpy()
|
| 669 |
+
if torch.is_tensor(mask)
|
| 670 |
+
else np.asarray(mask)
|
| 671 |
+
)
|
| 672 |
+
for obj_id, mask in mask_iter
|
| 673 |
+
}
|
| 674 |
+
|
| 675 |
+
ax.imshow(frame_np)
|
| 676 |
+
ax.axis("off")
|
| 677 |
+
|
| 678 |
+
for obj_id, mask_np in prepared_masks.items():
|
| 679 |
+
show_mask(mask_np, ax, obj_id=obj_id, det_class=None, random_color=False)
|
| 680 |
+
|
| 681 |
+
fig.savefig(save_path, bbox_inches="tight", pad_inches=0)
|
| 682 |
+
plt.close(fig)
|
| 683 |
+
return save_path
|
| 684 |
+
|
| 685 |
+
def get_video_masks_visualization(video_tensor,
|
| 686 |
+
video_masks,
|
| 687 |
+
video_id,
|
| 688 |
+
video_save_base_dir,
|
| 689 |
+
oid_class_pred=None,
|
| 690 |
+
sample_rate = 1):
|
| 691 |
+
|
| 692 |
+
video_save_dir = os.path.join(video_save_base_dir, video_id)
|
| 693 |
+
if not os.path.exists(video_save_dir):
|
| 694 |
+
os.makedirs(video_save_dir, exist_ok=True)
|
| 695 |
+
|
| 696 |
+
for frame_id, image in enumerate(video_tensor):
|
| 697 |
+
if frame_id not in video_masks:
|
| 698 |
+
print("No mask for Frame", frame_id)
|
| 699 |
+
continue
|
| 700 |
+
|
| 701 |
+
masks = video_masks[frame_id]
|
| 702 |
+
save_path = os.path.join(video_save_dir, f"{frame_id}.jpg")
|
| 703 |
+
get_mask_one_image(image, masks, oid_class_pred)
|
| 704 |
+
|
| 705 |
+
def get_mask_one_image(frame_image, masks, oid_class_pred=None):
|
| 706 |
+
# Create a figure and axis
|
| 707 |
+
fig, ax = plt.subplots(1, figsize=(6, 6))
|
| 708 |
+
|
| 709 |
+
# Display the frame image
|
| 710 |
+
ax.imshow(frame_image)
|
| 711 |
+
ax.axis('off')
|
| 712 |
+
|
| 713 |
+
if type(masks) == list:
|
| 714 |
+
masks = {i: m for i, m in enumerate(masks)}
|
| 715 |
+
|
| 716 |
+
# Add the masks
|
| 717 |
+
for obj_id, mask in masks.items():
|
| 718 |
+
det_class = f"{obj_id}. {oid_class_pred[obj_id]}" if not oid_class_pred is None else None
|
| 719 |
+
show_mask(mask, ax, obj_id=obj_id, det_class=det_class, random_color=False)
|
| 720 |
+
|
| 721 |
+
# Show the plot
|
| 722 |
+
return fig, ax
|
| 723 |
+
|
| 724 |
+
def save_video(frames, output_filename, output_fps):
|
| 725 |
+
|
| 726 |
+
# --- Create a video from all frames ---
|
| 727 |
+
num_frames = len(frames)
|
| 728 |
+
frame_h, frame_w = frames.shape[:2]
|
| 729 |
+
|
| 730 |
+
# Use a codec supported by VS Code (H.264 via 'avc1').
|
| 731 |
+
fourcc = cv2.VideoWriter_fourcc(*'avc1')
|
| 732 |
+
out = cv2.VideoWriter(output_filename, fourcc, output_fps, (frame_w, frame_h))
|
| 733 |
+
|
| 734 |
+
print(f"Processing {num_frames} frames...")
|
| 735 |
+
for i in range(num_frames):
|
| 736 |
+
vis_frame = get_visualized_frame(i)
|
| 737 |
+
out.write(vis_frame)
|
| 738 |
+
if i % 10 == 0:
|
| 739 |
+
print(f"Processed frame {i+1}/{num_frames}")
|
| 740 |
+
|
| 741 |
+
out.release()
|
| 742 |
+
print(f"Video saved as {output_filename}")
|
| 743 |
+
|
| 744 |
+
|
| 745 |
+
def list_depth(lst):
|
| 746 |
+
"""Calculates the depth of a nested list."""
|
| 747 |
+
if not (isinstance(lst, list) or isinstance(lst, torch.Tensor)):
|
| 748 |
+
return 0
|
| 749 |
+
elif (isinstance(lst, torch.Tensor) and lst.shape == torch.Size([])) or (isinstance(lst, list) and len(lst) == 0):
|
| 750 |
+
return 1
|
| 751 |
+
else:
|
| 752 |
+
return 1 + max(list_depth(item) for item in lst)
|
| 753 |
+
|
| 754 |
+
def normalize_prompt(points, labels):
|
| 755 |
+
if list_depth(points) == 3:
|
| 756 |
+
points = torch.stack([p.unsqueeze(0) for p in points])
|
| 757 |
+
labels = torch.stack([l.unsqueeze(0) for l in labels])
|
| 758 |
+
return points, labels
|
| 759 |
+
|
| 760 |
+
|
| 761 |
+
def show_box(box, ax, object_id):
|
| 762 |
+
if len(box) == 0:
|
| 763 |
+
return
|
| 764 |
+
|
| 765 |
+
cmap = plt.get_cmap("gist_rainbow")
|
| 766 |
+
cmap_idx = 0 if object_id is None else object_id
|
| 767 |
+
color = list(cmap((cmap_idx * 47) % 256))
|
| 768 |
+
|
| 769 |
+
x0, y0 = box[0], box[1]
|
| 770 |
+
w, h = box[2] - box[0], box[3] - box[1]
|
| 771 |
+
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor=color, facecolor=(0,0,0,0), lw=2))
|
| 772 |
+
|
| 773 |
+
def show_points(coords, labels, ax, object_id=None, marker_size=375):
|
| 774 |
+
if len(labels) == 0:
|
| 775 |
+
return
|
| 776 |
+
|
| 777 |
+
pos_points = coords[labels==1]
|
| 778 |
+
neg_points = coords[labels==0]
|
| 779 |
+
|
| 780 |
+
cmap = plt.get_cmap("gist_rainbow")
|
| 781 |
+
cmap_idx = 0 if object_id is None else object_id
|
| 782 |
+
color = list(cmap((cmap_idx * 47) % 256))
|
| 783 |
+
|
| 784 |
+
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='P', s=marker_size, edgecolor=color, linewidth=1.25)
|
| 785 |
+
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='s', s=marker_size, edgecolor=color, linewidth=1.25)
|
| 786 |
+
|
| 787 |
+
def save_prompts_one_image(frame_image, boxes, points, labels, save_path):
|
| 788 |
+
# Create a figure and axis
|
| 789 |
+
fig, ax = plt.subplots(1, figsize=(6, 6))
|
| 790 |
+
|
| 791 |
+
# Display the frame image
|
| 792 |
+
ax.imshow(frame_image)
|
| 793 |
+
ax.axis('off')
|
| 794 |
+
|
| 795 |
+
points, labels = normalize_prompt(points, labels)
|
| 796 |
+
if type(boxes) == torch.Tensor:
|
| 797 |
+
for object_id, box in enumerate(boxes):
|
| 798 |
+
# Add the bounding boxes
|
| 799 |
+
if not box is None:
|
| 800 |
+
show_box(box.cpu(), ax, object_id=object_id)
|
| 801 |
+
elif type(boxes) == dict:
|
| 802 |
+
for object_id, box in boxes.items():
|
| 803 |
+
# Add the bounding boxes
|
| 804 |
+
if not box is None:
|
| 805 |
+
show_box(box.cpu(), ax, object_id=object_id)
|
| 806 |
+
elif type(boxes) == list and len(boxes) == 0:
|
| 807 |
+
pass
|
| 808 |
+
else:
|
| 809 |
+
raise Exception()
|
| 810 |
+
|
| 811 |
+
for object_id, (point_ls, label_ls) in enumerate(zip(points, labels)):
|
| 812 |
+
if not len(point_ls) == 0:
|
| 813 |
+
show_points(point_ls.cpu(), label_ls.cpu(), ax, object_id=object_id)
|
| 814 |
+
|
| 815 |
+
# Show the plot
|
| 816 |
+
plt.savefig(save_path)
|
| 817 |
+
plt.close()
|
| 818 |
+
|
| 819 |
+
def save_video_prompts_visualization(video_tensor, video_boxes, video_points, video_labels, video_id, video_save_base_dir):
|
| 820 |
+
video_save_dir = os.path.join(video_save_base_dir, video_id)
|
| 821 |
+
if not os.path.exists(video_save_dir):
|
| 822 |
+
os.makedirs(video_save_dir, exist_ok=True)
|
| 823 |
+
|
| 824 |
+
for frame_id, image in enumerate(video_tensor):
|
| 825 |
+
boxes, points, labels = [], [], []
|
| 826 |
+
|
| 827 |
+
if frame_id in video_boxes:
|
| 828 |
+
boxes = video_boxes[frame_id]
|
| 829 |
+
|
| 830 |
+
if frame_id in video_points:
|
| 831 |
+
points = video_points[frame_id]
|
| 832 |
+
if frame_id in video_labels:
|
| 833 |
+
labels = video_labels[frame_id]
|
| 834 |
+
|
| 835 |
+
save_path = os.path.join(video_save_dir, f"{frame_id}.jpg")
|
| 836 |
+
save_prompts_one_image(image, boxes, points, labels, save_path)
|
| 837 |
+
|
| 838 |
+
|
| 839 |
+
def save_video_masks_visualization(video_tensor, video_masks, video_id, video_save_base_dir, oid_class_pred=None, sample_rate = 1):
|
| 840 |
+
video_save_dir = os.path.join(video_save_base_dir, video_id)
|
| 841 |
+
if not os.path.exists(video_save_dir):
|
| 842 |
+
os.makedirs(video_save_dir, exist_ok=True)
|
| 843 |
+
|
| 844 |
+
for frame_id, image in enumerate(video_tensor):
|
| 845 |
+
if random.random() > sample_rate:
|
| 846 |
+
continue
|
| 847 |
+
if frame_id not in video_masks:
|
| 848 |
+
print("No mask for Frame", frame_id)
|
| 849 |
+
continue
|
| 850 |
+
masks = video_masks[frame_id]
|
| 851 |
+
save_path = os.path.join(video_save_dir, f"{frame_id}.jpg")
|
| 852 |
+
save_mask_one_image(image, masks, save_path)
|
| 853 |
+
|
| 854 |
+
|
| 855 |
+
|
| 856 |
+
def get_color(obj_id, cmap_name="gist_rainbow",alpha=0.5):
|
| 857 |
+
cmap = plt.get_cmap(cmap_name)
|
| 858 |
+
cmap_idx = 0 if obj_id is None else obj_id
|
| 859 |
+
color = list(cmap((cmap_idx * 47) % 256))
|
| 860 |
+
color[3] = 0.5
|
| 861 |
+
color = np.array(color)
|
| 862 |
+
return color
|
| 863 |
+
|
| 864 |
+
|
| 865 |
+
def _bbox_center(bbox: Tuple[int, int, int, int]) -> Tuple[float, float]:
|
| 866 |
+
return ((bbox[0] + bbox[2]) / 2.0, (bbox[1] + bbox[3]) / 2.0)
|
| 867 |
+
|
| 868 |
+
|
| 869 |
+
def relation_line(
|
| 870 |
+
bbox1: Tuple[int, int, int, int],
|
| 871 |
+
bbox2: Tuple[int, int, int, int],
|
| 872 |
+
) -> Tuple[Tuple[int, int], Tuple[int, int]]:
|
| 873 |
+
"""
|
| 874 |
+
Returns integer pixel centers suitable for drawing a relation line. For
|
| 875 |
+
coincident boxes, nudges the target center to ensure the segment has span.
|
| 876 |
+
"""
|
| 877 |
+
center1 = _bbox_center(bbox1)
|
| 878 |
+
center2 = _bbox_center(bbox2)
|
| 879 |
+
if math.isclose(center1[0], center2[0], abs_tol=1e-3) and math.isclose(center1[1], center2[1], abs_tol=1e-3):
|
| 880 |
+
offset = max(1.0, (bbox2[2] - bbox2[0]) * 0.05)
|
| 881 |
+
center2 = (center2[0] + offset, center2[1])
|
| 882 |
+
start = (int(round(center1[0])), int(round(center1[1])))
|
| 883 |
+
end = (int(round(center2[0])), int(round(center2[1])))
|
| 884 |
+
if start == end:
|
| 885 |
+
end = (end[0] + 1, end[1])
|
| 886 |
+
return start, end
|
| 887 |
+
|
| 888 |
+
def get_binary_mask_one_image(frame_image, masks, rel_pred_ls=None):
|
| 889 |
+
# Create a figure and axis
|
| 890 |
+
fig, ax = plt.subplots(1, figsize=(6, 6))
|
| 891 |
+
|
| 892 |
+
# Display the frame image
|
| 893 |
+
ax.imshow(frame_image)
|
| 894 |
+
ax.axis('off')
|
| 895 |
+
|
| 896 |
+
all_objs_to_show = set()
|
| 897 |
+
all_lines_to_show = []
|
| 898 |
+
|
| 899 |
+
# print(rel_pred_ls[0])
|
| 900 |
+
for (from_obj_id, to_obj_id), rel_text in rel_pred_ls.items():
|
| 901 |
+
all_objs_to_show.add(from_obj_id)
|
| 902 |
+
all_objs_to_show.add(to_obj_id)
|
| 903 |
+
|
| 904 |
+
from_mask = masks[from_obj_id]
|
| 905 |
+
bbox1 = mask_to_bbox(from_mask)
|
| 906 |
+
to_mask = masks[to_obj_id]
|
| 907 |
+
bbox2 = mask_to_bbox(to_mask)
|
| 908 |
+
|
| 909 |
+
c1, c2 = shortest_line_between_bboxes(bbox1, bbox2)
|
| 910 |
+
|
| 911 |
+
line_color = get_color(from_obj_id)
|
| 912 |
+
face_color = get_color(to_obj_id)
|
| 913 |
+
line = c1, c2, face_color, line_color, rel_text
|
| 914 |
+
all_lines_to_show.append(line)
|
| 915 |
+
|
| 916 |
+
masks_to_show = {}
|
| 917 |
+
for oid in all_objs_to_show:
|
| 918 |
+
masks_to_show[oid] = masks[oid]
|
| 919 |
+
|
| 920 |
+
# Add the masks
|
| 921 |
+
for obj_id, mask in masks_to_show.items():
|
| 922 |
+
show_mask(mask, ax, obj_id=obj_id, random_color=False)
|
| 923 |
+
|
| 924 |
+
for (from_pt_x, from_pt_y), (to_pt_x, to_pt_y), face_color, line_color, rel_text in all_lines_to_show:
|
| 925 |
+
|
| 926 |
+
plt.plot([from_pt_x, to_pt_x], [from_pt_y, to_pt_y], color=line_color, linestyle='-', linewidth=3)
|
| 927 |
+
mid_pt_x = (from_pt_x + to_pt_x) / 2
|
| 928 |
+
mid_pt_y = (from_pt_y + to_pt_y) / 2
|
| 929 |
+
ax.text(
|
| 930 |
+
mid_pt_x - 5,
|
| 931 |
+
mid_pt_y,
|
| 932 |
+
rel_text,
|
| 933 |
+
color="white",
|
| 934 |
+
fontsize=6,
|
| 935 |
+
backgroundcolor=np.array(line_color),
|
| 936 |
+
bbox=dict(facecolor=face_color, edgecolor=line_color, boxstyle='round,pad=1'),
|
| 937 |
+
alpha=1
|
| 938 |
+
)
|
| 939 |
+
|
| 940 |
+
# Show the plot
|
| 941 |
+
return fig, ax
|