File size: 9,462 Bytes
8652b14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
681f346
8652b14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
681f346
 
 
8652b14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
681f346
8652b14
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
import os
import io
import tarfile
import numpy as np
import torch
from typing import Sequence, Mapping
from omegaconf import DictConfig
from pytorchvideo.data.encoded_video import EncodedVideo
import random

from .base_video_dataset import BaseVideoDataset




ACTION_KEYS = [
    "inventory",
    "ESC",
    "hotbar.1",
    "hotbar.2",
    "hotbar.3",
    "hotbar.4",
    "hotbar.5",
    "hotbar.6",
    "hotbar.7",
    "hotbar.8",
    "hotbar.9",
    "forward",
    "back",
    "left",
    "right",
    "cameraY",
    "cameraX",
    "jump",
    "sneak",
    "sprint",
    "swapHands",
    "attack",
    "use",
    "pickItem",
    "drop",
]

def convert_action_space(actions):
    vec_25 = torch.zeros(len(actions), len(ACTION_KEYS))
    vec_25[actions[:,0]==1, 11] = 1
    vec_25[actions[:,0]==2, 12] = 1
    vec_25[actions[:,4]==11, 16] = -1
    vec_25[actions[:,4]==13, 16] = 1
    vec_25[actions[:,3]==11, 15] = -1
    vec_25[actions[:,3]==13, 15] = 1
    vec_25[actions[:,5]==6, 24] = 1
    vec_25[actions[:,5]==1, 24] = 1
    vec_25[actions[:,1]==1, 13] = 1
    vec_25[actions[:,1]==2, 14] = 1
    vec_25[actions[:,7]==1, 2] = 1
    return vec_25

# Dataset class
class MinecraftVideoDataset(BaseVideoDataset):
    """
    Minecraft video dataset for training and validation.

    Args:
        cfg (DictConfig): Configuration object.
        split (str): Dataset split ("training" or "validation").
    """
    def __init__(self, cfg: DictConfig, split: str = "training"):
        self.wo_updown = getattr(cfg, "wo_updown", False)
        super().__init__(cfg, split)
        self.n_frames = cfg.n_frames_valid if split == "validation" or split == "test" and hasattr(cfg, "n_frames_valid") else cfg.n_frames
        self.memory_condition_length = getattr(cfg, "memory_condition_length", 8)
        self.customized_validation = cfg.customized_validation
        if split == "training":
            self.angle_range = cfg.angle_range
            self.pos_range = cfg.pos_range
        self.add_timestamp_embedding = getattr(cfg, "add_timestamp_embedding", True)
        self.training_dropout = 0.1
        self.sample_more_event = getattr(cfg, "sample_more_event", False)
        self.causal_frame = getattr(cfg, "causal_frame", False)
        
    def get_data_paths(self, split: str):
        """
        Retrieve all video file paths for the given split.

        Args:
            split (str): Dataset split ("training" or "validation").

        Returns:
            List[Path]: List of video file paths.
        """
        data_dir = self.save_dir / split
        paths = sorted(list(data_dir.glob("**/*.mp4")), key=lambda x: x.name)

        if self.wo_updown:
            # Filter out paths containing "w_updown"
            paths = [p for p in paths if "w_updown" not in str(p)]
        
        if (split == "validation" or split == "test") and self.wo_updown:
            paths = [p for p in paths if "w_updown" not in str(p)]
        elif split == "validation" or split == "test":
            paths = [p for p in paths if "w_updown" in str(p)]

        if not paths:
            sub_dirs = os.listdir(data_dir)
            for sub_dir in sub_dirs:
                sub_path = data_dir / sub_dir
                paths += sorted(list(sub_path.glob("**/*.mp4")), key=lambda x: x.name)
        return paths

    def download_dataset(self):
        pass
    
    def __getitem__(self, idx: int):
        """
        Retrieve a single data sample by index.

        Args:
            idx (int): Index of the data sample.

        Returns:
            Tuple[torch.Tensor, torch.Tensor, np.ndarray, np.ndarray]: Video, actions, poses, and timestamps.
        """
        max_retries = 1000
        for _ in range(max_retries):
            try:
                return self.load_data(idx)
            except Exception as e:
                print(f"Retrying due to error: {e}")
                idx = (idx + 1) % len(self)

    def load_data(self, idx):
        # === 1. Remap index and skip first few frames ===
        idx = self.idx_remap[idx]
        file_idx, frame_idx = self.split_idx(idx)
        frame_idx += 100 # initial few frames are low quality

        # === 2. Load paths and data arrays ===
        video_path = self.data_paths[file_idx]
        action_path = video_path.with_suffix(".npz")
        data = np.load(action_path)
        actions_pool = convert_action_space(data["actions"])
        poses_pool = data["poses"]

        # Fix corrupted height (maybe) in the first frame
        poses_pool[0, 1] = poses_pool[1, 1]
        # assert poses_pool[:, 1].ptp() < 2, f"Height variation too large: {poses_pool[:, 1].ptp()} - {video_path}"
        assert poses_pool[:, 1].ptp() < 2

        # Pad poses if shorter than actions
        if len(poses_pool) < len(actions_pool):
            poses_pool = np.pad(poses_pool, ((1, 0), (0, 0)))

        # === 3. Load video clip ===
        video_raw = EncodedVideo.from_path(video_path, decode_audio=False)
        fps = 10
        clip = video_raw.get_clip(
            start_sec=frame_idx / fps,
            end_sec=(frame_idx + self.n_frames) / fps
        )["video"]
        video = clip.permute(1, 2, 3, 0).numpy()

        actions = np.copy(actions_pool[frame_idx : frame_idx + self.n_frames])
        poses = np.copy(poses_pool[frame_idx : frame_idx + self.n_frames])

        # === 4. Normalize poses relative to current segment ===
        def normalize_pose(pose, ref_pose):
            pose[:, :3] -= ref_pose[:1, :3]
            pose[:, -1] = -pose[:, -1]
            pose[:, 3:] %= 360
            return pose

        poses_pool = normalize_pose(poses_pool, poses)
        poses = normalize_pose(poses, poses)

        assert len(video) >= self.n_frames, f"{video_path}"

        # === 5. Sample memory frames for training ===
        if self.split == "training" and self.memory_condition_length > 0:
            use_memory = random.uniform(0, 1) > self.training_dropout

            if use_memory:
                # Compute pose distance between current and candidate frames
                dis = np.abs(poses[:, None] - poses_pool[None, :])
                dis[..., 3:][dis[..., 3:] > 180] = 360 - dis[..., 3:][dis[..., 3:] > 180]

                spatial_match = (dis[..., :3] <= self.pos_range).sum(-1) >= 3 # X, Y, Z axis all within range
                angular_match = (dis[..., 3:] <= self.angle_range).sum(-1) >= 2 # Pitch, Yaw all within range
                not_exact_match = ((dis[..., :3] > 0).sum(-1) >= 1) | ((dis[..., 3:] > 0).sum(-1) >= 1) # At least one axis is in range

                valid_index = (spatial_match & angular_match & not_exact_match).sum(0)
                valid_index[:100] = 0  # skip unstable early frames

                # Exclude future if causality and timestamp are enabled
                if self.add_timestamp_embedding and self.causal_frame and (actions_pool[:frame_idx, 24] == 1).sum() > 0:
                    valid_index[frame_idx:] = 0

                # Select indices satisfying condition
                mask = valid_index >= 1
                mask[0] = False
                candidate_indices = np.argwhere(mask)

                # Backup candidates with weaker conditions
                mask2 = valid_index >= 0
                mask2[0] = False

                count = min(self.memory_condition_length, candidate_indices.shape[0])
                selected = candidate_indices[np.random.choice(candidate_indices.shape[0], count, replace=True)][:, 0]

                if count < self.memory_condition_length:
                    extra = np.argwhere(mask2)
                    extra = extra[np.random.choice(extra.shape[0], self.memory_condition_length - count, replace=True)][:, 0]
                    selected = np.concatenate([selected, extra])

                # Prioritize event-trigger frames if applicable
                if self.sample_more_event and random.uniform(0, 1) > 0.3:
                    event_idx = torch.nonzero(actions_pool[:frame_idx, 24] == 1)[:, 0]
                    if len(event_idx) > self.memory_condition_length // 2:
                        event_idx = event_idx[-self.memory_condition_length // 2:]
                    if len(event_idx) > 0:
                        selected[-len(event_idx):] = event_idx + 4

            else:
                selected = np.full(self.memory_condition_length, random.randint(0, frame_idx))

            # === 6. Retrieve video frames for selected memory indices ===
            video_pool = []
            for si in selected:
                frame = video_raw.get_clip(start_sec=si / fps, end_sec=(si + 1) / fps)["video"][:, 0].permute(1, 2, 0)
                video_pool.append(frame)

            video = np.concatenate([video, np.stack(video_pool)], axis=0)
            actions = np.concatenate([actions, actions_pool[selected]], axis=0)
            poses = np.concatenate([poses, poses_pool[selected]], axis=0)
            timestamp = np.concatenate([np.arange(frame_idx, frame_idx + self.n_frames), selected])
        else:
            timestamp = np.arange(self.n_frames)

        # === 7. Convert video to torch format ===
        video = torch.from_numpy(video / 255.0).float().permute(0, 3, 1, 2).contiguous() # (T, H, W, C) -> (T, C, H, W)

        # === 9. Return all items ===
        return (
            video[:: self.frame_skip],
            actions[:: self.frame_skip],
            poses[:: self.frame_skip],
            timestamp
        )