File size: 1,146 Bytes
f3d0a26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# utils/video_utils.py
import cv2
import numpy as np
import imageio
import torch

def load_video(path: str, max_frames: int = 81) -> np.ndarray:
    """
    Returns: [T, H, W, 3] uint8 RGB array
    """
    cap = cv2.VideoCapture(path)
    frames = []
    while len(frames) < max_frames:
        ret, frame = cap.read()
        if not ret:
            break
        frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
    cap.release()
    return np.stack(frames)

def save_video(frames: np.ndarray, path: str, fps: int = 24):
    """
    frames: [T, H, W, 3] uint8 RGB
    """
    writer = imageio.get_writer(path, fps=fps)
    for frame in frames:
        writer.append_data(frame)
    writer.close()

def frames_to_tensor(frames: np.ndarray) -> torch.Tensor:
    """
    [T, H, W, 3] uint8 → [T, 3, H, W] float32 in [-1, 1]
    """
    t = torch.from_numpy(frames).float() / 127.5 - 1.0
    return t.permute(0, 3, 1, 2)

def tensor_to_frames(t: torch.Tensor) -> np.ndarray:
    """
    [T, 3, H, W] float32 in [-1, 1] → [T, H, W, 3] uint8
    """
    t = ((t + 1.0) * 127.5).clamp(0, 255)
    return t.permute(0, 2, 3, 1).byte().numpy()