import os import torch import cv2 import numpy as np from torch.utils.data import Dataset, DataLoader from typing import Tuple, Dict, Any, List, Optional from collections import OrderedDict import time # Import configurations from data_config.py try: from wm.dataset.data_config import DatasetConfig, get_config_by_name except ImportError: from data_config import DatasetConfig, get_config_by_name # --- Dataset Implementation --- class BaseRoboticsDataset(Dataset): """ Unified Dataset for robotics data. Handles MP4 loading with window sampling and caching. """ def __init__(self, config: DatasetConfig): self.config = config self.metadata_path = os.path.join(config.root_dir, "metadata.pt") self.metadata_lite_path = os.path.join(config.root_dir, "metadata_lite.pt") # Load lite metadata for initialization if it exists, otherwise use full if os.path.exists(self.metadata_lite_path): print(f"[{config.name}] Initializing from LITE metadata...") self.init_metadata = torch.load(self.metadata_lite_path, weights_only=False) else: print(f"[{config.name}] Initializing from FULL metadata (lite not found)...") self.init_metadata = torch.load(self.metadata_path, weights_only=False) # Build indices efficiently self.indices = [] for i, entry in enumerate(self.init_metadata): t_len = entry['length'] if 'length' in entry else self._get_traj_len(entry) if t_len >= config.seq_len: # Add all valid starting positions for start_f in range(t_len - config.seq_len + 1): self.indices.append((i, start_f)) # Free up init_metadata memory self.init_metadata = None self._full_metadata = None print(f"[{config.name}] Initialized: {len(self.indices)} windows.") self.cache = OrderedDict() @property def full_metadata(self): if self._full_metadata is None: print(f"[{self.config.name}] Lazy-loading FULL metadata...") self._full_metadata = torch.load(self.metadata_path, weights_only=False, mmap=True) return self._full_metadata def _get_traj_len(self, entry: Dict[str, Any]) -> int: if 'actions' in entry: return entry['actions'].shape[0] if 'length' in entry: return entry['length'] if 'commands' in entry: if isinstance(entry['commands'], dict): return entry['commands']['linear_velocity'].shape[0] return entry['commands'].shape[0] return 0 def _load_video(self, video_rel_path: str) -> torch.Tensor: if video_rel_path in self.cache: self.cache.move_to_end(video_rel_path) return self.cache[video_rel_path] video_path = os.path.join(self.config.root_dir, video_rel_path) cap = cv2.VideoCapture(video_path) frames = [] target_h, target_w = self.config.obs_shape[1], self.config.obs_shape[2] while True: ret, frame = cap.read() if not ret: break if frame.shape[:2] != (target_h, target_w): frame = cv2.resize(frame, (target_w, target_h), interpolation=cv2.INTER_AREA) frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(frame) cap.release() if not frames: print(f"Warning: Could not read any frames from {video_path}") return torch.zeros((0, 3, target_h, target_w), dtype=torch.uint8) # (T, H, W, C) -> (T, C, H, W) video_tensor = torch.from_numpy(np.stack(frames)).permute(0, 3, 1, 2).contiguous() if len(self.cache) >= self.config.cache_size: self.cache.popitem(last=False) self.cache[video_rel_path] = video_tensor return video_tensor def _get_action_slice(self, entry: Dict[str, Any], start: int, end: int) -> torch.Tensor: """Extract raw action slice without padding.""" if self.config.name in ["language_table", "rt1", "dreamer4", "pusht", "franka", "lang_table_50k"]: return entry['actions'][start:end] elif self.config.name == "recon": # RECON commands are linear_velocity and angular_velocity cmds = entry['commands'] lin = cmds['linear_velocity'][start:end] ang = cmds['angular_velocity'][start:end] return torch.stack([lin, ang], dim=-1) return torch.zeros((end - start, self.config.action_dim)) def __len__(self): return len(self.indices) def __getitem__(self, idx: int) -> Dict[str, Any]: traj_idx, start_f = self.indices[idx] entry = self.full_metadata[traj_idx] full_video = self._load_video(entry['video_path']) obs_window = full_video[start_f : start_f + self.config.seq_len] # Handle cases where video is shorter than metadata claims if obs_window.shape[0] < self.config.seq_len: if obs_window.shape[0] == 0: # Video is completely empty or start_f is out of bounds # If full_video has some frames, use the last one. Otherwise use zeros. if full_video.shape[0] > 0: last_frame = full_video[-1:] # (1, C, H, W) obs_window = last_frame.repeat(self.config.seq_len, 1, 1, 1) else: obs_window = torch.zeros((self.config.seq_len, *self.config.obs_shape), dtype=torch.float32) else: # Pad with the last available frame in the window last_frame = obs_window[-1:] pad_len = self.config.seq_len - obs_window.shape[0] padding = last_frame.repeat(pad_len, 1, 1, 1) obs_window = torch.cat([obs_window, padding], dim=0) action_window = self._get_action_slice(entry, start_f, start_f + self.config.seq_len) # Ensure action_window is also correct length (should be, but just in case) if action_window.shape[0] < self.config.seq_len: pad_len = self.config.seq_len - action_window.shape[0] padding = action_window[-1:].repeat(pad_len, 1) if action_window.shape[0] > 0 else torch.zeros((pad_len, self.config.action_dim)) action_window = torch.cat([action_window, padding], dim=0) res = { "obs": obs_window.float() / 255.0 if obs_window.dtype == torch.uint8 else obs_window, "action": action_window } if 'task_id' in entry: res['task_id'] = entry['task_id'] return res class RoboticsDatasetWrapper: """ Helper to instantiate datasets by name using pre-defined configs. """ @staticmethod def get_dataset(name: str, **kwargs) -> BaseRoboticsDataset: """ Instantiates a BaseRoboticsDataset by looking up the configuration by name. kwargs can be used to override default configuration parameters. """ config = get_config_by_name(name, **kwargs) return BaseRoboticsDataset(config) if __name__ == "__main__": # To run this script directly, we need to handle the relative import # This block allows running `python wm/dataset/dataset.py` from the project root import sys sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) print("\n--- Testing Individual Datasets ---") for name in ["language_table", "dreamer4"]: print(f"\nTesting {name}...") try: dataset = RoboticsDatasetWrapper.get_dataset(name, seq_len=5, obs_shape=(3, 64, 64)) loader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2) start_time = time.time() for i, batch in enumerate(loader): if i == 0: print(f" Obs Shape: {batch['obs'].shape}") print(f" Action Shape: {batch['action'].shape}") if i >= 4: break end_time = time.time() print(f" Load time for 5 batches: {end_time - start_time:.2f}s") except Exception as e: print(f" Failed to test {name}: {e}")