| | import os |
| | import torch |
| | import cv2 |
| | import numpy as np |
| | from torch.utils.data import Dataset, DataLoader |
| | from typing import Tuple, Dict, Any, List, Optional |
| | from collections import OrderedDict |
| | import time |
| |
|
| | |
| | try: |
| | from wm.dataset.data_config import DatasetConfig, get_config_by_name |
| | except ImportError: |
| | from data_config import DatasetConfig, get_config_by_name |
| |
|
| | |
| |
|
| | class BaseRoboticsDataset(Dataset): |
| | """ |
| | Unified Dataset for robotics data. Handles MP4 loading with window sampling and caching. |
| | """ |
| | def __init__(self, config: DatasetConfig): |
| | self.config = config |
| | self.metadata_path = os.path.join(config.root_dir, "metadata.pt") |
| | self.metadata_lite_path = os.path.join(config.root_dir, "metadata_lite.pt") |
| | |
| | |
| | if os.path.exists(self.metadata_lite_path): |
| | print(f"[{config.name}] Initializing from LITE metadata...") |
| | self.init_metadata = torch.load(self.metadata_lite_path, weights_only=False) |
| | else: |
| | print(f"[{config.name}] Initializing from FULL metadata (lite not found)...") |
| | self.init_metadata = torch.load(self.metadata_path, weights_only=False) |
| | |
| | |
| | self.indices = [] |
| | for i, entry in enumerate(self.init_metadata): |
| | t_len = entry['length'] if 'length' in entry else self._get_traj_len(entry) |
| | if t_len >= config.seq_len: |
| | |
| | for start_f in range(t_len - config.seq_len + 1): |
| | self.indices.append((i, start_f)) |
| | |
| | |
| | self.init_metadata = None |
| | self._full_metadata = None |
| | |
| | print(f"[{config.name}] Initialized: {len(self.indices)} windows.") |
| | self.cache = OrderedDict() |
| |
|
| | @property |
| | def full_metadata(self): |
| | if self._full_metadata is None: |
| | print(f"[{self.config.name}] Lazy-loading FULL metadata...") |
| | self._full_metadata = torch.load(self.metadata_path, weights_only=False, mmap=True) |
| | return self._full_metadata |
| |
|
| | def _get_traj_len(self, entry: Dict[str, Any]) -> int: |
| | if 'actions' in entry: |
| | return entry['actions'].shape[0] |
| | if 'length' in entry: |
| | return entry['length'] |
| | if 'commands' in entry: |
| | if isinstance(entry['commands'], dict): |
| | return entry['commands']['linear_velocity'].shape[0] |
| | return entry['commands'].shape[0] |
| | return 0 |
| |
|
| | def _load_video(self, video_rel_path: str) -> torch.Tensor: |
| | if video_rel_path in self.cache: |
| | self.cache.move_to_end(video_rel_path) |
| | return self.cache[video_rel_path] |
| |
|
| | video_path = os.path.join(self.config.root_dir, video_rel_path) |
| | cap = cv2.VideoCapture(video_path) |
| | frames = [] |
| | target_h, target_w = self.config.obs_shape[1], self.config.obs_shape[2] |
| | |
| | while True: |
| | ret, frame = cap.read() |
| | if not ret: |
| | break |
| | if frame.shape[:2] != (target_h, target_w): |
| | frame = cv2.resize(frame, (target_w, target_h), interpolation=cv2.INTER_AREA) |
| | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| | frames.append(frame) |
| | cap.release() |
| |
|
| | if not frames: |
| | print(f"Warning: Could not read any frames from {video_path}") |
| | return torch.zeros((0, 3, target_h, target_w), dtype=torch.uint8) |
| |
|
| | |
| | video_tensor = torch.from_numpy(np.stack(frames)).permute(0, 3, 1, 2).contiguous() |
| | |
| | if len(self.cache) >= self.config.cache_size: |
| | self.cache.popitem(last=False) |
| | self.cache[video_rel_path] = video_tensor |
| | return video_tensor |
| |
|
| | def _get_action_slice(self, entry: Dict[str, Any], start: int, end: int) -> torch.Tensor: |
| | """Extract raw action slice without padding.""" |
| | if self.config.name in ["language_table", "rt1", "dreamer4", "pusht", "franka", "lang_table_50k"]: |
| | return entry['actions'][start:end] |
| | elif self.config.name == "recon": |
| | |
| | cmds = entry['commands'] |
| | lin = cmds['linear_velocity'][start:end] |
| | ang = cmds['angular_velocity'][start:end] |
| | return torch.stack([lin, ang], dim=-1) |
| | return torch.zeros((end - start, self.config.action_dim)) |
| |
|
| | def __len__(self): |
| | return len(self.indices) |
| |
|
| | def __getitem__(self, idx: int) -> Dict[str, Any]: |
| | traj_idx, start_f = self.indices[idx] |
| | entry = self.full_metadata[traj_idx] |
| | |
| | full_video = self._load_video(entry['video_path']) |
| | obs_window = full_video[start_f : start_f + self.config.seq_len] |
| | |
| | |
| | if obs_window.shape[0] < self.config.seq_len: |
| | if obs_window.shape[0] == 0: |
| | |
| | |
| | if full_video.shape[0] > 0: |
| | last_frame = full_video[-1:] |
| | obs_window = last_frame.repeat(self.config.seq_len, 1, 1, 1) |
| | else: |
| | obs_window = torch.zeros((self.config.seq_len, *self.config.obs_shape), dtype=torch.float32) |
| | else: |
| | |
| | last_frame = obs_window[-1:] |
| | pad_len = self.config.seq_len - obs_window.shape[0] |
| | padding = last_frame.repeat(pad_len, 1, 1, 1) |
| | obs_window = torch.cat([obs_window, padding], dim=0) |
| | |
| | action_window = self._get_action_slice(entry, start_f, start_f + self.config.seq_len) |
| | |
| | |
| | if action_window.shape[0] < self.config.seq_len: |
| | pad_len = self.config.seq_len - action_window.shape[0] |
| | padding = action_window[-1:].repeat(pad_len, 1) if action_window.shape[0] > 0 else torch.zeros((pad_len, self.config.action_dim)) |
| | action_window = torch.cat([action_window, padding], dim=0) |
| |
|
| | res = { |
| | "obs": obs_window.float() / 255.0 if obs_window.dtype == torch.uint8 else obs_window, |
| | "action": action_window |
| | } |
| | |
| | if 'task_id' in entry: |
| | res['task_id'] = entry['task_id'] |
| | |
| | return res |
| |
|
| | class RoboticsDatasetWrapper: |
| | """ |
| | Helper to instantiate datasets by name using pre-defined configs. |
| | """ |
| | @staticmethod |
| | def get_dataset(name: str, **kwargs) -> BaseRoboticsDataset: |
| | """ |
| | Instantiates a BaseRoboticsDataset by looking up the configuration by name. |
| | kwargs can be used to override default configuration parameters. |
| | """ |
| | config = get_config_by_name(name, **kwargs) |
| | return BaseRoboticsDataset(config) |
| |
|
| | if __name__ == "__main__": |
| | |
| | |
| | import sys |
| | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) |
| | |
| | print("\n--- Testing Individual Datasets ---") |
| | |
| | for name in ["language_table", "dreamer4"]: |
| | print(f"\nTesting {name}...") |
| | try: |
| | dataset = RoboticsDatasetWrapper.get_dataset(name, seq_len=5, obs_shape=(3, 64, 64)) |
| | loader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2) |
| | |
| | |
| | start_time = time.time() |
| | for i, batch in enumerate(loader): |
| | if i == 0: |
| | print(f" Obs Shape: {batch['obs'].shape}") |
| | print(f" Action Shape: {batch['action'].shape}") |
| | if i >= 4: break |
| | |
| | end_time = time.time() |
| | print(f" Load time for 5 batches: {end_time - start_time:.2f}s") |
| | except Exception as e: |
| | print(f" Failed to test {name}: {e}") |
| |
|