deepdetection / utils /graph.py
akagtag's picture
Implement ZeroGPU Space runtime
eff3d67
from __future__ import annotations
import cv2
import numpy as np
import torch
from torch_geometric.data import Data
def video_to_graph(video_path: str, patch_size: int = 16, max_frames: int = 32) -> Data:
frames = _extract_frames(video_path, max_frames=max_frames)
if not frames:
raise ValueError("Could not extract frames from video")
frames = _pad_frames(frames, max_frames)
node_features, temporal_features, rows, cols = _patch_features(frames, patch_size)
edge_index = _grid_edges(rows, cols)
return Data(
x=torch.tensor(node_features, dtype=torch.float32),
x_temporal=torch.tensor(temporal_features, dtype=torch.float32),
edge_index=torch.tensor(edge_index, dtype=torch.long),
)
def _extract_frames(video_path: str, max_frames: int) -> list[np.ndarray]:
cap = cv2.VideoCapture(video_path)
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if total > 0:
indices = set(np.linspace(0, max(total - 1, 0), max_frames, dtype=int).tolist())
else:
indices = set(range(max_frames))
frames = []
current = 0
while cap.isOpened() and len(frames) < max_frames:
ret, frame = cap.read()
if not ret:
break
if current in indices:
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(cv2.resize(rgb, (128, 128)))
current += 1
cap.release()
return frames
def _pad_frames(frames: list[np.ndarray], max_frames: int) -> list[np.ndarray]:
if len(frames) >= max_frames:
return frames[:max_frames]
return frames + [frames[-1]] * (max_frames - len(frames))
def _patch_features(frames: list[np.ndarray], patch_size: int):
stack = np.stack(frames, axis=0).astype(np.float32) / 255.0
frame_count, height, width, _ = stack.shape
rows = height // patch_size
cols = width // patch_size
node_features = []
temporal_features = []
for row in range(rows):
for col in range(cols):
patch = stack[
:,
row * patch_size : (row + 1) * patch_size,
col * patch_size : (col + 1) * patch_size,
:,
]
means = patch.mean(axis=(0, 1, 2))
stds = patch.std(axis=(0, 1, 2))
diff = np.abs(np.diff(patch, axis=0)).mean() if frame_count > 1 else 0.0
node_features.append(
[
float(means[0]),
float(means[1]),
float(means[2]),
float(stds[0]),
float(stds[1]),
float(stds[2]),
float(diff),
float((row * cols + col) / max(rows * cols - 1, 1)),
]
)
temporal = patch.mean(axis=(1, 2, 3))
temporal_features.append(temporal.astype(np.float32))
return np.array(node_features), np.array(temporal_features), rows, cols
def _grid_edges(rows: int, cols: int) -> list[list[int]]:
src = []
dst = []
def nid(row: int, col: int) -> int:
return row * cols + col
for row in range(rows):
for col in range(cols):
current = nid(row, col)
src.append(current)
dst.append(current)
if col + 1 < cols:
right = nid(row, col + 1)
src.extend([current, right])
dst.extend([right, current])
if row + 1 < rows:
down = nid(row + 1, col)
src.extend([current, down])
dst.extend([down, current])
return [src, dst]