world_model / wm /dataset /dataset.py
t1an's picture
Upload folder using huggingface_hub
ad9aba4 verified
import os
import torch
import cv2
import numpy as np
from torch.utils.data import Dataset, DataLoader
from typing import Tuple, Dict, Any, List, Optional
from collections import OrderedDict
import time
# Import configurations from data_config.py
try:
from wm.dataset.data_config import DatasetConfig, get_config_by_name
except ImportError:
from data_config import DatasetConfig, get_config_by_name
# --- Dataset Implementation ---
class BaseRoboticsDataset(Dataset):
"""
Unified Dataset for robotics data. Handles MP4 loading with window sampling and caching.
"""
def __init__(self, config: DatasetConfig):
self.config = config
self.metadata_path = os.path.join(config.root_dir, "metadata.pt")
self.metadata_lite_path = os.path.join(config.root_dir, "metadata_lite.pt")
# Load lite metadata for initialization if it exists, otherwise use full
if os.path.exists(self.metadata_lite_path):
print(f"[{config.name}] Initializing from LITE metadata...")
self.init_metadata = torch.load(self.metadata_lite_path, weights_only=False)
else:
print(f"[{config.name}] Initializing from FULL metadata (lite not found)...")
self.init_metadata = torch.load(self.metadata_path, weights_only=False)
# Build indices efficiently
self.indices = []
for i, entry in enumerate(self.init_metadata):
t_len = entry['length'] if 'length' in entry else self._get_traj_len(entry)
if t_len >= config.seq_len:
# Add all valid starting positions
for start_f in range(t_len - config.seq_len + 1):
self.indices.append((i, start_f))
# Free up init_metadata memory
self.init_metadata = None
self._full_metadata = None
print(f"[{config.name}] Initialized: {len(self.indices)} windows.")
self.cache = OrderedDict()
@property
def full_metadata(self):
if self._full_metadata is None:
print(f"[{self.config.name}] Lazy-loading FULL metadata...")
self._full_metadata = torch.load(self.metadata_path, weights_only=False, mmap=True)
return self._full_metadata
def _get_traj_len(self, entry: Dict[str, Any]) -> int:
if 'actions' in entry:
return entry['actions'].shape[0]
if 'length' in entry:
return entry['length']
if 'commands' in entry:
if isinstance(entry['commands'], dict):
return entry['commands']['linear_velocity'].shape[0]
return entry['commands'].shape[0]
return 0
def _load_video(self, video_rel_path: str) -> torch.Tensor:
if video_rel_path in self.cache:
self.cache.move_to_end(video_rel_path)
return self.cache[video_rel_path]
video_path = os.path.join(self.config.root_dir, video_rel_path)
cap = cv2.VideoCapture(video_path)
frames = []
target_h, target_w = self.config.obs_shape[1], self.config.obs_shape[2]
while True:
ret, frame = cap.read()
if not ret:
break
if frame.shape[:2] != (target_h, target_w):
frame = cv2.resize(frame, (target_w, target_h), interpolation=cv2.INTER_AREA)
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame)
cap.release()
if not frames:
print(f"Warning: Could not read any frames from {video_path}")
return torch.zeros((0, 3, target_h, target_w), dtype=torch.uint8)
# (T, H, W, C) -> (T, C, H, W)
video_tensor = torch.from_numpy(np.stack(frames)).permute(0, 3, 1, 2).contiguous()
if len(self.cache) >= self.config.cache_size:
self.cache.popitem(last=False)
self.cache[video_rel_path] = video_tensor
return video_tensor
def _get_action_slice(self, entry: Dict[str, Any], start: int, end: int) -> torch.Tensor:
"""Extract raw action slice without padding."""
if self.config.name in ["language_table", "rt1", "dreamer4", "pusht", "franka", "lang_table_50k"]:
return entry['actions'][start:end]
elif self.config.name == "recon":
# RECON commands are linear_velocity and angular_velocity
cmds = entry['commands']
lin = cmds['linear_velocity'][start:end]
ang = cmds['angular_velocity'][start:end]
return torch.stack([lin, ang], dim=-1)
return torch.zeros((end - start, self.config.action_dim))
def __len__(self):
return len(self.indices)
def __getitem__(self, idx: int) -> Dict[str, Any]:
traj_idx, start_f = self.indices[idx]
entry = self.full_metadata[traj_idx]
full_video = self._load_video(entry['video_path'])
obs_window = full_video[start_f : start_f + self.config.seq_len]
# Handle cases where video is shorter than metadata claims
if obs_window.shape[0] < self.config.seq_len:
if obs_window.shape[0] == 0:
# Video is completely empty or start_f is out of bounds
# If full_video has some frames, use the last one. Otherwise use zeros.
if full_video.shape[0] > 0:
last_frame = full_video[-1:] # (1, C, H, W)
obs_window = last_frame.repeat(self.config.seq_len, 1, 1, 1)
else:
obs_window = torch.zeros((self.config.seq_len, *self.config.obs_shape), dtype=torch.float32)
else:
# Pad with the last available frame in the window
last_frame = obs_window[-1:]
pad_len = self.config.seq_len - obs_window.shape[0]
padding = last_frame.repeat(pad_len, 1, 1, 1)
obs_window = torch.cat([obs_window, padding], dim=0)
action_window = self._get_action_slice(entry, start_f, start_f + self.config.seq_len)
# Ensure action_window is also correct length (should be, but just in case)
if action_window.shape[0] < self.config.seq_len:
pad_len = self.config.seq_len - action_window.shape[0]
padding = action_window[-1:].repeat(pad_len, 1) if action_window.shape[0] > 0 else torch.zeros((pad_len, self.config.action_dim))
action_window = torch.cat([action_window, padding], dim=0)
res = {
"obs": obs_window.float() / 255.0 if obs_window.dtype == torch.uint8 else obs_window,
"action": action_window
}
if 'task_id' in entry:
res['task_id'] = entry['task_id']
return res
class RoboticsDatasetWrapper:
"""
Helper to instantiate datasets by name using pre-defined configs.
"""
@staticmethod
def get_dataset(name: str, **kwargs) -> BaseRoboticsDataset:
"""
Instantiates a BaseRoboticsDataset by looking up the configuration by name.
kwargs can be used to override default configuration parameters.
"""
config = get_config_by_name(name, **kwargs)
return BaseRoboticsDataset(config)
if __name__ == "__main__":
# To run this script directly, we need to handle the relative import
# This block allows running `python wm/dataset/dataset.py` from the project root
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
print("\n--- Testing Individual Datasets ---")
for name in ["language_table", "dreamer4"]:
print(f"\nTesting {name}...")
try:
dataset = RoboticsDatasetWrapper.get_dataset(name, seq_len=5, obs_shape=(3, 64, 64))
loader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2)
start_time = time.time()
for i, batch in enumerate(loader):
if i == 0:
print(f" Obs Shape: {batch['obs'].shape}")
print(f" Action Shape: {batch['action'].shape}")
if i >= 4: break
end_time = time.time()
print(f" Load time for 5 batches: {end_time - start_time:.2f}s")
except Exception as e:
print(f" Failed to test {name}: {e}")