from __future__ import annotations import cv2 import numpy as np import torch from torch_geometric.data import Data def video_to_graph(video_path: str, patch_size: int = 16, max_frames: int = 32) -> Data: frames = _extract_frames(video_path, max_frames=max_frames) if not frames: raise ValueError("Could not extract frames from video") frames = _pad_frames(frames, max_frames) node_features, temporal_features, rows, cols = _patch_features(frames, patch_size) edge_index = _grid_edges(rows, cols) return Data( x=torch.tensor(node_features, dtype=torch.float32), x_temporal=torch.tensor(temporal_features, dtype=torch.float32), edge_index=torch.tensor(edge_index, dtype=torch.long), ) def _extract_frames(video_path: str, max_frames: int) -> list[np.ndarray]: cap = cv2.VideoCapture(video_path) total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) if total > 0: indices = set(np.linspace(0, max(total - 1, 0), max_frames, dtype=int).tolist()) else: indices = set(range(max_frames)) frames = [] current = 0 while cap.isOpened() and len(frames) < max_frames: ret, frame = cap.read() if not ret: break if current in indices: rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(cv2.resize(rgb, (128, 128))) current += 1 cap.release() return frames def _pad_frames(frames: list[np.ndarray], max_frames: int) -> list[np.ndarray]: if len(frames) >= max_frames: return frames[:max_frames] return frames + [frames[-1]] * (max_frames - len(frames)) def _patch_features(frames: list[np.ndarray], patch_size: int): stack = np.stack(frames, axis=0).astype(np.float32) / 255.0 frame_count, height, width, _ = stack.shape rows = height // patch_size cols = width // patch_size node_features = [] temporal_features = [] for row in range(rows): for col in range(cols): patch = stack[ :, row * patch_size : (row + 1) * patch_size, col * patch_size : (col + 1) * patch_size, :, ] means = patch.mean(axis=(0, 1, 2)) stds = patch.std(axis=(0, 1, 2)) diff = np.abs(np.diff(patch, axis=0)).mean() if frame_count > 1 else 0.0 node_features.append( [ float(means[0]), float(means[1]), float(means[2]), float(stds[0]), float(stds[1]), float(stds[2]), float(diff), float((row * cols + col) / max(rows * cols - 1, 1)), ] ) temporal = patch.mean(axis=(1, 2, 3)) temporal_features.append(temporal.astype(np.float32)) return np.array(node_features), np.array(temporal_features), rows, cols def _grid_edges(rows: int, cols: int) -> list[list[int]]: src = [] dst = [] def nid(row: int, col: int) -> int: return row * cols + col for row in range(rows): for col in range(cols): current = nid(row, col) src.append(current) dst.append(current) if col + 1 < cols: right = nid(row, col + 1) src.extend([current, right]) dst.extend([right, current]) if row + 1 < rows: down = nid(row + 1, col) src.extend([current, down]) dst.extend([down, current]) return [src, dst]