|
|
import torch |
|
|
import numpy as np |
|
|
import cv2 |
|
|
import os |
|
|
from typing import Dict, List, Tuple, Optional, Any, Union |
|
|
from transformers import Pipeline |
|
|
import tempfile |
|
|
import uuid |
|
|
|
|
|
from .vine_config import VineConfig |
|
|
from .vine_model import VineModel |
|
|
from .vis_utils import render_dino_frames, render_sam_frames, render_vine_frame_sets |
|
|
from laser.loading import load_video |
|
|
from laser.preprocess.mask_generation_grounding_dino import generate_masks_grounding_dino |
|
|
|
|
|
class VinePipeline(Pipeline): |
|
|
""" |
|
|
Pipeline for VINE model that handles end-to-end video understanding. |
|
|
|
|
|
This pipeline takes a video file or frames, along with segmentation method |
|
|
and keyword lists, and returns probability distributions over the keywords. |
|
|
|
|
|
Segmentation Model Configuration: |
|
|
The pipeline requires SAM2 and GroundingDINO models for mask generation. |
|
|
You can configure custom paths via constructor kwargs: |
|
|
|
|
|
- sam_config_path: Path to SAM2 config (e.g., "configs/sam2.1/sam2.1_hiera_b+.yaml") |
|
|
- sam_checkpoint_path: Path to SAM2 checkpoint (e.g., "checkpoints/sam2.1_hiera_base_plus.pt") |
|
|
- gd_config_path: Path to GroundingDINO config (e.g., "groundingdino/config/GroundingDINO_SwinT_OGC.py") |
|
|
- gd_checkpoint_path: Path to GroundingDINO checkpoint (e.g., "checkpoints/groundingdino_swint_ogc.pth") |
|
|
|
|
|
Old: |
|
|
- SAM2: ~/research/sam2/ or /home/asethi04/LASER_NEW/LASER/sam2/ |
|
|
- GroundingDINO: /home/asethi04/LASER_NEW/LASER/GroundingDINO/ |
|
|
|
|
|
Alternative: Use set_segmentation_models() to provide pre-initialized model instances. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
sam_config_path: Optional[str] = None, |
|
|
sam_checkpoint_path: Optional[str] = None, |
|
|
gd_config_path: Optional[str] = None, |
|
|
gd_checkpoint_path: Optional[str] = None, |
|
|
**kwargs |
|
|
): |
|
|
self.grounding_model = None |
|
|
self.sam_predictor = None |
|
|
self.mask_generator = None |
|
|
|
|
|
self.sam_config_path = sam_config_path |
|
|
self.sam_checkpoint_path = sam_checkpoint_path |
|
|
self.gd_config_path = gd_config_path |
|
|
self.gd_checkpoint_path = gd_checkpoint_path |
|
|
|
|
|
|
|
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
|
|
|
|
self.segmentation_method = getattr(self.model.config, 'segmentation_method', 'grounding_dino_sam2') |
|
|
self.box_threshold = getattr(self.model.config, 'box_threshold', 0.35) |
|
|
self.text_threshold = getattr(self.model.config, 'text_threshold', 0.25) |
|
|
self.target_fps = getattr(self.model.config, 'target_fps', 1) |
|
|
self.visualize = getattr(self.model.config, 'visualize', False) |
|
|
self.visualization_dir = getattr(self.model.config, 'visualization_dir', None) |
|
|
self.debug_visualizations = getattr(self.model.config, 'debug_visualizations', False) |
|
|
self._device = getattr(self.model.config, '_device') |
|
|
if kwargs.get("device") is not None: |
|
|
self._device = kwargs.get("device") |
|
|
|
|
|
def set_segmentation_models( |
|
|
self, |
|
|
*, |
|
|
sam_predictor=None, |
|
|
mask_generator=None, |
|
|
grounding_model=None |
|
|
): |
|
|
""" |
|
|
Set pre-initialized segmentation models, bypassing automatic initialization/current_values |
|
|
|
|
|
Args: |
|
|
sam_predictor: Pre-built SAM2 video predictor |
|
|
mask_generator: Pre-built SAM2 automatic mask generator |
|
|
grounding_model: Pre-built GroundingDINO model |
|
|
""" |
|
|
if sam_predictor is not None: |
|
|
self.sam_predictor = sam_predictor |
|
|
if mask_generator is not None: |
|
|
self.mask_generator = mask_generator |
|
|
if grounding_model is not None: |
|
|
self.grounding_model = grounding_model |
|
|
|
|
|
def _sanitize_parameters(self, **kwargs): |
|
|
"""Sanitize parameters for different pipeline stages.""" |
|
|
preprocess_kwargs = {} |
|
|
forward_kwargs = {} |
|
|
postprocess_kwargs = {} |
|
|
|
|
|
|
|
|
if "segmentation_method" in kwargs: |
|
|
preprocess_kwargs["segmentation_method"] = kwargs["segmentation_method"] |
|
|
if "target_fps" in kwargs: |
|
|
preprocess_kwargs["target_fps"] = kwargs["target_fps"] |
|
|
if "box_threshold" in kwargs: |
|
|
preprocess_kwargs["box_threshold"] = kwargs["box_threshold"] |
|
|
if "text_threshold" in kwargs: |
|
|
preprocess_kwargs["text_threshold"] = kwargs["text_threshold"] |
|
|
if "categorical_keywords" in kwargs: |
|
|
preprocess_kwargs["categorical_keywords"] = kwargs["categorical_keywords"] |
|
|
|
|
|
|
|
|
if "categorical_keywords" in kwargs: |
|
|
forward_kwargs["categorical_keywords"] = kwargs["categorical_keywords"] |
|
|
if "unary_keywords" in kwargs: |
|
|
forward_kwargs["unary_keywords"] = kwargs["unary_keywords"] |
|
|
if "binary_keywords" in kwargs: |
|
|
forward_kwargs["binary_keywords"] = kwargs["binary_keywords"] |
|
|
if "object_pairs" in kwargs: |
|
|
forward_kwargs["object_pairs"] = kwargs["object_pairs"] |
|
|
if "return_flattened_segments" in kwargs: |
|
|
forward_kwargs["return_flattened_segments"] = kwargs["return_flattened_segments"] |
|
|
if "return_valid_pairs" in kwargs: |
|
|
forward_kwargs["return_valid_pairs"] = kwargs["return_valid_pairs"] |
|
|
if "interested_object_pairs" in kwargs: |
|
|
forward_kwargs["interested_object_pairs"] = kwargs["interested_object_pairs"] |
|
|
if "debug_visualizations" in kwargs: |
|
|
forward_kwargs["debug_visualizations"] = kwargs["debug_visualizations"] |
|
|
postprocess_kwargs["debug_visualizations"] = kwargs["debug_visualizations"] |
|
|
|
|
|
|
|
|
if "return_top_k" in kwargs: |
|
|
postprocess_kwargs["return_top_k"] = kwargs["return_top_k"] |
|
|
if "self.visualize" in kwargs: |
|
|
postprocess_kwargs["self.visualize"] = kwargs["self.visualize"] |
|
|
|
|
|
return preprocess_kwargs, forward_kwargs, postprocess_kwargs |
|
|
|
|
|
def preprocess( |
|
|
self, |
|
|
video_input: Union[str, np.ndarray, torch.Tensor], |
|
|
segmentation_method: str = None, |
|
|
target_fps: int = None, |
|
|
box_threshold: float = None, |
|
|
text_threshold: float = None, |
|
|
categorical_keywords: List[str] = None, |
|
|
**kwargs |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Preprocess video input and generate masks. |
|
|
|
|
|
Args: |
|
|
video_input: Path to video file, or video tensor/array |
|
|
segmentation_method: "sam2" or "grounding_dino_sam2" |
|
|
target_fps: Target FPS for video processing |
|
|
box_threshold: Box threshold for Grounding DINO |
|
|
text_threshold: Text threshold for Grounding DINO |
|
|
categorical_keywords: Keywords for Grounding DINO segmentation |
|
|
|
|
|
Returns: |
|
|
Dict containing video frames, masks, and bboxes |
|
|
""" |
|
|
|
|
|
if segmentation_method is None: |
|
|
segmentation_method = self.segmentation_method |
|
|
if target_fps is None: |
|
|
target_fps = self.target_fps |
|
|
if box_threshold is None: |
|
|
box_threshold = self.box_threshold |
|
|
if text_threshold is None: |
|
|
text_threshold = self.text_threshold |
|
|
if categorical_keywords is None: |
|
|
categorical_keywords = ["object"] |
|
|
|
|
|
if isinstance(video_input, str): |
|
|
|
|
|
video_tensor = load_video(video_input, target_fps=target_fps) |
|
|
if isinstance(video_tensor, list): |
|
|
video_tensor = np.array(video_tensor) |
|
|
elif isinstance(video_tensor, torch.Tensor): |
|
|
video_tensor = video_tensor.cpu().numpy() |
|
|
|
|
|
elif isinstance(video_input, (np.ndarray, torch.Tensor)): |
|
|
|
|
|
if isinstance(video_input, torch.Tensor): |
|
|
video_tensor = video_input.numpy() |
|
|
else: |
|
|
video_tensor = video_input |
|
|
else: |
|
|
raise ValueError(f"Unsupported video input type: {type(video_input)}") |
|
|
|
|
|
|
|
|
if not isinstance(video_tensor, np.ndarray): |
|
|
video_tensor = np.array(video_tensor) |
|
|
|
|
|
|
|
|
if len(video_tensor.shape) != 4: |
|
|
raise ValueError(f"Expected video tensor shape (frames, height, width, channels), got {video_tensor.shape}") |
|
|
|
|
|
|
|
|
visualization_data: Dict[str, Any] = {} |
|
|
print(f"Segmentation method: {segmentation_method}") |
|
|
if segmentation_method == "sam2": |
|
|
masks, bboxes, vis_data = self._generate_sam2_masks(video_tensor) |
|
|
elif segmentation_method == "grounding_dino_sam2": |
|
|
masks, bboxes, vis_data = self._generate_grounding_dino_sam2_masks( |
|
|
video_tensor, categorical_keywords, box_threshold, text_threshold, video_input |
|
|
) |
|
|
else: |
|
|
raise ValueError(f"Unsupported segmentation method: {segmentation_method}") |
|
|
if vis_data: |
|
|
visualization_data.update(vis_data) |
|
|
visualization_data.setdefault("sam_masks", masks) |
|
|
|
|
|
return { |
|
|
"video_frames": torch.tensor(video_tensor), |
|
|
"masks": masks, |
|
|
"bboxes": bboxes, |
|
|
"num_frames": len(video_tensor), |
|
|
"visualization_data": visualization_data, |
|
|
} |
|
|
|
|
|
def _generate_sam2_masks(self, video_tensor: np.ndarray) -> Tuple[Dict, Dict, Dict[str, Any]]: |
|
|
"""Generate masks using SAM2 automatic mask generation.""" |
|
|
|
|
|
print("Generating SAM2 masks...") |
|
|
if self.mask_generator is None: |
|
|
self._initialize_segmentation_models() |
|
|
|
|
|
if self.mask_generator is None: |
|
|
raise ValueError("SAM2 mask generator not available") |
|
|
|
|
|
masks: Dict[int, Dict[int, torch.Tensor]] = {} |
|
|
bboxes: Dict[int, Dict[int, List[int]]] = {} |
|
|
|
|
|
for frame_id, frame in enumerate(video_tensor): |
|
|
if isinstance(frame, np.ndarray) and frame.dtype != np.uint8: |
|
|
frame = (frame * 255).astype(np.uint8) if frame.max() <= 1 else frame.astype(np.uint8) |
|
|
|
|
|
height, width, _ = frame.shape |
|
|
frame_masks = self.mask_generator.generate(frame) |
|
|
|
|
|
masks[frame_id] = {} |
|
|
bboxes[frame_id] = {} |
|
|
|
|
|
for obj_id, mask_data in enumerate(frame_masks): |
|
|
mask = mask_data["segmentation"] |
|
|
if isinstance(mask, np.ndarray): |
|
|
mask = torch.from_numpy(mask) |
|
|
|
|
|
if len(mask.shape) == 2: |
|
|
mask = mask.unsqueeze(-1) |
|
|
elif len(mask.shape) == 3 and mask.shape[0] == 1: |
|
|
mask = mask.permute(1, 2, 0) |
|
|
|
|
|
wrapped_id = obj_id + 1 |
|
|
masks[frame_id][wrapped_id] = mask |
|
|
|
|
|
mask_np = mask.squeeze().numpy() if isinstance(mask, torch.Tensor) else mask.squeeze() |
|
|
|
|
|
coords = np.where(mask_np > 0) |
|
|
if len(coords[0]) > 0: |
|
|
y1, y2 = coords[0].min(), coords[0].max() |
|
|
x1, x2 = coords[1].min(), coords[1].max() |
|
|
bboxes[frame_id][wrapped_id] = [x1, y1, x2, y2] |
|
|
|
|
|
return masks, bboxes, {"sam_masks": masks} |
|
|
|
|
|
def _generate_grounding_dino_sam2_masks( |
|
|
self, |
|
|
video_tensor: np.ndarray, |
|
|
categorical_keywords: List[str], |
|
|
box_threshold: float, |
|
|
text_threshold: float, |
|
|
video_path: str, |
|
|
) -> Tuple[Dict, Dict, Dict[str, Any]]: |
|
|
"""Generate masks using Grounding DINO + SAM2.""" |
|
|
|
|
|
print("Generating Grounding DINO + SAM2 masks...") |
|
|
if self.grounding_model is None or self.sam_predictor is None: |
|
|
self._initialize_segmentation_models() |
|
|
|
|
|
if self.grounding_model is None or self.sam_predictor is None: |
|
|
raise ValueError("GroundingDINO or SAM2 models not available") |
|
|
|
|
|
temp_video_path = None |
|
|
if video_path is None or not isinstance(video_path, str): |
|
|
temp_video_path = self._create_temp_video(video_tensor) |
|
|
video_path = temp_video_path |
|
|
|
|
|
CHUNK = 5 |
|
|
classes_ls = [categorical_keywords[i:i + CHUNK] for i in range(0, len(categorical_keywords), CHUNK)] |
|
|
video_segments, oid_class_pred, _ = generate_masks_grounding_dino( |
|
|
self.grounding_model, |
|
|
box_threshold, |
|
|
text_threshold, |
|
|
self.sam_predictor, |
|
|
self.mask_generator, |
|
|
video_tensor, |
|
|
video_path, |
|
|
"temp_video", |
|
|
out_dir=tempfile.gettempdir(), |
|
|
classes_ls=classes_ls, |
|
|
target_fps=self.target_fps, |
|
|
visualize=self.debug_visualizations, |
|
|
frames=None, |
|
|
max_prop_time=10 |
|
|
) |
|
|
|
|
|
masks: Dict[int, Dict[int, torch.Tensor]] = {} |
|
|
bboxes: Dict[int, Dict[int, List[int]]] = {} |
|
|
|
|
|
|
|
|
for frame_id, frame_masks in video_segments.items(): |
|
|
masks[frame_id] = {} |
|
|
bboxes[frame_id] = {} |
|
|
|
|
|
for obj_id, mask in frame_masks.items(): |
|
|
if not isinstance(mask, torch.Tensor): |
|
|
mask = torch.tensor(mask) |
|
|
masks[frame_id][obj_id] = mask |
|
|
mask_np = mask.numpy() |
|
|
if mask_np.ndim == 3 and mask_np.shape[0] == 1: |
|
|
mask_np = np.squeeze(mask_np, axis=0) |
|
|
|
|
|
coords = np.where(mask_np > 0) |
|
|
if len(coords[0]) > 0: |
|
|
y1, y2 = coords[0].min(), coords[0].max() |
|
|
x1, x2 = coords[1].min(), coords[1].max() |
|
|
bboxes[frame_id][obj_id] = [x1, y1, x2, y2] |
|
|
|
|
|
|
|
|
if temp_video_path and os.path.exists(temp_video_path): |
|
|
os.remove(temp_video_path) |
|
|
|
|
|
vis_data: Dict[str, Any] = { |
|
|
"sam_masks": masks, |
|
|
"dino_labels": oid_class_pred, |
|
|
} |
|
|
return masks, bboxes, vis_data |
|
|
|
|
|
def _initialize_segmentation_models(self): |
|
|
"""Initialize segmentation models based on the requested method and configured paths.""" |
|
|
if (self.sam_predictor is None or self.mask_generator is None): |
|
|
self._initialize_sam2_models() |
|
|
|
|
|
if self.grounding_model is None: |
|
|
self._initialize_grounding_dino_model() |
|
|
|
|
|
def _initialize_sam2_models(self): |
|
|
"""Initialize SAM2 video predictor and mask generator.""" |
|
|
try: |
|
|
from sam2.build_sam import build_sam2_video_predictor, build_sam2 |
|
|
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator |
|
|
except ImportError as e: |
|
|
print(f"Warning: Could not import SAM2: {e}") |
|
|
return |
|
|
|
|
|
|
|
|
config_path, checkpoint_path = self._resolve_sam2_paths() |
|
|
|
|
|
|
|
|
if self.sam_config_path is not None and not os.path.exists(config_path): |
|
|
raise ValueError(f"SAM2 config path not found: {config_path}") |
|
|
if self.sam_checkpoint_path is not None and not os.path.exists(checkpoint_path): |
|
|
raise ValueError(f"SAM2 checkpoint path not found: {checkpoint_path}") |
|
|
|
|
|
|
|
|
if not os.path.exists(checkpoint_path): |
|
|
print(f"Warning: SAM2 checkpoint not found at {checkpoint_path}") |
|
|
print("SAM2 functionality will be unavailable") |
|
|
return |
|
|
|
|
|
try: |
|
|
device = self._device |
|
|
|
|
|
print(type(device)) |
|
|
|
|
|
self.sam_predictor = build_sam2_video_predictor( |
|
|
config_path, checkpoint_path, device=device |
|
|
) |
|
|
|
|
|
|
|
|
sam2_model = build_sam2(config_path, checkpoint_path, device=device, apply_postprocessing=False) |
|
|
self.mask_generator = SAM2AutomaticMaskGenerator( |
|
|
model=sam2_model, |
|
|
points_per_side=32, |
|
|
points_per_batch=32, |
|
|
pred_iou_thresh=0.7, |
|
|
stability_score_thresh=0.8, |
|
|
crop_n_layers=2, |
|
|
box_nms_thresh=0.6, |
|
|
crop_n_points_downscale_factor=2, |
|
|
min_mask_region_area=100, |
|
|
use_m2m=True, |
|
|
) |
|
|
print("✓ SAM2 models initialized successfully") |
|
|
|
|
|
except Exception as e: |
|
|
raise ValueError(f"Failed to initialize SAM2 with custom paths: {e}") |
|
|
|
|
|
def _initialize_grounding_dino_model(self): |
|
|
"""Initialize GroundingDINO model.""" |
|
|
try: |
|
|
from groundingdino.util.inference import Model as gd_Model |
|
|
except ImportError as e: |
|
|
print(f"Warning: Could not import GroundingDINO: {e}") |
|
|
return |
|
|
|
|
|
|
|
|
config_path, checkpoint_path = self._resolve_grounding_dino_paths() |
|
|
|
|
|
|
|
|
if self.gd_config_path is not None and not os.path.exists(config_path): |
|
|
raise ValueError(f"GroundingDINO config path not found: {config_path}") |
|
|
if self.gd_checkpoint_path is not None and not os.path.exists(checkpoint_path): |
|
|
raise ValueError(f"GroundingDINO checkpoint path not found: {checkpoint_path}") |
|
|
|
|
|
|
|
|
if not (os.path.exists(config_path) and os.path.exists(checkpoint_path)): |
|
|
print(f"Warning: GroundingDINO models not found at {config_path} / {checkpoint_path}") |
|
|
print("GroundingDINO functionality will be unavailable") |
|
|
return |
|
|
|
|
|
try: |
|
|
device = self._device |
|
|
print(type(device)) |
|
|
self.grounding_model = gd_Model( |
|
|
model_config_path=config_path, |
|
|
model_checkpoint_path=checkpoint_path, |
|
|
device=device |
|
|
) |
|
|
print("✓ GroundingDINO model initialized successfully") |
|
|
|
|
|
except Exception as e: |
|
|
raise ValueError(f"Failed to initialize GroundingDINO with custom paths: {e}") |
|
|
|
|
|
def _resolve_sam2_paths(self): |
|
|
"""Resolve SAM2 config and checkpoint paths.""" |
|
|
|
|
|
if self.sam_config_path and self.sam_checkpoint_path: |
|
|
return self.sam_config_path, self.sam_checkpoint_path |
|
|
|
|
|
def _resolve_grounding_dino_paths(self): |
|
|
"""Resolve GroundingDINO config and checkpoint paths.""" |
|
|
|
|
|
if self.gd_config_path and self.gd_checkpoint_path: |
|
|
return self.gd_config_path, self.gd_checkpoint_path |
|
|
|
|
|
|
|
|
def _prepare_visualization_dir(self, name: str, enabled: bool) -> Optional[str]: |
|
|
""" |
|
|
Ensure a directory exists for visualization artifacts and return it. |
|
|
If visualization is disabled, returns None. |
|
|
""" |
|
|
if not enabled: |
|
|
return None |
|
|
|
|
|
if self.visualization_dir: |
|
|
target_dir = os.path.join(self.visualization_dir, name) if name else self.visualization_dir |
|
|
os.makedirs(target_dir, exist_ok=True) |
|
|
return target_dir |
|
|
|
|
|
return tempfile.mkdtemp(prefix=f"vine_{name}_") |
|
|
|
|
|
def _create_temp_video(self, video_tensor: np.ndarray, base_dir: Optional[str] = None, prefix: str = "temp_video") -> str: |
|
|
"""Create a temporary video file from video tensor.""" |
|
|
if base_dir is None: |
|
|
base_dir = tempfile.mkdtemp(prefix=f"vine_{prefix}_") |
|
|
else: |
|
|
os.makedirs(base_dir, exist_ok=True) |
|
|
file_name = f"{prefix}_{uuid.uuid4().hex}.mp4" |
|
|
temp_path = os.path.join(base_dir, file_name) |
|
|
|
|
|
|
|
|
height, width = video_tensor.shape[1:3] |
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
|
out = cv2.VideoWriter(temp_path, fourcc, self.target_fps, (width, height)) |
|
|
|
|
|
for frame in video_tensor: |
|
|
|
|
|
if len(frame.shape) == 3 and frame.shape[2] == 3: |
|
|
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) |
|
|
else: |
|
|
frame_bgr = frame |
|
|
out.write(frame_bgr.astype(np.uint8)) |
|
|
|
|
|
out.release() |
|
|
return temp_path |
|
|
|
|
|
def _forward(self, model_inputs: Dict[str, Any], **forward_kwargs) -> Dict[str, Any]: |
|
|
"""Forward pass through the model.""" |
|
|
outputs = self.model.predict( |
|
|
video_frames=model_inputs["video_frames"], |
|
|
masks=model_inputs["masks"], |
|
|
bboxes=model_inputs["bboxes"], |
|
|
**forward_kwargs |
|
|
) |
|
|
outputs.setdefault("video_frames", model_inputs.get("video_frames")) |
|
|
outputs.setdefault("bboxes", model_inputs.get("bboxes")) |
|
|
outputs.setdefault("masks", model_inputs.get("masks")) |
|
|
outputs.setdefault("visualization_data", model_inputs.get("visualization_data")) |
|
|
return outputs |
|
|
|
|
|
def postprocess( |
|
|
self, |
|
|
model_outputs: Dict[str, Any], |
|
|
return_top_k: int = 3, |
|
|
visualize: Optional[bool] = None, |
|
|
**kwargs |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Postprocess model outputs into user-friendly format. |
|
|
|
|
|
Args: |
|
|
model_outputs: Raw model outputs |
|
|
return_top_k: Number of top predictions to return |
|
|
self.visualize: Whether to include visualization data |
|
|
|
|
|
Returns: |
|
|
Formatted results |
|
|
""" |
|
|
results = { |
|
|
"categorical_predictions": model_outputs.get("categorical_predictions", {}), |
|
|
"unary_predictions": model_outputs.get("unary_predictions", {}), |
|
|
"binary_predictions": model_outputs.get("binary_predictions", {}), |
|
|
"confidence_scores": model_outputs.get("confidence_scores", {}), |
|
|
"summary": self._generate_summary(model_outputs) |
|
|
} |
|
|
if "flattened_segments" in model_outputs: |
|
|
results["flattened_segments"] = model_outputs["flattened_segments"] |
|
|
if "valid_pairs" in model_outputs: |
|
|
results["valid_pairs"] = model_outputs["valid_pairs"] |
|
|
if "valid_pairs_metadata" in model_outputs: |
|
|
results["valid_pairs_metadata"] = model_outputs["valid_pairs_metadata"] |
|
|
if "visualization_data" in model_outputs: |
|
|
results["visualization_data"] = model_outputs["visualization_data"] |
|
|
|
|
|
if self.visualize and "video_frames" in model_outputs and "bboxes" in model_outputs: |
|
|
frames_tensor = model_outputs["video_frames"] |
|
|
if isinstance(frames_tensor, torch.Tensor): |
|
|
frames_np = frames_tensor.detach().cpu().numpy() |
|
|
else: |
|
|
frames_np = np.asarray(frames_tensor) |
|
|
if frames_np.dtype != np.uint8: |
|
|
if np.issubdtype(frames_np.dtype, np.floating): |
|
|
max_val = frames_np.max() if frames_np.size else 0.0 |
|
|
scale = 255.0 if max_val <= 1.0 else 1.0 |
|
|
frames_np = (frames_np * scale).clip(0, 255).astype(np.uint8) |
|
|
else: |
|
|
frames_np = frames_np.clip(0, 255).astype(np.uint8) |
|
|
|
|
|
cat_label_lookup: Dict[int, Tuple[str, float]] = {} |
|
|
for obj_id, preds in model_outputs.get("categorical_predictions", {}).items(): |
|
|
if preds: |
|
|
prob, label = preds[0] |
|
|
cat_label_lookup[obj_id] = (label, prob) |
|
|
|
|
|
unary_preds = model_outputs.get("unary_predictions", {}) |
|
|
unary_lookup: Dict[int, Dict[int, List[Tuple[float, str]]]] = {} |
|
|
for (frame_id, obj_id), preds in unary_preds.items(): |
|
|
if preds: |
|
|
unary_lookup.setdefault(frame_id, {})[obj_id] = preds |
|
|
|
|
|
binary_preds = model_outputs.get("binary_predictions", {}) |
|
|
binary_lookup: Dict[int, List[Tuple[Tuple[int, int], List[Tuple[float, str]]]]] = {} |
|
|
for (frame_id, obj_pair), preds in binary_preds.items(): |
|
|
if preds: |
|
|
binary_lookup.setdefault(frame_id, []).append((obj_pair, preds)) |
|
|
|
|
|
bboxes = model_outputs["bboxes"] |
|
|
visualization_data = model_outputs.get("visualization_data", {}) |
|
|
visualizations: Dict[str, Dict[str, Any]] = {} |
|
|
debug_visualizations = kwargs.get("debug_visualizations") |
|
|
if debug_visualizations is None: |
|
|
debug_visualizations = self.debug_visualizations |
|
|
|
|
|
vine_frame_sets = render_vine_frame_sets( |
|
|
frames_np, |
|
|
bboxes, |
|
|
cat_label_lookup, |
|
|
unary_lookup, |
|
|
binary_lookup, |
|
|
visualization_data.get("sam_masks"), |
|
|
) |
|
|
|
|
|
vine_visuals: Dict[str, Dict[str, Any]] = {} |
|
|
final_frames = vine_frame_sets.get("all", []) |
|
|
if final_frames: |
|
|
final_entry: Dict[str, Any] = {"frames": final_frames, "video_path": None} |
|
|
final_dir = self._prepare_visualization_dir("all", enabled=self.visualize) |
|
|
final_entry["video_path"] = self._create_temp_video( |
|
|
np.stack(final_frames, axis=0), |
|
|
base_dir=final_dir, |
|
|
prefix="all_visualization" |
|
|
) |
|
|
vine_visuals["all"] = final_entry |
|
|
|
|
|
if debug_visualizations: |
|
|
sam_masks = visualization_data.get("sam_masks") |
|
|
if sam_masks: |
|
|
sam_frames = render_sam_frames(frames_np, sam_masks, visualization_data.get("dino_labels")) |
|
|
sam_entry = {"frames": sam_frames, "video_path": None} |
|
|
if sam_frames: |
|
|
sam_dir = self._prepare_visualization_dir("sam", enabled=self.visualize) |
|
|
sam_entry["video_path"] = self._create_temp_video( |
|
|
np.stack(sam_frames, axis=0), |
|
|
base_dir=sam_dir, |
|
|
prefix="sam_visualization" |
|
|
) |
|
|
visualizations["sam"] = sam_entry |
|
|
|
|
|
dino_labels = visualization_data.get("dino_labels") |
|
|
if dino_labels: |
|
|
dino_frames = render_dino_frames(frames_np, bboxes, dino_labels) |
|
|
dino_entry = {"frames": dino_frames, "video_path": None} |
|
|
if dino_frames: |
|
|
dino_dir = self._prepare_visualization_dir("dino", enabled=self.visualize) |
|
|
dino_entry["video_path"] = self._create_temp_video( |
|
|
np.stack(dino_frames, axis=0), |
|
|
base_dir=dino_dir, |
|
|
prefix="dino_visualization" |
|
|
) |
|
|
visualizations["dino"] = dino_entry |
|
|
|
|
|
for name in ("object", "unary", "binary"): |
|
|
frames_list = vine_frame_sets.get(name, []) |
|
|
entry: Dict[str, Any] = {"frames": frames_list, "video_path": None} |
|
|
if frames_list: |
|
|
vine_dir = self._prepare_visualization_dir(name, enabled=self.visualize) |
|
|
entry["video_path"] = self._create_temp_video( |
|
|
np.stack(frames_list, axis=0), |
|
|
base_dir=vine_dir, |
|
|
prefix=f"{name}_visualization" |
|
|
) |
|
|
vine_visuals[name] = entry |
|
|
|
|
|
if vine_visuals: |
|
|
visualizations["vine"] = vine_visuals |
|
|
|
|
|
if visualizations: |
|
|
results["visualizations"] = visualizations |
|
|
|
|
|
return results |
|
|
|
|
|
def _generate_summary(self, model_outputs: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Generate a summary of the predictions.""" |
|
|
categorical_preds = model_outputs.get("categorical_predictions", {}) |
|
|
unary_preds = model_outputs.get("unary_predictions", {}) |
|
|
binary_preds = model_outputs.get("binary_predictions", {}) |
|
|
|
|
|
summary = { |
|
|
"num_objects_detected": len(categorical_preds), |
|
|
"num_unary_predictions": len(unary_preds), |
|
|
"num_binary_predictions": len(binary_preds), |
|
|
"top_categories": [], |
|
|
"top_actions": [], |
|
|
"top_relations": [] |
|
|
} |
|
|
|
|
|
|
|
|
all_categories = [] |
|
|
for obj_preds in categorical_preds.values(): |
|
|
if obj_preds: |
|
|
all_categories.extend(obj_preds) |
|
|
|
|
|
if all_categories: |
|
|
sorted_categories = sorted(all_categories, reverse=True) |
|
|
summary["top_categories"] = [(cat, prob) for prob, cat in sorted_categories[:3]] |
|
|
|
|
|
|
|
|
all_actions = [] |
|
|
for action_preds in unary_preds.values(): |
|
|
if action_preds: |
|
|
all_actions.extend(action_preds) |
|
|
|
|
|
if all_actions: |
|
|
sorted_actions = sorted(all_actions, reverse=True) |
|
|
summary["top_actions"] = [(act, prob) for prob, act in sorted_actions[:3]] |
|
|
|
|
|
|
|
|
all_relations = [] |
|
|
for rel_preds in binary_preds.values(): |
|
|
if rel_preds: |
|
|
all_relations.extend(rel_preds) |
|
|
|
|
|
if all_relations: |
|
|
sorted_relations = sorted(all_relations, reverse=True) |
|
|
summary["top_relations"] = [(rel, prob) for prob, rel in sorted_relations[:3]] |
|
|
|
|
|
return summary |
|
|
|