Spaces:
Paused
Paused
| from __future__ import annotations | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from torch_geometric.data import Data | |
| def video_to_graph(video_path: str, patch_size: int = 16, max_frames: int = 32) -> Data: | |
| frames = _extract_frames(video_path, max_frames=max_frames) | |
| if not frames: | |
| raise ValueError("Could not extract frames from video") | |
| frames = _pad_frames(frames, max_frames) | |
| node_features, temporal_features, rows, cols = _patch_features(frames, patch_size) | |
| edge_index = _grid_edges(rows, cols) | |
| return Data( | |
| x=torch.tensor(node_features, dtype=torch.float32), | |
| x_temporal=torch.tensor(temporal_features, dtype=torch.float32), | |
| edge_index=torch.tensor(edge_index, dtype=torch.long), | |
| ) | |
| def _extract_frames(video_path: str, max_frames: int) -> list[np.ndarray]: | |
| cap = cv2.VideoCapture(video_path) | |
| total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| if total > 0: | |
| indices = set(np.linspace(0, max(total - 1, 0), max_frames, dtype=int).tolist()) | |
| else: | |
| indices = set(range(max_frames)) | |
| frames = [] | |
| current = 0 | |
| while cap.isOpened() and len(frames) < max_frames: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| if current in indices: | |
| rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frames.append(cv2.resize(rgb, (128, 128))) | |
| current += 1 | |
| cap.release() | |
| return frames | |
| def _pad_frames(frames: list[np.ndarray], max_frames: int) -> list[np.ndarray]: | |
| if len(frames) >= max_frames: | |
| return frames[:max_frames] | |
| return frames + [frames[-1]] * (max_frames - len(frames)) | |
| def _patch_features(frames: list[np.ndarray], patch_size: int): | |
| stack = np.stack(frames, axis=0).astype(np.float32) / 255.0 | |
| frame_count, height, width, _ = stack.shape | |
| rows = height // patch_size | |
| cols = width // patch_size | |
| node_features = [] | |
| temporal_features = [] | |
| for row in range(rows): | |
| for col in range(cols): | |
| patch = stack[ | |
| :, | |
| row * patch_size : (row + 1) * patch_size, | |
| col * patch_size : (col + 1) * patch_size, | |
| :, | |
| ] | |
| means = patch.mean(axis=(0, 1, 2)) | |
| stds = patch.std(axis=(0, 1, 2)) | |
| diff = np.abs(np.diff(patch, axis=0)).mean() if frame_count > 1 else 0.0 | |
| node_features.append( | |
| [ | |
| float(means[0]), | |
| float(means[1]), | |
| float(means[2]), | |
| float(stds[0]), | |
| float(stds[1]), | |
| float(stds[2]), | |
| float(diff), | |
| float((row * cols + col) / max(rows * cols - 1, 1)), | |
| ] | |
| ) | |
| temporal = patch.mean(axis=(1, 2, 3)) | |
| temporal_features.append(temporal.astype(np.float32)) | |
| return np.array(node_features), np.array(temporal_features), rows, cols | |
| def _grid_edges(rows: int, cols: int) -> list[list[int]]: | |
| src = [] | |
| dst = [] | |
| def nid(row: int, col: int) -> int: | |
| return row * cols + col | |
| for row in range(rows): | |
| for col in range(cols): | |
| current = nid(row, col) | |
| src.append(current) | |
| dst.append(current) | |
| if col + 1 < cols: | |
| right = nid(row, col + 1) | |
| src.extend([current, right]) | |
| dst.extend([right, current]) | |
| if row + 1 < rows: | |
| down = nid(row + 1, col) | |
| src.extend([current, down]) | |
| dst.extend([down, current]) | |
| return [src, dst] | |