File size: 3,640 Bytes
cf54850
 
eff3d67
cf54850
eff3d67
 
cf54850
 
eff3d67
 
cf54850
 
 
eff3d67
 
 
 
 
 
 
 
cf54850
 
 
eff3d67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf54850
eff3d67
 
 
cf54850
eff3d67
 
cf54850
eff3d67
 
 
 
 
 
 
 
 
 
 
 
 
cf54850
eff3d67
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from __future__ import annotations

import cv2
import numpy as np
import torch
from torch_geometric.data import Data


def video_to_graph(video_path: str, patch_size: int = 16, max_frames: int = 32) -> Data:
    frames = _extract_frames(video_path, max_frames=max_frames)
    if not frames:
        raise ValueError("Could not extract frames from video")

    frames = _pad_frames(frames, max_frames)
    node_features, temporal_features, rows, cols = _patch_features(frames, patch_size)
    edge_index = _grid_edges(rows, cols)

    return Data(
        x=torch.tensor(node_features, dtype=torch.float32),
        x_temporal=torch.tensor(temporal_features, dtype=torch.float32),
        edge_index=torch.tensor(edge_index, dtype=torch.long),
    )


def _extract_frames(video_path: str, max_frames: int) -> list[np.ndarray]:
    cap = cv2.VideoCapture(video_path)
    total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    if total > 0:
        indices = set(np.linspace(0, max(total - 1, 0), max_frames, dtype=int).tolist())
    else:
        indices = set(range(max_frames))

    frames = []
    current = 0
    while cap.isOpened() and len(frames) < max_frames:
        ret, frame = cap.read()
        if not ret:
            break
        if current in indices:
            rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.append(cv2.resize(rgb, (128, 128)))
        current += 1
    cap.release()
    return frames


def _pad_frames(frames: list[np.ndarray], max_frames: int) -> list[np.ndarray]:
    if len(frames) >= max_frames:
        return frames[:max_frames]
    return frames + [frames[-1]] * (max_frames - len(frames))


def _patch_features(frames: list[np.ndarray], patch_size: int):
    stack = np.stack(frames, axis=0).astype(np.float32) / 255.0
    frame_count, height, width, _ = stack.shape
    rows = height // patch_size
    cols = width // patch_size

    node_features = []
    temporal_features = []
    for row in range(rows):
        for col in range(cols):
            patch = stack[
                :,
                row * patch_size : (row + 1) * patch_size,
                col * patch_size : (col + 1) * patch_size,
                :,
            ]
            means = patch.mean(axis=(0, 1, 2))
            stds = patch.std(axis=(0, 1, 2))
            diff = np.abs(np.diff(patch, axis=0)).mean() if frame_count > 1 else 0.0
            node_features.append(
                [
                    float(means[0]),
                    float(means[1]),
                    float(means[2]),
                    float(stds[0]),
                    float(stds[1]),
                    float(stds[2]),
                    float(diff),
                    float((row * cols + col) / max(rows * cols - 1, 1)),
                ]
            )

            temporal = patch.mean(axis=(1, 2, 3))
            temporal_features.append(temporal.astype(np.float32))

    return np.array(node_features), np.array(temporal_features), rows, cols


def _grid_edges(rows: int, cols: int) -> list[list[int]]:
    src = []
    dst = []

    def nid(row: int, col: int) -> int:
        return row * cols + col

    for row in range(rows):
        for col in range(cols):
            current = nid(row, col)
            src.append(current)
            dst.append(current)
            if col + 1 < cols:
                right = nid(row, col + 1)
                src.extend([current, right])
                dst.extend([right, current])
            if row + 1 < rows:
                down = nid(row + 1, col)
                src.extend([current, down])
                dst.extend([down, current])

    return [src, dst]