KevinX-Penn28 commited on
Commit
7a409cf
·
verified ·
1 Parent(s): 3ca196f

Upload VINE model - model

Browse files
Files changed (5) hide show
  1. config.json +6 -2
  2. flattening.py +124 -0
  3. model.safetensors +3 -0
  4. vine_model.py +658 -0
  5. vis_utils.py +941 -0
config.json CHANGED
@@ -1,9 +1,12 @@
1
  {
2
- "_attn_implementation_autoset": true,
3
  "_device": "cuda",
4
  "alpha": 0.5,
 
 
 
5
  "auto_map": {
6
- "AutoConfig": "vine_config.VineConfig"
 
7
  },
8
  "bbox_min_dim": 5,
9
  "box_threshold": 0.35,
@@ -23,6 +26,7 @@
23
  "target_fps": 1,
24
  "text_threshold": 0.25,
25
  "topk_cate": 3,
 
26
  "transformers_version": "4.46.2",
27
  "visualization_dir": null,
28
  "visualize": false,
 
1
  {
 
2
  "_device": "cuda",
3
  "alpha": 0.5,
4
+ "architectures": [
5
+ "VineModel"
6
+ ],
7
  "auto_map": {
8
+ "AutoConfig": "vine_config.VineConfig",
9
+ "AutoModel": "vine_model.VineModel"
10
  },
11
  "bbox_min_dim": 5,
12
  "box_threshold": 0.35,
 
26
  "target_fps": 1,
27
  "text_threshold": 0.25,
28
  "topk_cate": 3,
29
+ "torch_dtype": "float32",
30
  "transformers_version": "4.46.2",
31
  "visualization_dir": null,
32
  "visualize": false,
flattening.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from collections import defaultdict
4
+ from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+
10
+ MaskType = Union[np.ndarray, torch.Tensor]
11
+
12
+
13
+ def _to_numpy_mask(mask: MaskType) -> np.ndarray:
14
+ """
15
+ Convert assorted mask formats to a 2D numpy boolean array.
16
+ """
17
+ if isinstance(mask, torch.Tensor):
18
+ mask_np = mask.detach().cpu().numpy()
19
+ else:
20
+ mask_np = np.asarray(mask)
21
+
22
+ # Remove singleton dimensions at the front/back
23
+ while mask_np.ndim > 2 and mask_np.shape[0] == 1:
24
+ mask_np = np.squeeze(mask_np, axis=0)
25
+ if mask_np.ndim > 2 and mask_np.shape[-1] == 1:
26
+ mask_np = np.squeeze(mask_np, axis=-1)
27
+
28
+ if mask_np.ndim != 2:
29
+ raise ValueError(f"Expected mask to be 2D after squeezing, got shape {mask_np.shape}")
30
+
31
+ return mask_np.astype(bool)
32
+
33
+
34
+ def _mask_to_bbox(mask: np.ndarray) -> Optional[Tuple[int, int, int, int]]:
35
+ """
36
+ Compute a bounding box for a 2D boolean mask.
37
+ """
38
+ if not mask.any():
39
+ return None
40
+ rows, cols = np.nonzero(mask)
41
+ y_min, y_max = rows.min(), rows.max()
42
+ x_min, x_max = cols.min(), cols.max()
43
+ return x_min, y_min, x_max, y_max
44
+
45
+
46
+ def flatten_segments_for_batch(
47
+ video_id: int,
48
+ segments: Dict[int, Dict[int, MaskType]],
49
+ bbox_min_dim: int = 5,
50
+ ) -> Dict[str, List]:
51
+ """
52
+ Flatten nested segmentation data into batched lists suitable for predicate
53
+ models or downstream visualizations. Mirrors the notebook helper but is
54
+ robust to differing mask dtypes/shapes.
55
+ """
56
+ batched_object_ids: List[Tuple[int, int, int]] = []
57
+ batched_masks: List[np.ndarray] = []
58
+ batched_bboxes: List[Tuple[int, int, int, int]] = []
59
+ frame_pairs: List[Tuple[int, int, Tuple[int, int]]] = []
60
+
61
+ for frame_id, frame_objects in segments.items():
62
+ valid_objects: List[int] = []
63
+ for object_id, raw_mask in frame_objects.items():
64
+ mask = _to_numpy_mask(raw_mask)
65
+ bbox = _mask_to_bbox(mask)
66
+ if bbox is None:
67
+ continue
68
+
69
+ x_min, y_min, x_max, y_max = bbox
70
+ if abs(y_max - y_min) < bbox_min_dim or abs(x_max - x_min) < bbox_min_dim:
71
+ continue
72
+
73
+ valid_objects.append(object_id)
74
+ batched_object_ids.append((video_id, frame_id, object_id))
75
+ batched_masks.append(mask)
76
+ batched_bboxes.append(bbox)
77
+
78
+ for i in valid_objects:
79
+ for j in valid_objects:
80
+ if i == j:
81
+ continue
82
+ frame_pairs.append((video_id, frame_id, (i, j)))
83
+
84
+ return {
85
+ "object_ids": batched_object_ids,
86
+ "masks": batched_masks,
87
+ "bboxes": batched_bboxes,
88
+ "pairs": frame_pairs,
89
+ }
90
+
91
+
92
+ def extract_valid_object_pairs(
93
+ batched_object_ids: Sequence[Tuple[int, int, int]],
94
+ interested_object_pairs: Optional[Iterable[Tuple[int, int]]] = None,
95
+ ) -> List[Tuple[int, int, Tuple[int, int]]]:
96
+ """
97
+ Filter object pairs per frame. If `interested_object_pairs` is provided, only
98
+ emit those combinations when both objects are present; otherwise emit all
99
+ permutations (i, j) with i != j for each frame.
100
+ """
101
+ frame_to_objects: Dict[Tuple[int, int], set] = defaultdict(set)
102
+ for vid, fid, oid in batched_object_ids:
103
+ frame_to_objects[(vid, fid)].add(oid)
104
+
105
+ interested = (
106
+ list(interested_object_pairs)
107
+ if interested_object_pairs is not None
108
+ else None
109
+ )
110
+
111
+ valid_pairs: List[Tuple[int, int, Tuple[int, int]]] = []
112
+ for (vid, fid), object_ids in frame_to_objects.items():
113
+ if interested:
114
+ for src, dst in interested:
115
+ if src in object_ids and dst in object_ids:
116
+ valid_pairs.append((vid, fid, (src, dst)))
117
+ else:
118
+ for src in object_ids:
119
+ for dst in object_ids:
120
+ if src == dst:
121
+ continue
122
+ valid_pairs.append((vid, fid, (src, dst)))
123
+
124
+ return valid_pairs
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c91c273c5f61b7f17fc6cc265e14bb78ed134c71d7b54611208420fcbe4f81de
3
+ size 1815491340
vine_model.py ADDED
@@ -0,0 +1,658 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flax import config
2
+ import torch
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+ import torch.utils.checkpoint as cp
6
+ from transformers import PreTrainedModel, AutoTokenizer, AutoModel, AutoProcessor
7
+ from typing import Dict, List, Tuple, Optional, Any, Union
8
+ import numpy as np
9
+ import os
10
+ import cv2
11
+ from collections import defaultdict
12
+ import builtins
13
+ import sys
14
+ from laser.models import llava_clip_model_v3
15
+ sys.modules["llava_clip_model_v3"] = llava_clip_model_v3
16
+ import inspect
17
+ from transformers.models.clip import modeling_clip
18
+ import transformers
19
+
20
+
21
+
22
+
23
+ from .vine_config import VineConfig
24
+ from laser.models.model_utils import (
25
+ extract_single_object,
26
+ extract_object_subject,
27
+ crop_image_contain_bboxes,
28
+ segment_list
29
+ )
30
+ from .flattening import (
31
+ extract_valid_object_pairs,
32
+ flatten_segments_for_batch,
33
+ )
34
+
35
+ from .vis_utils import save_mask_one_image
36
+
37
+ class VineModel(PreTrainedModel):
38
+ """
39
+ VINE (Video Understanding with Natural Language) Model
40
+
41
+ This model processes videos along with categorical, unary, and binary keywords
42
+ to return probability distributions over those keywords for detected objects
43
+ and their relationships in the video.
44
+ """
45
+
46
+ config_class = VineConfig
47
+
48
+ def __init__(self, config: VineConfig):
49
+ super().__init__(config)
50
+
51
+ self.config = config
52
+ self.visualize = getattr(config, "visualize", False)
53
+ self.visualization_dir = getattr(config, "visualization_dir", None)
54
+ self.debug_visualizations = getattr(config, "debug_visualizations", False)
55
+ self._device = getattr(config, "_device")
56
+
57
+
58
+
59
+ # Initialize CLIP components
60
+
61
+ self.clip_tokenizer = AutoTokenizer.from_pretrained(config.model_name)
62
+ if self.clip_tokenizer.pad_token is None:
63
+ self.clip_tokenizer.pad_token = (
64
+ self.clip_tokenizer.unk_token
65
+ if self.clip_tokenizer.unk_token
66
+ else self.clip_tokenizer.eos_token
67
+ )
68
+ self.clip_processor = AutoProcessor.from_pretrained(config.model_name)
69
+ self.clip_cate_model = AutoModel.from_pretrained(config.model_name)
70
+ self.clip_unary_model = AutoModel.from_pretrained(config.model_name)
71
+ self.clip_binary_model = AutoModel.from_pretrained(config.model_name)
72
+
73
+
74
+ # Then try to load pretrained VINE weights if specified
75
+ if config.pretrained_vine_path:
76
+ self._load_pretrained_vine_weights(config.pretrained_vine_path)
77
+
78
+ # Move models to devicexwxw
79
+ self.to(self._device)
80
+
81
+ def _load_pretrained_vine_weights(self, pretrained_path: str, epoch: int = 0):
82
+ """
83
+ Load pretrained VINE weights from a saved .pt file or ensemble format.
84
+ """
85
+ #try: # simple .pt or .pth checkpoint
86
+
87
+ # x = torch.load(pretrained_path, map_location=self._device, weights_only=False)
88
+ # print(f"Loaded VINE checkpoint type: {type(x)}")
89
+ if pretrained_path == "video-fm/vine_v0":
90
+ self.clip_tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
91
+ self.clip_cate_model = AutoModel.from_pretrained(pretrained_path)
92
+ self.clip_unary_model = AutoModel.from_pretrained(pretrained_path)
93
+ self.clip_binary_model = AutoModel.from_pretrained(pretrained_path)
94
+
95
+ if pretrained_path.endswith(".pkl"):
96
+ print(f"Loading VINE weights from: {pretrained_path}")
97
+ loaded_vine_model = torch.load(pretrained_path, map_location=self._device, weights_only=False)
98
+
99
+ print(f"Loaded state type: {type(loaded_vine_model)}")
100
+ if not isinstance(loaded_vine_model, dict):
101
+ if hasattr(loaded_vine_model, 'clip_cate_model'):
102
+ self.clip_cate_model.load_state_dict(loaded_vine_model.clip_cate_model.state_dict())
103
+ if hasattr(loaded_vine_model, 'clip_unary_model'):
104
+ self.clip_unary_model.load_state_dict(loaded_vine_model.clip_unary_model.state_dict())
105
+ if hasattr(loaded_vine_model, 'clip_binary_model'):
106
+ self.clip_binary_model.load_state_dict(loaded_vine_model.clip_binary_model.state_dict())
107
+ return True
108
+
109
+ elif pretrained_path.endswith(".pt") or pretrained_path.endswith(".pth"):
110
+ state = torch.load(pretrained_path, map_location=self._device, weights_only=True)
111
+ print(f"Loaded state type: {type(state)}")
112
+ self.load_state_dict(state)
113
+ return True
114
+
115
+ # handle directory + epoch format
116
+ if os.path.isdir(pretrained_path):
117
+ model_files = [f for f in os.listdir(pretrained_path) if f.endswith(f'.{epoch}.model')]
118
+ if model_files:
119
+ model_file = os.path.join(pretrained_path, model_files[0])
120
+ print(f"Loading VINE weights from: {model_file}")
121
+ pretrained_model = torch.load(model_file, map_location="cpu")
122
+
123
+ # Conversion from PredicateModel-like object to VineModel
124
+ # Only copy if attributes exist
125
+ if hasattr(pretrained_model, 'clip_cate_model'):
126
+ self.clip_cate_model.load_state_dict(pretrained_model.clip_cate_model.state_dict())
127
+ if hasattr(pretrained_model, 'clip_unary_model'):
128
+ self.clip_unary_model.load_state_dict(pretrained_model.clip_unary_model.state_dict())
129
+ if hasattr(pretrained_model, 'clip_binary_model'):
130
+ self.clip_binary_model.load_state_dict(pretrained_model.clip_binary_model.state_dict())
131
+ print("✓ Loaded all sub-model weights from ensemble format")
132
+ return True
133
+ else:
134
+ print(f"No model file found for epoch {epoch} in {pretrained_path}")
135
+ return False
136
+
137
+ print("Unsupported format for pretrained_vine_path")
138
+ return False
139
+
140
+ # except Exception as e:
141
+ # print(f"✗ Error loading VINE weights: {e}")
142
+ # print("Using base CLIP models instead")
143
+ # return False
144
+
145
+
146
+
147
+ # def _load_pretrained_vine_weights(self, pretrained_path: str, epoch: int = 0):
148
+ # """
149
+ # Load pretrained VINE weights from local ensemble format.
150
+
151
+ # Args:
152
+ # pretrained_path: Path to the pretrained model directory or HF model name
153
+ # epoch: Epoch number to load (for ensemble format)
154
+ # """
155
+ # if pretrained_path == "video-fm/vine_v0":
156
+ # # Try to load from HuggingFace Hubtry:
157
+ # # ✅ TODO FIXED: Added support for loading .pt/.pth checkpoints with state dicts
158
+ # if pretrained_path.endswith(".pt") or pretrained_path.endswith(".pth"):
159
+ # print(f"Loading VINE weights from: {pretrained_path}")
160
+ # state = torch.load(pretrained_path, map_location="cpu")
161
+
162
+ # if "clip_cate_model" in state:
163
+ # self.clip_cate_model.load_state_dict(state["clip_cate_model"])
164
+ # print("✓ Loaded categorical model weights")
165
+ # if "clip_unary_model" in state:
166
+ # self.clip_unary_model.load_state_dict(state["clip_unary_model"])
167
+ # print("✓ Loaded unary model weights")
168
+ # if "clip_binary_model" in state:
169
+ # self.clip_binary_model.load_state_dict(state["clip_binary_model"])
170
+ # print("✓ Loaded binary model weights")
171
+
172
+ # if "clip_tokenizer" in state:
173
+ # self.clip_tokenizer = state["clip_tokenizer"]
174
+ # print("✓ Loaded tokenizer")
175
+ # if "clip_processor" in state:
176
+ # self.clip_processor = state["clip_processor"]
177
+ # print("✓ Loaded processor")
178
+
179
+ # print("✓ All VINE weights loaded successfully")
180
+ # return True
181
+
182
+ # # Load from local ensemble format
183
+ # try:
184
+ # if os.path.isdir(pretrained_path):
185
+ # # Directory format - look for ensemble file
186
+ # model_files = [f for f in os.listdir(pretrained_path) if f.endswith(f'.{epoch}.model')]
187
+ # if model_files:
188
+ # model_file = os.path.join(pretrained_path, model_files[0])
189
+ # else:
190
+ # print(f"No model file found for epoch {epoch} in {pretrained_path}")
191
+ # return False
192
+ # else:
193
+ # # Direct file path
194
+ # model_file = pretrained_path
195
+
196
+ # print(f"Loading VINE weights from: {model_file}")
197
+
198
+ # # Load the ensemble model (PredicateModel instance)
199
+ # # TODO: conversion from PredicateModel to VineModel
200
+ # pretrained_model = torch.load(model_file, map_location='cpu', weights_only=False)
201
+
202
+ # # Transfer weights from the pretrained model to our HuggingFace models
203
+ # if hasattr(pretrained_model, 'clip_cate_model'):
204
+ # self.clip_cate_model.load_state_dict(pretrained_model.clip_cate_model.state_dict())
205
+ # print("✓ Loaded categorical model weights")
206
+
207
+ # if hasattr(pretrained_model, 'clip_unary_model'):
208
+ # self.clip_unary_model.load_state_dict(pretrained_model.clip_unary_model.state_dict())
209
+ # print("✓ Loaded unary model weights")
210
+
211
+ # if hasattr(pretrained_model, 'clip_binary_model'):
212
+ # self.clip_binary_model.load_state_dict(pretrained_model.clip_binary_model.state_dict())
213
+ # print("✓ Loaded binary model weights")
214
+
215
+ # # Also transfer tokenizer and processor if available
216
+ # if hasattr(pretrained_model, 'clip_tokenizer'):
217
+ # self.clip_tokenizer = pretrained_model.clip_tokenizer
218
+ # print("✓ Loaded tokenizer")
219
+
220
+ # if hasattr(pretrained_model, 'clip_processor'):
221
+ # self.clip_processor = pretrained_model.clip_processor
222
+ # print("✓ Loaded processor")
223
+
224
+ # print("✓ Successfully loaded all VINE weights")
225
+ # return True
226
+
227
+ # except Exception as e:
228
+ # print(f"✗ Error loading VINE weights: {e}")
229
+ # print("Using base CLIP models instead")
230
+ # return False
231
+
232
+ @classmethod
233
+ def from_pretrained_vine(
234
+ cls,
235
+ model_path: str,
236
+ config: Optional[VineConfig] = None,
237
+ epoch: int = 0,
238
+ **kwargs
239
+ ):
240
+ """
241
+ Create VineModel from pretrained VINE weights.
242
+
243
+ Args:
244
+ model_path: Path to pretrained VINE model
245
+ config: Optional config, will create default if None
246
+ epoch: Epoch number to load
247
+ **kwargs: Additional arguments
248
+
249
+ Returns:
250
+ VineModel instance with loaded weights
251
+ """
252
+ if config is None:
253
+ config = VineConfig(pretrained_vine_path=model_path)
254
+ else:
255
+ config.pretrained_vine_path = model_path
256
+
257
+ # Create model instance (will automatically load weights)
258
+ model = cls(config, **kwargs)
259
+
260
+ return model
261
+
262
+ def _text_features_checkpoint(self, model, tokens):
263
+ """Extract text features with gradient checkpointing."""
264
+ token_keys = list(tokens.keys())
265
+
266
+ def get_text_features_wrapped(*inputs):
267
+ kwargs = {key: value for key, value in zip(token_keys, inputs)}
268
+ return model.get_text_features(**kwargs)
269
+
270
+ token_values = [tokens[key] for key in token_keys]
271
+ return cp.checkpoint(get_text_features_wrapped, *token_values, use_reentrant=False)
272
+
273
+ def _image_features_checkpoint(self, model, images):
274
+ """Extract image features with gradient checkpointing."""
275
+ return cp.checkpoint(model.get_image_features, images, use_reentrant=False)
276
+
277
+ def clip_sim(self, model, nl_feat, img_feat):
278
+ img_feat = img_feat / img_feat.norm(p=2, dim=-1, keepdim=True)
279
+ nl_feat = nl_feat / nl_feat.norm(p=2, dim=-1, keepdim=True)
280
+ logits = torch.matmul(img_feat, nl_feat.T)
281
+ if hasattr(model, "logit_scale"):
282
+ logits = logits * model.logit_scale.exp()
283
+ return logits
284
+
285
+ def forward(
286
+ self,
287
+ video_frames: torch.Tensor,
288
+ masks: Dict[int, Dict[int, torch.Tensor]],
289
+ bboxes: Dict[int, Dict[int, List]],
290
+ categorical_keywords: List[str],
291
+ unary_keywords: Optional[List[str]] = None,
292
+ binary_keywords: Optional[List[str]] = None,
293
+ object_pairs: Optional[List[Tuple[int, int]]] = None,
294
+ return_flattened_segments: Optional[bool] = None,
295
+ return_valid_pairs: Optional[bool] = None,
296
+ interested_object_pairs: Optional[List[Tuple[int, int]]] = None,
297
+ debug_visualizations: Optional[bool] = None,
298
+ **kwargs
299
+ ) -> Dict[str, Any]:
300
+ """
301
+ Forward pass of the VINE model.
302
+
303
+ Args:
304
+ video_frames: Tensor of shape (num_frames, height, width, 3)
305
+ masks: Dict mapping frame_id -> object_id -> mask tensor
306
+ bboxes: Dict mapping frame_id -> object_id -> [x1, y1, x2, y2]
307
+ categorical_keywords: List of category names to classify objects
308
+ unary_keywords: Optional list of unary predicates (actions on single objects)
309
+ binary_keywords: Optional list of binary predicates (relations between objects)
310
+ object_pairs: Optional list of (obj1_id, obj2_id) pairs for binary classification
311
+
312
+ Returns:
313
+ Dict containing probability distributions for categorical, unary, and binary predictions
314
+ """
315
+ if unary_keywords is None:
316
+ unary_keywords = []
317
+ if binary_keywords is None:
318
+ binary_keywords = []
319
+ if object_pairs is None:
320
+ object_pairs = []
321
+ if return_flattened_segments is None:
322
+ return_flattened_segments = self.config.return_flattened_segments
323
+ if return_valid_pairs is None:
324
+ return_valid_pairs = self.config.return_valid_pairs
325
+ if interested_object_pairs is None or len(interested_object_pairs) == 0:
326
+ interested_object_pairs = getattr(self.config, "interested_object_pairs", []) or []
327
+ if debug_visualizations is None:
328
+ debug_visualizations = self.debug_visualizations
329
+
330
+ # Prepare dummy strings for empty categories
331
+ dummy_str = ""
332
+
333
+ # Fill empty categories with dummy strings
334
+ if len(categorical_keywords) == 0:
335
+ categorical_keywords = [dummy_str]
336
+ if len(unary_keywords) == 0:
337
+ unary_keywords = [dummy_str]
338
+ if len(binary_keywords) == 0:
339
+ binary_keywords = [dummy_str]
340
+
341
+ # Extract text features for all keyword types
342
+ categorical_features = self._extract_text_features(
343
+ self.clip_cate_model, categorical_keywords
344
+ )
345
+ unary_features = self._extract_text_features(
346
+ self.clip_unary_model, unary_keywords
347
+ )
348
+ binary_features = self._extract_text_features(
349
+ self.clip_binary_model, binary_keywords
350
+ )
351
+
352
+ # Process video frames and extract object features
353
+ categorical_probs = {}
354
+ unary_probs = {}
355
+ binary_probs = {}
356
+
357
+ # Process each frame
358
+ for frame_id, frame_masks in masks.items():
359
+ if frame_id >= len(video_frames):
360
+ continue
361
+
362
+ frame = self._frame_to_numpy(video_frames[frame_id])
363
+ frame_bboxes = bboxes.get(frame_id, {})
364
+
365
+ # Extract object features for categorical classification
366
+ for obj_id, mask in frame_masks.items():
367
+ if obj_id not in frame_bboxes:
368
+ continue
369
+
370
+ bbox = frame_bboxes[obj_id]
371
+
372
+ # Extract single object image
373
+ mask_np = self._mask_to_numpy(mask)
374
+
375
+ obj_image = extract_single_object(
376
+ frame, mask_np, alpha=self.config.alpha
377
+ )
378
+
379
+ # Get image features
380
+ obj_features = self._extract_image_features(
381
+ self.clip_cate_model, obj_image
382
+ )
383
+
384
+ # Compute similarities for categorical classification
385
+ cat_similarities = self.clip_sim(
386
+ self.clip_cate_model, categorical_features, obj_features
387
+ )
388
+ cat_probs = F.softmax(cat_similarities, dim=-1)
389
+
390
+ # Store categorical predictions
391
+ for i, keyword in enumerate(categorical_keywords):
392
+ if keyword != dummy_str:
393
+ categorical_probs[(obj_id, keyword)] = cat_probs[0, i].item()
394
+
395
+ # Compute unary predictions
396
+ if len(unary_keywords) > 0 and unary_keywords[0] != dummy_str:
397
+ unary_similarities = self.clip_sim(
398
+ self.clip_unary_model, unary_features, obj_features
399
+ )
400
+ unary_probs_tensor = F.softmax(unary_similarities, dim=-1)
401
+
402
+ for i, keyword in enumerate(unary_keywords):
403
+ if keyword != dummy_str:
404
+ unary_probs[(frame_id, obj_id, keyword)] = unary_probs_tensor[0, i].item()
405
+
406
+ # Process binary relationships
407
+ if len(binary_keywords) > 0 and binary_keywords[0] != dummy_str and len(object_pairs) > 0:
408
+ for obj1_id, obj2_id in object_pairs:
409
+ for frame_id, frame_masks in masks.items():
410
+ if frame_id >= len(video_frames):
411
+ continue
412
+ if (obj1_id in frame_masks and obj2_id in frame_masks and
413
+ obj1_id in bboxes.get(frame_id, {}) and obj2_id in bboxes.get(frame_id, {})):
414
+
415
+ frame = self._frame_to_numpy(video_frames[frame_id])
416
+ mask1 = frame_masks[obj1_id]
417
+ mask2 = frame_masks[obj2_id]
418
+
419
+ mask1_np = self._mask_to_numpy(mask1)
420
+ mask2_np = self._mask_to_numpy(mask2)
421
+
422
+ # Extract object pair image
423
+ pair_image = extract_object_subject(
424
+ frame, mask1_np[..., None], mask2_np[..., None],
425
+ alpha=self.config.alpha,
426
+ white_alpha=self.config.white_alpha
427
+ )
428
+
429
+ # Crop to contain both objects
430
+ bbox1 = bboxes[frame_id][obj1_id]
431
+ bbox2 = bboxes[frame_id][obj2_id]
432
+
433
+ # Bounding box overlap check
434
+ if bbox1[0] >= bbox2[2] or bbox2[1] >= bbox1[3] or \
435
+ bbox2[0] >= bbox1[2] or bbox1[1] >= bbox2[3]:
436
+ continue
437
+
438
+ cropped_image = crop_image_contain_bboxes(
439
+ pair_image, [bbox1, bbox2], f"frame_{frame_id}"
440
+ )
441
+
442
+ # Get image features
443
+ pair_features = self._extract_image_features(
444
+ self.clip_binary_model, cropped_image
445
+ )
446
+
447
+ # Compute similarities for binary classification
448
+ binary_similarities = self.clip_sim(
449
+ self.clip_binary_model, binary_features, pair_features
450
+ )
451
+ binary_probs_tensor = F.softmax(binary_similarities, dim=-1)
452
+
453
+ for i, keyword in enumerate(binary_keywords):
454
+ if keyword != dummy_str:
455
+ binary_probs[(frame_id, (obj1_id, obj2_id), keyword)] = binary_probs_tensor[0, i].item()
456
+
457
+ # Calculate dummy probability (for compatibility)
458
+ dummy_prob = 1.0 / max(len(categorical_keywords), len(unary_keywords), len(binary_keywords))
459
+
460
+ result: Dict[str, Any] = {
461
+ "categorical_probs": {0: categorical_probs}, # Video ID 0
462
+ "unary_probs": {0: unary_probs},
463
+ "binary_probs": [binary_probs], # List format for compatibility
464
+ "dummy_prob": dummy_prob
465
+ }
466
+
467
+ if return_flattened_segments or return_valid_pairs:
468
+ flattened = flatten_segments_for_batch(
469
+ video_id=0,
470
+ segments=masks,
471
+ bbox_min_dim=self.config.bbox_min_dim,
472
+ )
473
+ if return_flattened_segments:
474
+ result["flattened_segments"] = flattened
475
+ if return_valid_pairs:
476
+ interested_pairs = interested_object_pairs if interested_object_pairs else None
477
+ result["valid_pairs"] = extract_valid_object_pairs(
478
+ flattened["object_ids"],
479
+ interested_pairs,
480
+ )
481
+ if interested_pairs is None:
482
+ # Provide all generated pairs for clarity when auto-generated.
483
+ result["valid_pairs_metadata"] = {"pair_source": "all_pairs"}
484
+ else:
485
+ result["valid_pairs_metadata"] = {"pair_source": "filtered", "requested_pairs": interested_pairs}
486
+
487
+ return result
488
+
489
+ def _frame_to_numpy(self, frame: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
490
+ """Convert a frame tensor/array to a contiguous numpy array."""
491
+ if torch.is_tensor(frame):
492
+ frame_np = frame.detach().cpu().numpy()
493
+ else:
494
+ frame_np = np.asarray(frame)
495
+ return np.ascontiguousarray(frame_np)
496
+
497
+ def _mask_to_numpy(self, mask: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
498
+ """Convert a mask tensor/array to a 2D boolean numpy array."""
499
+ if torch.is_tensor(mask):
500
+ mask_np = mask.detach().cpu().numpy()
501
+ else:
502
+ mask_np = np.asarray(mask)
503
+
504
+ if mask_np.ndim == 3:
505
+ if mask_np.shape[0] == 1:
506
+ mask_np = mask_np.squeeze(0)
507
+ elif mask_np.shape[2] == 1:
508
+ mask_np = mask_np.squeeze(2)
509
+
510
+ if mask_np.ndim != 2:
511
+ raise ValueError(f"Mask must be 2D after squeezing, got shape {mask_np.shape}")
512
+
513
+ return mask_np.astype(bool, copy=False)
514
+
515
+ def _extract_text_features(self, model, keywords):
516
+ """Extract text features for given keywords."""
517
+ tokens = self.clip_tokenizer(
518
+ keywords,
519
+ return_tensors="pt",
520
+ max_length=75,
521
+ truncation=True,
522
+ padding='max_length'
523
+ ).to(self._device)
524
+
525
+ return self._text_features_checkpoint(model, tokens)
526
+
527
+ def _extract_image_features(self, model, image):
528
+ """Extract image features for given image."""
529
+ # Ensure image is in correct format
530
+ if isinstance(image, np.ndarray):
531
+ if image.dtype != np.uint8:
532
+ image = image.astype(np.uint8)
533
+ # Convert BGR to RGB if needed
534
+ if len(image.shape) == 3 and image.shape[2] == 3:
535
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
536
+
537
+ # Process image with CLIP processor
538
+ inputs = self.clip_processor(
539
+ images=image,
540
+ return_tensors="pt"
541
+ ).to(self._device)
542
+
543
+ return self._image_features_checkpoint(model, inputs['pixel_values'])
544
+ #TODO: return masks and bboxes and their corresponding index
545
+ def predict(
546
+ self,
547
+ video_frames: torch.Tensor,
548
+ masks: Dict[int, Dict[int, torch.Tensor]],
549
+ bboxes: Dict[int, Dict[int, List]],
550
+ categorical_keywords: List[str],
551
+ unary_keywords: Optional[List[str]] = None,
552
+ binary_keywords: Optional[List[str]] = None,
553
+ object_pairs: Optional[List[Tuple[int, int]]] = None,
554
+ return_top_k: int = 3,
555
+ return_flattened_segments: Optional[bool] = None,
556
+ return_valid_pairs: Optional[bool] = None,
557
+ interested_object_pairs: Optional[List[Tuple[int, int]]] = None,
558
+ debug_visualizations: Optional[bool] = None,
559
+ ) -> Dict[str, Any]:
560
+ """
561
+ High-level prediction method that returns formatted results.
562
+
563
+ Args:
564
+ video_frames: Tensor of shape (num_frames, height, width, 3)
565
+ masks: Dict mapping frame_id -> object_id -> mask tensor
566
+ bboxes: Dict mapping frame_id -> object_id -> [x1, y1, x2, y2]
567
+ categorical_keywords: List of category names
568
+ unary_keywords: Optional list of unary predicates
569
+ binary_keywords: Optional list of binary predicates
570
+ object_pairs: Optional list of object pairs for binary relations
571
+ return_top_k: Number of top predictions to return
572
+ return_flattened_segments: Whether to include flattened mask/bbox tensors
573
+ return_valid_pairs: Whether to compute valid object pairs per frame
574
+ interested_object_pairs: Optional subset of object pairs to track
575
+
576
+ Returns:
577
+ Formatted prediction results
578
+ """
579
+
580
+ with torch.no_grad():
581
+ outputs = self.forward(
582
+ video_frames=video_frames,
583
+ masks=masks,
584
+ bboxes=bboxes,
585
+ categorical_keywords=categorical_keywords,
586
+ unary_keywords=unary_keywords,
587
+ binary_keywords=binary_keywords,
588
+ object_pairs=object_pairs,
589
+ return_flattened_segments=return_flattened_segments,
590
+ return_valid_pairs=return_valid_pairs,
591
+ interested_object_pairs=interested_object_pairs,
592
+ debug_visualizations=debug_visualizations,
593
+ )
594
+
595
+ # Format categorical results
596
+ formatted_categorical = {}
597
+ for (obj_id, category), prob in outputs["categorical_probs"][0].items():
598
+ if obj_id not in formatted_categorical:
599
+ formatted_categorical[obj_id] = []
600
+ formatted_categorical[obj_id].append((prob, category))
601
+
602
+ # Sort and take top-k for each object
603
+ for obj_id in formatted_categorical:
604
+ formatted_categorical[obj_id] = sorted(
605
+ formatted_categorical[obj_id], reverse=True
606
+ )[:return_top_k]
607
+
608
+ # Format unary results
609
+ formatted_unary = {}
610
+ for (frame_id, obj_id, predicate), prob in outputs["unary_probs"][0].items():
611
+ key = (frame_id, obj_id)
612
+ if key not in formatted_unary:
613
+ formatted_unary[key] = []
614
+ formatted_unary[key].append((prob, predicate))
615
+
616
+ # Sort and take top-k
617
+ for key in formatted_unary:
618
+ formatted_unary[key] = sorted(
619
+ formatted_unary[key], reverse=True
620
+ )[:return_top_k]
621
+
622
+ # Format binary results
623
+ formatted_binary = {}
624
+ if len(outputs["binary_probs"]) > 0:
625
+ for (frame_id, obj_pair, predicate), prob in outputs["binary_probs"][0].items():
626
+ key = (frame_id, obj_pair)
627
+ if key not in formatted_binary:
628
+ formatted_binary[key] = []
629
+ formatted_binary[key].append((prob, predicate))
630
+
631
+ # Sort and take top-k
632
+ for key in formatted_binary:
633
+ formatted_binary[key] = sorted(
634
+ formatted_binary[key], reverse=True
635
+ )[:return_top_k]
636
+
637
+ result: Dict[str, Any] = {
638
+ "categorical_predictions": formatted_categorical,
639
+ "unary_predictions": formatted_unary,
640
+ "binary_predictions": formatted_binary,
641
+ "confidence_scores": {
642
+ "categorical": max([max([p for p, _ in preds], default=0.0)
643
+ for preds in formatted_categorical.values()], default=0.0),
644
+ "unary": max([max([p for p, _ in preds], default=0.0)
645
+ for preds in formatted_unary.values()], default=0.0),
646
+ "binary": max([max([p for p, _ in preds], default=0.0)
647
+ for preds in formatted_binary.values()], default=0.0)
648
+ }
649
+ }
650
+
651
+ if "flattened_segments" in outputs:
652
+ result["flattened_segments"] = outputs["flattened_segments"]
653
+ if "valid_pairs" in outputs:
654
+ result["valid_pairs"] = outputs["valid_pairs"]
655
+ if "valid_pairs_metadata" in outputs:
656
+ result["valid_pairs_metadata"] = outputs["valid_pairs_metadata"]
657
+
658
+ return result
vis_utils.py ADDED
@@ -0,0 +1,941 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import torch
6
+ import random
7
+ import math
8
+ from matplotlib.patches import Rectangle
9
+ import itertools
10
+ from typing import Any, Dict, List, Tuple, Optional, Union
11
+
12
+ from laser.preprocess.mask_generation_grounding_dino import mask_to_bbox
13
+
14
+ ########################################################################################
15
+ ########## Visualization Library ########
16
+ ########################################################################################
17
+ # This module renders SAM masks, GroundingDINO boxes, and VINE predictions.
18
+ #
19
+ # Conventions (RGB frames, pixel coords):
20
+ # - Frames: list[np.ndarray] with shape (H, W, 3) in RGB, or np.ndarray with shape (T, H, W, 3).
21
+ # - Masks: 2D boolean arrays (H, W) or tensors convertible to that; (H, W, 1) is also accepted.
22
+ # - BBoxes: (x1, y1, x2, y2) integer pixel coordinates with x2 > x1 and y2 > y1.
23
+ #
24
+ # Per-frame stores use one of:
25
+ # - Dict[int(frame_id) -> Dict[int(obj_id) -> value]]
26
+ # - List indexed by frame_id (each item may be a dict of obj_id->value or a list in order)
27
+ #
28
+ # Renderer inputs/outputs:
29
+ # 1) render_sam_frames(frames, sam_masks, dino_labels=None) -> List[np.ndarray]
30
+ # - sam_masks: Dict[frame_id, Dict[obj_id, Mask]] or a list; Mask can be np.ndarray or torch.Tensor.
31
+ # - dino_labels: Optional Dict[obj_id, str] to annotate boxes derived from masks.
32
+ #
33
+ # 2) render_dino_frames(frames, bboxes, dino_labels=None) -> List[np.ndarray]
34
+ # - bboxes: Dict[frame_id, Dict[obj_id, Sequence[float]]] or a list; each bbox as [x1, y1, x2, y2].
35
+ #
36
+ # 3) render_vine_frames(frames, bboxes, cat_label_lookup, unary_lookup, binary_lookup, masks=None)
37
+ # -> List[np.ndarray] (the "all" view)
38
+ # - cat_label_lookup: Dict[obj_id, (label: str, prob: float)]
39
+ # - unary_lookup: Dict[frame_id, Dict[obj_id, List[(prob: float, label: str)]]]
40
+ # - binary_lookup: Dict[frame_id, List[((sub_id: int, obj_id: int), List[(prob: float, relation: str)])]]
41
+ # - masks: Optional; same structure as sam_masks, used for translucent overlays when unary labels exist.
42
+ #
43
+ # Ground-truth helpers used by plotting utilities:
44
+ # - For a single frame, gt_relations is represented as List[(subject_label, object_label, relation_label)].
45
+ #
46
+ # All rendered frames returned by functions are RGB np.ndarray images suitable for saving or video writing.
47
+ ########################################################################################
48
+
49
+ def clean_label(label):
50
+ """Replace underscores and slashes with spaces for uniformity."""
51
+ return label.replace("_", " ").replace("/", " ")
52
+
53
+ # Should be performed somewhere else I believe
54
+ def format_cate_preds(cate_preds):
55
+ # Group object predictions from the model output.
56
+ obj_pred_dict = {}
57
+ for (oid, label), prob in cate_preds.items():
58
+ # Clean the predicted label as well.
59
+ clean_pred = clean_label(label)
60
+ if oid not in obj_pred_dict:
61
+ obj_pred_dict[oid] = []
62
+ obj_pred_dict[oid].append((clean_pred, prob))
63
+ for oid in obj_pred_dict:
64
+ obj_pred_dict[oid].sort(key=lambda x: x[1], reverse=True)
65
+ return obj_pred_dict
66
+
67
+ def format_binary_cate_preds(binary_preds):
68
+ frame_binary_preds = []
69
+ for key, score in binary_preds.items():
70
+ # Expect key format: (frame_id, (subject, object), predicted_relation)
71
+ try:
72
+ f_id, (subj, obj), pred_rel = key
73
+ frame_binary_preds.append((f_id, subj, obj, pred_rel, score))
74
+ except Exception as e:
75
+ print("Skipping key with unexpected format:", key)
76
+ continue
77
+ frame_binary_preds.sort(key=lambda x: x[3], reverse=True)
78
+ return frame_binary_preds
79
+
80
+ _FONT = cv2.FONT_HERSHEY_SIMPLEX
81
+
82
+
83
+ def _to_numpy_mask(mask: Union[np.ndarray, torch.Tensor, None]) -> Optional[np.ndarray]:
84
+ if mask is None:
85
+ return None
86
+ if isinstance(mask, torch.Tensor):
87
+ mask_np = mask.detach().cpu().numpy()
88
+ else:
89
+ mask_np = np.asarray(mask)
90
+ if mask_np.ndim == 0:
91
+ return None
92
+ if mask_np.ndim == 3:
93
+ mask_np = np.squeeze(mask_np)
94
+ if mask_np.ndim != 2:
95
+ return None
96
+ if mask_np.dtype == bool:
97
+ return mask_np
98
+ return mask_np > 0
99
+
100
+
101
+ def _sanitize_bbox(bbox: Union[List[float], Tuple[float, ...], None], width: int, height: int) -> Optional[Tuple[int, int, int, int]]:
102
+ if bbox is None:
103
+ return None
104
+ if isinstance(bbox, (list, tuple)) and len(bbox) >= 4:
105
+ x1, y1, x2, y2 = [float(b) for b in bbox[:4]]
106
+ elif isinstance(bbox, np.ndarray) and bbox.size >= 4:
107
+ x1, y1, x2, y2 = [float(b) for b in bbox.flat[:4]]
108
+ else:
109
+ return None
110
+ x1 = int(np.clip(round(x1), 0, width - 1))
111
+ y1 = int(np.clip(round(y1), 0, height - 1))
112
+ x2 = int(np.clip(round(x2), 0, width - 1))
113
+ y2 = int(np.clip(round(y2), 0, height - 1))
114
+ if x2 <= x1 or y2 <= y1:
115
+ return None
116
+ return (x1, y1, x2, y2)
117
+
118
+
119
+ def _object_color_bgr(obj_id: int) -> Tuple[int, int, int]:
120
+ color = get_color(obj_id)
121
+ rgb = [int(np.clip(c, 0.0, 1.0) * 255) for c in color[:3]]
122
+ return (rgb[2], rgb[1], rgb[0])
123
+
124
+
125
+ def _background_color(color: Tuple[int, int, int]) -> Tuple[int, int, int]:
126
+ return tuple(int(0.25 * 255 + 0.75 * channel) for channel in color)
127
+
128
+
129
+ def _draw_label_block(
130
+ image: np.ndarray,
131
+ lines: List[str],
132
+ anchor: Tuple[int, int],
133
+ color: Tuple[int, int, int],
134
+ font_scale: float = 0.5,
135
+ thickness: int = 1,
136
+ direction: str = "up",
137
+ ) -> None:
138
+ if not lines:
139
+ return
140
+ img_h, img_w = image.shape[:2]
141
+ x, y = anchor
142
+ x = int(np.clip(x, 0, img_w - 1))
143
+ y_cursor = int(np.clip(y, 0, img_h - 1))
144
+ bg_color = _background_color(color)
145
+
146
+ if direction == "down":
147
+ for text in lines:
148
+ text = str(text)
149
+ (tw, th), baseline = cv2.getTextSize(text, _FONT, font_scale, thickness)
150
+ left_x = x
151
+ right_x = min(left_x + tw + 8, img_w - 1)
152
+ top_y = int(np.clip(y_cursor + 6, 0, img_h - 1))
153
+ bottom_y = int(np.clip(top_y + th + baseline + 6, 0, img_h - 1))
154
+ if bottom_y <= top_y:
155
+ break
156
+ cv2.rectangle(image, (left_x, top_y), (right_x, bottom_y), bg_color, -1)
157
+ text_x = left_x + 4
158
+ text_y = min(bottom_y - baseline - 2, img_h - 1)
159
+ cv2.putText(image, text, (text_x, text_y), _FONT, font_scale, (0, 0, 0), thickness, cv2.LINE_AA)
160
+ y_cursor = bottom_y
161
+ else:
162
+ for text in lines:
163
+ text = str(text)
164
+ (tw, th), baseline = cv2.getTextSize(text, _FONT, font_scale, thickness)
165
+ top_y = max(y_cursor - th - baseline - 6, 0)
166
+ left_x = x
167
+ right_x = min(left_x + tw + 8, img_w - 1)
168
+ bottom_y = min(top_y + th + baseline + 6, img_h - 1)
169
+ cv2.rectangle(image, (left_x, top_y), (right_x, bottom_y), bg_color, -1)
170
+ text_x = left_x + 4
171
+ text_y = min(bottom_y - baseline - 2, img_h - 1)
172
+ cv2.putText(image, text, (text_x, text_y), _FONT, font_scale, (0, 0, 0), thickness, cv2.LINE_AA)
173
+ y_cursor = top_y
174
+
175
+
176
+ def _draw_centered_label(
177
+ image: np.ndarray,
178
+ text: str,
179
+ center: Tuple[int, int],
180
+ color: Tuple[int, int, int],
181
+ font_scale: float = 0.5,
182
+ thickness: int = 1,
183
+ ) -> None:
184
+ text = str(text)
185
+ img_h, img_w = image.shape[:2]
186
+ (tw, th), baseline = cv2.getTextSize(text, _FONT, font_scale, thickness)
187
+ cx = int(np.clip(center[0], 0, img_w - 1))
188
+ cy = int(np.clip(center[1], 0, img_h - 1))
189
+ left_x = int(np.clip(cx - tw // 2 - 4, 0, img_w - 1))
190
+ top_y = int(np.clip(cy - th // 2 - baseline - 4, 0, img_h - 1))
191
+ right_x = int(np.clip(left_x + tw + 8, 0, img_w - 1))
192
+ bottom_y = int(np.clip(top_y + th + baseline + 6, 0, img_h - 1))
193
+ cv2.rectangle(image, (left_x, top_y), (right_x, bottom_y), _background_color(color), -1)
194
+ text_x = left_x + 4
195
+ text_y = min(bottom_y - baseline - 2, img_h - 1)
196
+ cv2.putText(image, text, (text_x, text_y), _FONT, font_scale, (0, 0, 0), thickness, cv2.LINE_AA)
197
+
198
+
199
+ def _extract_frame_entities(store: Union[Dict[int, Dict[int, Any]], List, None], frame_idx: int) -> Dict[int, Any]:
200
+ if isinstance(store, dict):
201
+ frame_entry = store.get(frame_idx, {})
202
+ elif isinstance(store, list) and 0 <= frame_idx < len(store):
203
+ frame_entry = store[frame_idx]
204
+ else:
205
+ frame_entry = {}
206
+ if isinstance(frame_entry, dict):
207
+ return frame_entry
208
+ if isinstance(frame_entry, list):
209
+ return {i: value for i, value in enumerate(frame_entry)}
210
+ return {}
211
+
212
+
213
+ def _label_anchor_and_direction(
214
+ bbox: Tuple[int, int, int, int],
215
+ position: str,
216
+ ) -> Tuple[Tuple[int, int], str]:
217
+ x1, y1, x2, y2 = bbox
218
+ if position == "bottom":
219
+ return (x1, y2), "down"
220
+ return (x1, y1), "up"
221
+
222
+
223
+ def _draw_bbox_with_label(
224
+ image: np.ndarray,
225
+ bbox: Tuple[int, int, int, int],
226
+ obj_id: int,
227
+ title: Optional[str] = None,
228
+ sub_lines: Optional[List[str]] = None,
229
+ label_position: str = "top",
230
+ ) -> None:
231
+ color = _object_color_bgr(obj_id)
232
+ cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), color, 2)
233
+ head = title if title else f"#{obj_id}"
234
+ if not head.startswith("#"):
235
+ head = f"#{obj_id} {head}"
236
+ lines = [head]
237
+ if sub_lines:
238
+ lines.extend(sub_lines)
239
+ anchor, direction = _label_anchor_and_direction(bbox, label_position)
240
+ _draw_label_block(image, lines, anchor, color, direction=direction)
241
+
242
+
243
+ def render_sam_frames(
244
+ frames: Union[np.ndarray, List[np.ndarray]],
245
+ sam_masks: Union[Dict[int, Dict[int, Union[np.ndarray, torch.Tensor]]], List, None],
246
+ dino_labels: Optional[Dict[int, str]] = None,
247
+ ) -> List[np.ndarray]:
248
+ results: List[np.ndarray] = []
249
+ frames_iterable = frames if isinstance(frames, list) else list(frames)
250
+ dino_labels = dino_labels or {}
251
+
252
+ for frame_idx, frame in enumerate(frames_iterable):
253
+ if frame is None:
254
+ continue
255
+ frame_rgb = np.asarray(frame)
256
+ frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
257
+ overlay = frame_bgr.astype(np.float32)
258
+ masks_for_frame = _extract_frame_entities(sam_masks, frame_idx)
259
+
260
+ for obj_id, mask in masks_for_frame.items():
261
+ mask_np = _to_numpy_mask(mask)
262
+ if mask_np is None or not np.any(mask_np):
263
+ continue
264
+ color = _object_color_bgr(obj_id)
265
+ alpha = 0.45
266
+ overlay[mask_np] = (1.0 - alpha) * overlay[mask_np] + alpha * np.array(color, dtype=np.float32)
267
+
268
+ annotated = np.clip(overlay, 0, 255).astype(np.uint8)
269
+ frame_h, frame_w = annotated.shape[:2]
270
+
271
+ for obj_id, mask in masks_for_frame.items():
272
+ mask_np = _to_numpy_mask(mask)
273
+ if mask_np is None or not np.any(mask_np):
274
+ continue
275
+ bbox = mask_to_bbox(mask_np)
276
+ bbox = _sanitize_bbox(bbox, frame_w, frame_h)
277
+ if not bbox:
278
+ continue
279
+ label = dino_labels.get(obj_id)
280
+ title = f"{label}" if label else None
281
+ _draw_bbox_with_label(annotated, bbox, obj_id, title=title)
282
+
283
+ results.append(cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB))
284
+
285
+ return results
286
+
287
+
288
+ def render_dino_frames(
289
+ frames: Union[np.ndarray, List[np.ndarray]],
290
+ bboxes: Union[Dict[int, Dict[int, Union[List[float], np.ndarray]]], List, None],
291
+ dino_labels: Optional[Dict[int, str]] = None,
292
+ ) -> List[np.ndarray]:
293
+ results: List[np.ndarray] = []
294
+ frames_iterable = frames if isinstance(frames, list) else list(frames)
295
+ dino_labels = dino_labels or {}
296
+
297
+ for frame_idx, frame in enumerate(frames_iterable):
298
+ if frame is None:
299
+ continue
300
+ frame_rgb = np.asarray(frame)
301
+ annotated = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
302
+ frame_h, frame_w = annotated.shape[:2]
303
+ frame_bboxes = _extract_frame_entities(bboxes, frame_idx)
304
+
305
+ for obj_id, bbox_values in frame_bboxes.items():
306
+ bbox = _sanitize_bbox(bbox_values, frame_w, frame_h)
307
+ if not bbox:
308
+ continue
309
+ label = dino_labels.get(obj_id)
310
+ title = f"{label}" if label else None
311
+ _draw_bbox_with_label(annotated, bbox, obj_id, title=title)
312
+
313
+ results.append(cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB))
314
+
315
+ return results
316
+
317
+
318
+ def render_vine_frame_sets(
319
+ frames: Union[np.ndarray, List[np.ndarray]],
320
+ bboxes: Union[Dict[int, Dict[int, Union[List[float], np.ndarray]]], List, None],
321
+ cat_label_lookup: Dict[int, Tuple[str, float]],
322
+ unary_lookup: Dict[int, Dict[int, List[Tuple[float, str]]]],
323
+ binary_lookup: Dict[int, List[Tuple[Tuple[int, int], List[Tuple[float, str]]]]],
324
+ masks: Union[Dict[int, Dict[int, Union[np.ndarray, torch.Tensor]]], List, None] = None,
325
+ ) -> Dict[str, List[np.ndarray]]:
326
+ frame_groups: Dict[str, List[np.ndarray]] = {
327
+ "object": [],
328
+ "unary": [],
329
+ "binary": [],
330
+ "all": [],
331
+ }
332
+ frames_iterable = frames if isinstance(frames, list) else list(frames)
333
+
334
+ for frame_idx, frame in enumerate(frames_iterable):
335
+ if frame is None:
336
+ continue
337
+ frame_rgb = np.asarray(frame)
338
+ base_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
339
+ frame_h, frame_w = base_bgr.shape[:2]
340
+ frame_bboxes = _extract_frame_entities(bboxes, frame_idx)
341
+ frame_masks = _extract_frame_entities(masks, frame_idx) if masks is not None else {}
342
+
343
+ objects_bgr = base_bgr.copy()
344
+ unary_bgr = base_bgr.copy()
345
+ binary_bgr = base_bgr.copy()
346
+ all_bgr = base_bgr.copy()
347
+
348
+ bbox_lookup: Dict[int, Tuple[int, int, int, int]] = {}
349
+ unary_lines_lookup: Dict[int, List[str]] = {}
350
+ titles_lookup: Dict[int, Optional[str]] = {}
351
+
352
+ for obj_id, bbox_values in frame_bboxes.items():
353
+ bbox = _sanitize_bbox(bbox_values, frame_w, frame_h)
354
+ if not bbox:
355
+ continue
356
+ bbox_lookup[obj_id] = bbox
357
+ cat_label, cat_prob = cat_label_lookup.get(obj_id, (None, None))
358
+ title_parts = []
359
+ if cat_label:
360
+ if cat_prob is not None:
361
+ title_parts.append(f"{cat_label} {cat_prob:.2f}")
362
+ else:
363
+ title_parts.append(cat_label)
364
+ titles_lookup[obj_id] = " ".join(title_parts) if title_parts else None
365
+ unary_preds = unary_lookup.get(frame_idx, {}).get(obj_id, [])
366
+ unary_lines = [f"{label} {prob:.2f}" for prob, label in unary_preds]
367
+ unary_lines_lookup[obj_id] = unary_lines
368
+
369
+ for obj_id, bbox in bbox_lookup.items():
370
+ unary_lines = unary_lines_lookup.get(obj_id, [])
371
+ if not unary_lines:
372
+ continue
373
+ mask_raw = frame_masks.get(obj_id)
374
+ mask_np = _to_numpy_mask(mask_raw)
375
+ if mask_np is None or not np.any(mask_np):
376
+ continue
377
+ color = np.array(_object_color_bgr(obj_id), dtype=np.float32)
378
+ alpha = 0.45
379
+ for target in (unary_bgr, all_bgr):
380
+ target_vals = target[mask_np].astype(np.float32)
381
+ blended = (1.0 - alpha) * target_vals + alpha * color
382
+ target[mask_np] = np.clip(blended, 0, 255).astype(np.uint8)
383
+
384
+ for obj_id, bbox in bbox_lookup.items():
385
+ title = titles_lookup.get(obj_id)
386
+ unary_lines = unary_lines_lookup.get(obj_id, [])
387
+ _draw_bbox_with_label(objects_bgr, bbox, obj_id, title=title, label_position="top")
388
+ _draw_bbox_with_label(unary_bgr, bbox, obj_id, title=title, label_position="top")
389
+ if unary_lines:
390
+ anchor, direction = _label_anchor_and_direction(bbox, "bottom")
391
+ _draw_label_block(unary_bgr, unary_lines, anchor, _object_color_bgr(obj_id), direction=direction)
392
+ _draw_bbox_with_label(binary_bgr, bbox, obj_id, title=title, label_position="top")
393
+ _draw_bbox_with_label(all_bgr, bbox, obj_id, title=title, label_position="top")
394
+ if unary_lines:
395
+ anchor, direction = _label_anchor_and_direction(bbox, "bottom")
396
+ _draw_label_block(all_bgr, unary_lines, anchor, _object_color_bgr(obj_id), direction=direction)
397
+
398
+ for obj_pair, relation_preds in binary_lookup.get(frame_idx, []):
399
+ if len(obj_pair) != 2 or not relation_preds:
400
+ continue
401
+ subj_id, obj_id = obj_pair
402
+ subj_bbox = bbox_lookup.get(subj_id)
403
+ obj_bbox = bbox_lookup.get(obj_id)
404
+ if not subj_bbox or not obj_bbox:
405
+ continue
406
+ start, end = relation_line(subj_bbox, obj_bbox)
407
+ color = tuple(int(c) for c in np.clip(
408
+ (np.array(_object_color_bgr(subj_id), dtype=np.float32) +
409
+ np.array(_object_color_bgr(obj_id), dtype=np.float32)) / 2.0,
410
+ 0, 255
411
+ ))
412
+ prob, relation = relation_preds[0]
413
+ label_text = f"{relation} {prob:.2f}"
414
+ mid_point = (int((start[0] + end[0]) / 2), int((start[1] + end[1]) / 2))
415
+ cv2.line(binary_bgr, start, end, color, 6, cv2.LINE_AA)
416
+ cv2.line(all_bgr, start, end, color, 6, cv2.LINE_AA)
417
+ _draw_centered_label(binary_bgr, label_text, mid_point, color)
418
+ _draw_centered_label(all_bgr, label_text, mid_point, color)
419
+
420
+ frame_groups["object"].append(cv2.cvtColor(objects_bgr, cv2.COLOR_BGR2RGB))
421
+ frame_groups["unary"].append(cv2.cvtColor(unary_bgr, cv2.COLOR_BGR2RGB))
422
+ frame_groups["binary"].append(cv2.cvtColor(binary_bgr, cv2.COLOR_BGR2RGB))
423
+ frame_groups["all"].append(cv2.cvtColor(all_bgr, cv2.COLOR_BGR2RGB))
424
+
425
+ return frame_groups
426
+
427
+
428
+ def render_vine_frames(
429
+ frames: Union[np.ndarray, List[np.ndarray]],
430
+ bboxes: Union[Dict[int, Dict[int, Union[List[float], np.ndarray]]], List, None],
431
+ cat_label_lookup: Dict[int, Tuple[str, float]],
432
+ unary_lookup: Dict[int, Dict[int, List[Tuple[float, str]]]],
433
+ binary_lookup: Dict[int, List[Tuple[Tuple[int, int], List[Tuple[float, str]]]]],
434
+ masks: Union[Dict[int, Dict[int, Union[np.ndarray, torch.Tensor]]], List, None] = None,
435
+ ) -> List[np.ndarray]:
436
+ return render_vine_frame_sets(
437
+ frames,
438
+ bboxes,
439
+ cat_label_lookup,
440
+ unary_lookup,
441
+ binary_lookup,
442
+ masks,
443
+ ).get("all", [])
444
+
445
+ def color_for_cate_correctness(obj_pred_dict, gt_labels, topk_object):
446
+ all_colors = []
447
+ all_texts = []
448
+ for (obj_id, bbox, gt_label) in gt_labels:
449
+ preds = obj_pred_dict.get(obj_id, [])
450
+ if len(preds) == 0:
451
+ top1 = "N/A"
452
+ box_color = (0, 0, 255) # bright red if no prediction
453
+ else:
454
+ top1, prob1 = preds[0]
455
+ topk_labels = [p[0] for p in preds[:topk_object]]
456
+ # Compare cleaned labels.
457
+ if top1.lower() == gt_label.lower():
458
+ box_color = (0, 255, 0) # bright green for correct
459
+ elif gt_label.lower() in [p.lower() for p in topk_labels]:
460
+ box_color = (0, 165, 255) # bright orange for partial match
461
+ else:
462
+ box_color = (0, 0, 255) # bright red for incorrect
463
+
464
+ label_text = f"ID:{obj_id}/P:{top1}/GT:{gt_label}"
465
+ all_colors.append(box_color)
466
+ all_texts.append(label_text)
467
+ return all_colors, all_texts
468
+
469
+ def plot_unary(frame_img, gt_labels, all_colors, all_texts):
470
+
471
+ for (obj_id, bbox, gt_label), box_color, label_text in zip(gt_labels, all_colors, all_texts):
472
+ x1, y1, x2, y2 = map(int, bbox)
473
+ cv2.rectangle(frame_img, (x1, y1), (x2, y2), color=box_color, thickness=2)
474
+ (tw, th), baseline = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
475
+ cv2.rectangle(frame_img, (x1, y1 - th - baseline - 4), (x1 + tw, y1), box_color, -1)
476
+ cv2.putText(frame_img, label_text, (x1, y1 - 2), cv2.FONT_HERSHEY_SIMPLEX,
477
+ 0.5, (0, 0, 0), 1, cv2.LINE_AA)
478
+
479
+ return frame_img
480
+
481
+ def get_white_pane(pane_height,
482
+ pane_width=600,
483
+ header_height = 50,
484
+ header_font = cv2.FONT_HERSHEY_SIMPLEX,
485
+ header_font_scale = 0.7,
486
+ header_thickness = 2,
487
+ header_color = (0, 0, 0)):
488
+ # Create an expanded white pane to display text info.
489
+ white_pane = 255 * np.ones((pane_height, pane_width, 3), dtype=np.uint8)
490
+
491
+ # --- Adjust pane split: make predictions column wider (60% vs. 40%) ---
492
+ left_width = int(pane_width * 0.6)
493
+ right_width = pane_width - left_width
494
+ left_pane = white_pane[:, :left_width, :].copy()
495
+ right_pane = white_pane[:, left_width:, :].copy()
496
+
497
+ cv2.putText(left_pane, "Binary Predictions", (10, header_height - 30),
498
+ header_font, header_font_scale, header_color, header_thickness, cv2.LINE_AA)
499
+ cv2.putText(right_pane, "Ground Truth", (10, header_height - 30),
500
+ header_font, header_font_scale, header_color, header_thickness, cv2.LINE_AA)
501
+
502
+ return white_pane
503
+
504
+ # This is for ploting binary prediction results with frame-based scene graphs
505
+ def plot_binary_sg(frame_img,
506
+ white_pane,
507
+ bin_preds,
508
+ gt_relations,
509
+ topk_binary,
510
+ header_height=50,
511
+ indicator_size=20,
512
+ pane_width=600):
513
+ # Leave vertical space for the headers.
514
+ line_height = 30 # vertical spacing per line
515
+ x_text = 10 # left margin for text
516
+ y_text_left = header_height + 10 # starting y for left pane text
517
+ y_text_right = header_height + 10 # starting y for right pane text
518
+
519
+ # Left section: top-k binary predictions.
520
+ left_width = int(pane_width * 0.6)
521
+ right_width = pane_width - left_width
522
+ left_pane = white_pane[:, :left_width, :].copy()
523
+ right_pane = white_pane[:, left_width:, :].copy()
524
+
525
+ for (subj, pred_rel, obj, score) in bin_preds[:topk_binary]:
526
+ correct = any((subj == gt[0] and pred_rel.lower() == gt[2].lower() and obj == gt[1])
527
+ for gt in gt_relations)
528
+ indicator_color = (0, 255, 0) if correct else (0, 0, 255)
529
+ cv2.rectangle(left_pane, (x_text, y_text_left - indicator_size + 5),
530
+ (x_text + indicator_size, y_text_left + 5), indicator_color, -1)
531
+ text = f"{subj} - {pred_rel} - {obj} :: {score:.2f}"
532
+ cv2.putText(left_pane, text, (x_text + indicator_size + 5, y_text_left + 5),
533
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 1, cv2.LINE_AA)
534
+ y_text_left += line_height
535
+
536
+ # Right section: ground truth binary relations.
537
+ for gt in gt_relations:
538
+ if len(gt) != 3:
539
+ continue
540
+ text = f"{gt[0]} - {gt[2]} - {gt[1]}"
541
+ cv2.putText(right_pane, text, (x_text, y_text_right + 5),
542
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 1, cv2.LINE_AA)
543
+ y_text_right += line_height
544
+
545
+ # Combine the two text panes and then with the frame image.
546
+ combined_pane = np.hstack((left_pane, right_pane))
547
+ combined_image = np.hstack((frame_img, combined_pane))
548
+ return combined_image
549
+
550
+ def visualized_frame(frame_img,
551
+ bboxes,
552
+ object_ids,
553
+ gt_labels,
554
+ cate_preds,
555
+ binary_preds,
556
+ gt_relations,
557
+ topk_object,
558
+ topk_binary,
559
+ phase="unary"):
560
+
561
+ """Return the combined annotated frame for frame index i as an image (in BGR)."""
562
+ # Get the frame image (assuming batched_data['batched_reshaped_raw_videos'] is a list of frames)
563
+
564
+ # --- Process Object Predictions (for overlaying bboxes) ---
565
+ if phase == "unary":
566
+ objs = []
567
+ for ((_, f_id, obj_id), bbox, gt_label) in zip(object_ids, bboxes, gt_labels):
568
+ gt_label = clean_label(gt_label)
569
+ objs.append((obj_id, bbox, gt_label))
570
+
571
+ formatted_cate_preds = format_cate_preds(cate_preds)
572
+ all_colors, all_texts = color_for_cate_correctness(formatted_cate_preds, gt_labels, topk_object)
573
+ updated_frame_img = plot_unary(frame_img, gt_labels, all_colors, all_texts)
574
+ return updated_frame_img
575
+
576
+ else:
577
+ # --- Process Binary Predictions & Ground Truth for the Text Pane ---
578
+ formatted_binary_preds = format_binary_cate_preds(binary_preds)
579
+
580
+ # Ground truth binary relations for the frame.
581
+ # Clean ground truth relations.
582
+ gt_relations = [(clean_label(str(s)), clean_label(str(o)), clean_label(rel)) for s, o, rel in gt_relations]
583
+
584
+ pane_width = 600 # increased pane width for more horizontal space
585
+ pane_height = frame_img.shape[0]
586
+
587
+ # --- Add header labels to each text pane with extra space ---
588
+ header_height = 50 # increased header space
589
+ white_pane = get_white_pane(pane_height, pane_width, header_height=header_height)
590
+
591
+ combined_image = plot_binary_sg(frame_img, white_pane, formatted_binary_preds, gt_relations, topk_binary)
592
+
593
+ return combined_image
594
+
595
+ def show_mask(mask, ax, obj_id=None, det_class=None, random_color=False):
596
+ # Ensure mask is a numpy array
597
+ mask = np.array(mask)
598
+ # Handle different mask shapes
599
+ if mask.ndim == 3:
600
+ # (1, H, W) -> (H, W)
601
+ if mask.shape[0] == 1:
602
+ mask = mask.squeeze(0)
603
+ # (H, W, 1) -> (H, W)
604
+ elif mask.shape[2] == 1:
605
+ mask = mask.squeeze(2)
606
+ # Now mask should be (H, W)
607
+ assert mask.ndim == 2, f"Mask must be 2D after squeezing, got shape {mask.shape}"
608
+
609
+ if random_color:
610
+ color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0)
611
+ else:
612
+ cmap = plt.get_cmap("gist_rainbow")
613
+ cmap_idx = 0 if obj_id is None else obj_id
614
+ color = list(cmap((cmap_idx * 47) % 256))
615
+ color[3] = 0.5
616
+ color = np.array(color)
617
+
618
+ # Expand mask to (H, W, 1) for broadcasting
619
+ mask_expanded = mask[..., None]
620
+ mask_image = mask_expanded * color.reshape(1, 1, -1)
621
+
622
+ # draw a box around the mask with the det_class as the label
623
+ if not det_class is None:
624
+ # Find the bounding box coordinates
625
+ y_indices, x_indices = np.where(mask > 0)
626
+ if y_indices.size > 0 and x_indices.size > 0:
627
+ x_min, x_max = x_indices.min(), x_indices.max()
628
+ y_min, y_max = y_indices.min(), y_indices.max()
629
+ rect = Rectangle(
630
+ (x_min, y_min),
631
+ x_max - x_min,
632
+ y_max - y_min,
633
+ linewidth=1.5,
634
+ edgecolor=color[:3],
635
+ facecolor="none",
636
+ alpha=color[3]
637
+ )
638
+ ax.add_patch(rect)
639
+ ax.text(
640
+ x_min,
641
+ y_min - 5,
642
+ f"{det_class}",
643
+ color="white",
644
+ fontsize=6,
645
+ backgroundcolor=np.array(color),
646
+ alpha=1
647
+ )
648
+ ax.imshow(mask_image)
649
+
650
+ def save_mask_one_image(frame_image, masks, save_path):
651
+ """Render masks on top of a frame and store the visualization on disk."""
652
+ fig, ax = plt.subplots(1, figsize=(6, 6))
653
+
654
+ frame_np = (
655
+ frame_image.detach().cpu().numpy()
656
+ if torch.is_tensor(frame_image)
657
+ else np.asarray(frame_image)
658
+ )
659
+ frame_np = np.ascontiguousarray(frame_np)
660
+
661
+ if isinstance(masks, dict):
662
+ mask_iter = masks.items()
663
+ else:
664
+ mask_iter = enumerate(masks)
665
+
666
+ prepared_masks = {
667
+ obj_id: (
668
+ mask.detach().cpu().numpy()
669
+ if torch.is_tensor(mask)
670
+ else np.asarray(mask)
671
+ )
672
+ for obj_id, mask in mask_iter
673
+ }
674
+
675
+ ax.imshow(frame_np)
676
+ ax.axis("off")
677
+
678
+ for obj_id, mask_np in prepared_masks.items():
679
+ show_mask(mask_np, ax, obj_id=obj_id, det_class=None, random_color=False)
680
+
681
+ fig.savefig(save_path, bbox_inches="tight", pad_inches=0)
682
+ plt.close(fig)
683
+ return save_path
684
+
685
+ def get_video_masks_visualization(video_tensor,
686
+ video_masks,
687
+ video_id,
688
+ video_save_base_dir,
689
+ oid_class_pred=None,
690
+ sample_rate = 1):
691
+
692
+ video_save_dir = os.path.join(video_save_base_dir, video_id)
693
+ if not os.path.exists(video_save_dir):
694
+ os.makedirs(video_save_dir, exist_ok=True)
695
+
696
+ for frame_id, image in enumerate(video_tensor):
697
+ if frame_id not in video_masks:
698
+ print("No mask for Frame", frame_id)
699
+ continue
700
+
701
+ masks = video_masks[frame_id]
702
+ save_path = os.path.join(video_save_dir, f"{frame_id}.jpg")
703
+ get_mask_one_image(image, masks, oid_class_pred)
704
+
705
+ def get_mask_one_image(frame_image, masks, oid_class_pred=None):
706
+ # Create a figure and axis
707
+ fig, ax = plt.subplots(1, figsize=(6, 6))
708
+
709
+ # Display the frame image
710
+ ax.imshow(frame_image)
711
+ ax.axis('off')
712
+
713
+ if type(masks) == list:
714
+ masks = {i: m for i, m in enumerate(masks)}
715
+
716
+ # Add the masks
717
+ for obj_id, mask in masks.items():
718
+ det_class = f"{obj_id}. {oid_class_pred[obj_id]}" if not oid_class_pred is None else None
719
+ show_mask(mask, ax, obj_id=obj_id, det_class=det_class, random_color=False)
720
+
721
+ # Show the plot
722
+ return fig, ax
723
+
724
+ def save_video(frames, output_filename, output_fps):
725
+
726
+ # --- Create a video from all frames ---
727
+ num_frames = len(frames)
728
+ frame_h, frame_w = frames.shape[:2]
729
+
730
+ # Use a codec supported by VS Code (H.264 via 'avc1').
731
+ fourcc = cv2.VideoWriter_fourcc(*'avc1')
732
+ out = cv2.VideoWriter(output_filename, fourcc, output_fps, (frame_w, frame_h))
733
+
734
+ print(f"Processing {num_frames} frames...")
735
+ for i in range(num_frames):
736
+ vis_frame = get_visualized_frame(i)
737
+ out.write(vis_frame)
738
+ if i % 10 == 0:
739
+ print(f"Processed frame {i+1}/{num_frames}")
740
+
741
+ out.release()
742
+ print(f"Video saved as {output_filename}")
743
+
744
+
745
+ def list_depth(lst):
746
+ """Calculates the depth of a nested list."""
747
+ if not (isinstance(lst, list) or isinstance(lst, torch.Tensor)):
748
+ return 0
749
+ elif (isinstance(lst, torch.Tensor) and lst.shape == torch.Size([])) or (isinstance(lst, list) and len(lst) == 0):
750
+ return 1
751
+ else:
752
+ return 1 + max(list_depth(item) for item in lst)
753
+
754
+ def normalize_prompt(points, labels):
755
+ if list_depth(points) == 3:
756
+ points = torch.stack([p.unsqueeze(0) for p in points])
757
+ labels = torch.stack([l.unsqueeze(0) for l in labels])
758
+ return points, labels
759
+
760
+
761
+ def show_box(box, ax, object_id):
762
+ if len(box) == 0:
763
+ return
764
+
765
+ cmap = plt.get_cmap("gist_rainbow")
766
+ cmap_idx = 0 if object_id is None else object_id
767
+ color = list(cmap((cmap_idx * 47) % 256))
768
+
769
+ x0, y0 = box[0], box[1]
770
+ w, h = box[2] - box[0], box[3] - box[1]
771
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor=color, facecolor=(0,0,0,0), lw=2))
772
+
773
+ def show_points(coords, labels, ax, object_id=None, marker_size=375):
774
+ if len(labels) == 0:
775
+ return
776
+
777
+ pos_points = coords[labels==1]
778
+ neg_points = coords[labels==0]
779
+
780
+ cmap = plt.get_cmap("gist_rainbow")
781
+ cmap_idx = 0 if object_id is None else object_id
782
+ color = list(cmap((cmap_idx * 47) % 256))
783
+
784
+ ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='P', s=marker_size, edgecolor=color, linewidth=1.25)
785
+ ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='s', s=marker_size, edgecolor=color, linewidth=1.25)
786
+
787
+ def save_prompts_one_image(frame_image, boxes, points, labels, save_path):
788
+ # Create a figure and axis
789
+ fig, ax = plt.subplots(1, figsize=(6, 6))
790
+
791
+ # Display the frame image
792
+ ax.imshow(frame_image)
793
+ ax.axis('off')
794
+
795
+ points, labels = normalize_prompt(points, labels)
796
+ if type(boxes) == torch.Tensor:
797
+ for object_id, box in enumerate(boxes):
798
+ # Add the bounding boxes
799
+ if not box is None:
800
+ show_box(box.cpu(), ax, object_id=object_id)
801
+ elif type(boxes) == dict:
802
+ for object_id, box in boxes.items():
803
+ # Add the bounding boxes
804
+ if not box is None:
805
+ show_box(box.cpu(), ax, object_id=object_id)
806
+ elif type(boxes) == list and len(boxes) == 0:
807
+ pass
808
+ else:
809
+ raise Exception()
810
+
811
+ for object_id, (point_ls, label_ls) in enumerate(zip(points, labels)):
812
+ if not len(point_ls) == 0:
813
+ show_points(point_ls.cpu(), label_ls.cpu(), ax, object_id=object_id)
814
+
815
+ # Show the plot
816
+ plt.savefig(save_path)
817
+ plt.close()
818
+
819
+ def save_video_prompts_visualization(video_tensor, video_boxes, video_points, video_labels, video_id, video_save_base_dir):
820
+ video_save_dir = os.path.join(video_save_base_dir, video_id)
821
+ if not os.path.exists(video_save_dir):
822
+ os.makedirs(video_save_dir, exist_ok=True)
823
+
824
+ for frame_id, image in enumerate(video_tensor):
825
+ boxes, points, labels = [], [], []
826
+
827
+ if frame_id in video_boxes:
828
+ boxes = video_boxes[frame_id]
829
+
830
+ if frame_id in video_points:
831
+ points = video_points[frame_id]
832
+ if frame_id in video_labels:
833
+ labels = video_labels[frame_id]
834
+
835
+ save_path = os.path.join(video_save_dir, f"{frame_id}.jpg")
836
+ save_prompts_one_image(image, boxes, points, labels, save_path)
837
+
838
+
839
+ def save_video_masks_visualization(video_tensor, video_masks, video_id, video_save_base_dir, oid_class_pred=None, sample_rate = 1):
840
+ video_save_dir = os.path.join(video_save_base_dir, video_id)
841
+ if not os.path.exists(video_save_dir):
842
+ os.makedirs(video_save_dir, exist_ok=True)
843
+
844
+ for frame_id, image in enumerate(video_tensor):
845
+ if random.random() > sample_rate:
846
+ continue
847
+ if frame_id not in video_masks:
848
+ print("No mask for Frame", frame_id)
849
+ continue
850
+ masks = video_masks[frame_id]
851
+ save_path = os.path.join(video_save_dir, f"{frame_id}.jpg")
852
+ save_mask_one_image(image, masks, save_path)
853
+
854
+
855
+
856
+ def get_color(obj_id, cmap_name="gist_rainbow",alpha=0.5):
857
+ cmap = plt.get_cmap(cmap_name)
858
+ cmap_idx = 0 if obj_id is None else obj_id
859
+ color = list(cmap((cmap_idx * 47) % 256))
860
+ color[3] = 0.5
861
+ color = np.array(color)
862
+ return color
863
+
864
+
865
+ def _bbox_center(bbox: Tuple[int, int, int, int]) -> Tuple[float, float]:
866
+ return ((bbox[0] + bbox[2]) / 2.0, (bbox[1] + bbox[3]) / 2.0)
867
+
868
+
869
+ def relation_line(
870
+ bbox1: Tuple[int, int, int, int],
871
+ bbox2: Tuple[int, int, int, int],
872
+ ) -> Tuple[Tuple[int, int], Tuple[int, int]]:
873
+ """
874
+ Returns integer pixel centers suitable for drawing a relation line. For
875
+ coincident boxes, nudges the target center to ensure the segment has span.
876
+ """
877
+ center1 = _bbox_center(bbox1)
878
+ center2 = _bbox_center(bbox2)
879
+ if math.isclose(center1[0], center2[0], abs_tol=1e-3) and math.isclose(center1[1], center2[1], abs_tol=1e-3):
880
+ offset = max(1.0, (bbox2[2] - bbox2[0]) * 0.05)
881
+ center2 = (center2[0] + offset, center2[1])
882
+ start = (int(round(center1[0])), int(round(center1[1])))
883
+ end = (int(round(center2[0])), int(round(center2[1])))
884
+ if start == end:
885
+ end = (end[0] + 1, end[1])
886
+ return start, end
887
+
888
+ def get_binary_mask_one_image(frame_image, masks, rel_pred_ls=None):
889
+ # Create a figure and axis
890
+ fig, ax = plt.subplots(1, figsize=(6, 6))
891
+
892
+ # Display the frame image
893
+ ax.imshow(frame_image)
894
+ ax.axis('off')
895
+
896
+ all_objs_to_show = set()
897
+ all_lines_to_show = []
898
+
899
+ # print(rel_pred_ls[0])
900
+ for (from_obj_id, to_obj_id), rel_text in rel_pred_ls.items():
901
+ all_objs_to_show.add(from_obj_id)
902
+ all_objs_to_show.add(to_obj_id)
903
+
904
+ from_mask = masks[from_obj_id]
905
+ bbox1 = mask_to_bbox(from_mask)
906
+ to_mask = masks[to_obj_id]
907
+ bbox2 = mask_to_bbox(to_mask)
908
+
909
+ c1, c2 = shortest_line_between_bboxes(bbox1, bbox2)
910
+
911
+ line_color = get_color(from_obj_id)
912
+ face_color = get_color(to_obj_id)
913
+ line = c1, c2, face_color, line_color, rel_text
914
+ all_lines_to_show.append(line)
915
+
916
+ masks_to_show = {}
917
+ for oid in all_objs_to_show:
918
+ masks_to_show[oid] = masks[oid]
919
+
920
+ # Add the masks
921
+ for obj_id, mask in masks_to_show.items():
922
+ show_mask(mask, ax, obj_id=obj_id, random_color=False)
923
+
924
+ for (from_pt_x, from_pt_y), (to_pt_x, to_pt_y), face_color, line_color, rel_text in all_lines_to_show:
925
+
926
+ plt.plot([from_pt_x, to_pt_x], [from_pt_y, to_pt_y], color=line_color, linestyle='-', linewidth=3)
927
+ mid_pt_x = (from_pt_x + to_pt_x) / 2
928
+ mid_pt_y = (from_pt_y + to_pt_y) / 2
929
+ ax.text(
930
+ mid_pt_x - 5,
931
+ mid_pt_y,
932
+ rel_text,
933
+ color="white",
934
+ fontsize=6,
935
+ backgroundcolor=np.array(line_color),
936
+ bbox=dict(facecolor=face_color, edgecolor=line_color, boxstyle='round,pad=1'),
937
+ alpha=1
938
+ )
939
+
940
+ # Show the plot
941
+ return fig, ax