WorldMem_Repro / datasets /video /minecraft_video_dataset.py
BonanDing's picture
Reproduce Training & Fix distributed eval
681f346
import os
import io
import tarfile
import numpy as np
import torch
from typing import Sequence, Mapping
from omegaconf import DictConfig
from pytorchvideo.data.encoded_video import EncodedVideo
import random
from .base_video_dataset import BaseVideoDataset
ACTION_KEYS = [
"inventory",
"ESC",
"hotbar.1",
"hotbar.2",
"hotbar.3",
"hotbar.4",
"hotbar.5",
"hotbar.6",
"hotbar.7",
"hotbar.8",
"hotbar.9",
"forward",
"back",
"left",
"right",
"cameraY",
"cameraX",
"jump",
"sneak",
"sprint",
"swapHands",
"attack",
"use",
"pickItem",
"drop",
]
def convert_action_space(actions):
vec_25 = torch.zeros(len(actions), len(ACTION_KEYS))
vec_25[actions[:,0]==1, 11] = 1
vec_25[actions[:,0]==2, 12] = 1
vec_25[actions[:,4]==11, 16] = -1
vec_25[actions[:,4]==13, 16] = 1
vec_25[actions[:,3]==11, 15] = -1
vec_25[actions[:,3]==13, 15] = 1
vec_25[actions[:,5]==6, 24] = 1
vec_25[actions[:,5]==1, 24] = 1
vec_25[actions[:,1]==1, 13] = 1
vec_25[actions[:,1]==2, 14] = 1
vec_25[actions[:,7]==1, 2] = 1
return vec_25
# Dataset class
class MinecraftVideoDataset(BaseVideoDataset):
"""
Minecraft video dataset for training and validation.
Args:
cfg (DictConfig): Configuration object.
split (str): Dataset split ("training" or "validation").
"""
def __init__(self, cfg: DictConfig, split: str = "training"):
self.wo_updown = getattr(cfg, "wo_updown", False)
super().__init__(cfg, split)
self.n_frames = cfg.n_frames_valid if split == "validation" or split == "test" and hasattr(cfg, "n_frames_valid") else cfg.n_frames
self.memory_condition_length = getattr(cfg, "memory_condition_length", 8)
self.customized_validation = cfg.customized_validation
if split == "training":
self.angle_range = cfg.angle_range
self.pos_range = cfg.pos_range
self.add_timestamp_embedding = getattr(cfg, "add_timestamp_embedding", True)
self.training_dropout = 0.1
self.sample_more_event = getattr(cfg, "sample_more_event", False)
self.causal_frame = getattr(cfg, "causal_frame", False)
def get_data_paths(self, split: str):
"""
Retrieve all video file paths for the given split.
Args:
split (str): Dataset split ("training" or "validation").
Returns:
List[Path]: List of video file paths.
"""
data_dir = self.save_dir / split
paths = sorted(list(data_dir.glob("**/*.mp4")), key=lambda x: x.name)
if self.wo_updown:
# Filter out paths containing "w_updown"
paths = [p for p in paths if "w_updown" not in str(p)]
if (split == "validation" or split == "test") and self.wo_updown:
paths = [p for p in paths if "w_updown" not in str(p)]
elif split == "validation" or split == "test":
paths = [p for p in paths if "w_updown" in str(p)]
if not paths:
sub_dirs = os.listdir(data_dir)
for sub_dir in sub_dirs:
sub_path = data_dir / sub_dir
paths += sorted(list(sub_path.glob("**/*.mp4")), key=lambda x: x.name)
return paths
def download_dataset(self):
pass
def __getitem__(self, idx: int):
"""
Retrieve a single data sample by index.
Args:
idx (int): Index of the data sample.
Returns:
Tuple[torch.Tensor, torch.Tensor, np.ndarray, np.ndarray]: Video, actions, poses, and timestamps.
"""
max_retries = 1000
for _ in range(max_retries):
try:
return self.load_data(idx)
except Exception as e:
print(f"Retrying due to error: {e}")
idx = (idx + 1) % len(self)
def load_data(self, idx):
# === 1. Remap index and skip first few frames ===
idx = self.idx_remap[idx]
file_idx, frame_idx = self.split_idx(idx)
frame_idx += 100 # initial few frames are low quality
# === 2. Load paths and data arrays ===
video_path = self.data_paths[file_idx]
action_path = video_path.with_suffix(".npz")
data = np.load(action_path)
actions_pool = convert_action_space(data["actions"])
poses_pool = data["poses"]
# Fix corrupted height (maybe) in the first frame
poses_pool[0, 1] = poses_pool[1, 1]
# assert poses_pool[:, 1].ptp() < 2, f"Height variation too large: {poses_pool[:, 1].ptp()} - {video_path}"
assert poses_pool[:, 1].ptp() < 2
# Pad poses if shorter than actions
if len(poses_pool) < len(actions_pool):
poses_pool = np.pad(poses_pool, ((1, 0), (0, 0)))
# === 3. Load video clip ===
video_raw = EncodedVideo.from_path(video_path, decode_audio=False)
fps = 10
clip = video_raw.get_clip(
start_sec=frame_idx / fps,
end_sec=(frame_idx + self.n_frames) / fps
)["video"]
video = clip.permute(1, 2, 3, 0).numpy()
actions = np.copy(actions_pool[frame_idx : frame_idx + self.n_frames])
poses = np.copy(poses_pool[frame_idx : frame_idx + self.n_frames])
# === 4. Normalize poses relative to current segment ===
def normalize_pose(pose, ref_pose):
pose[:, :3] -= ref_pose[:1, :3]
pose[:, -1] = -pose[:, -1]
pose[:, 3:] %= 360
return pose
poses_pool = normalize_pose(poses_pool, poses)
poses = normalize_pose(poses, poses)
assert len(video) >= self.n_frames, f"{video_path}"
# === 5. Sample memory frames for training ===
if self.split == "training" and self.memory_condition_length > 0:
use_memory = random.uniform(0, 1) > self.training_dropout
if use_memory:
# Compute pose distance between current and candidate frames
dis = np.abs(poses[:, None] - poses_pool[None, :])
dis[..., 3:][dis[..., 3:] > 180] = 360 - dis[..., 3:][dis[..., 3:] > 180]
spatial_match = (dis[..., :3] <= self.pos_range).sum(-1) >= 3 # X, Y, Z axis all within range
angular_match = (dis[..., 3:] <= self.angle_range).sum(-1) >= 2 # Pitch, Yaw all within range
not_exact_match = ((dis[..., :3] > 0).sum(-1) >= 1) | ((dis[..., 3:] > 0).sum(-1) >= 1) # At least one axis is in range
valid_index = (spatial_match & angular_match & not_exact_match).sum(0)
valid_index[:100] = 0 # skip unstable early frames
# Exclude future if causality and timestamp are enabled
if self.add_timestamp_embedding and self.causal_frame and (actions_pool[:frame_idx, 24] == 1).sum() > 0:
valid_index[frame_idx:] = 0
# Select indices satisfying condition
mask = valid_index >= 1
mask[0] = False
candidate_indices = np.argwhere(mask)
# Backup candidates with weaker conditions
mask2 = valid_index >= 0
mask2[0] = False
count = min(self.memory_condition_length, candidate_indices.shape[0])
selected = candidate_indices[np.random.choice(candidate_indices.shape[0], count, replace=True)][:, 0]
if count < self.memory_condition_length:
extra = np.argwhere(mask2)
extra = extra[np.random.choice(extra.shape[0], self.memory_condition_length - count, replace=True)][:, 0]
selected = np.concatenate([selected, extra])
# Prioritize event-trigger frames if applicable
if self.sample_more_event and random.uniform(0, 1) > 0.3:
event_idx = torch.nonzero(actions_pool[:frame_idx, 24] == 1)[:, 0]
if len(event_idx) > self.memory_condition_length // 2:
event_idx = event_idx[-self.memory_condition_length // 2:]
if len(event_idx) > 0:
selected[-len(event_idx):] = event_idx + 4
else:
selected = np.full(self.memory_condition_length, random.randint(0, frame_idx))
# === 6. Retrieve video frames for selected memory indices ===
video_pool = []
for si in selected:
frame = video_raw.get_clip(start_sec=si / fps, end_sec=(si + 1) / fps)["video"][:, 0].permute(1, 2, 0)
video_pool.append(frame)
video = np.concatenate([video, np.stack(video_pool)], axis=0)
actions = np.concatenate([actions, actions_pool[selected]], axis=0)
poses = np.concatenate([poses, poses_pool[selected]], axis=0)
timestamp = np.concatenate([np.arange(frame_idx, frame_idx + self.n_frames), selected])
else:
timestamp = np.arange(self.n_frames)
# === 7. Convert video to torch format ===
video = torch.from_numpy(video / 255.0).float().permute(0, 3, 1, 2).contiguous() # (T, H, W, C) -> (T, C, H, W)
# === 9. Return all items ===
return (
video[:: self.frame_skip],
actions[:: self.frame_skip],
poses[:: self.frame_skip],
timestamp
)