#!/usr/bin/env python3 """ Custom PyTorch Dataset that reads directly from community_dataset_v3 v2.1 files on disk. No merging, no conversion, no copying. Just reads parquets + decodes video frames. Returns raw (unnormalized) data in the format LeRobotDataset returns — the existing Pi0.5 preprocessor handles normalization, padding, tokenization, and device placement. Provides a .meta adapter so lerobot_train.py can use it as a drop-in replacement. """ import json from pathlib import Path import numpy as np import pandas as pd import torch from torch.utils.data import Dataset class _DatasetMeta: """ Lightweight adapter that provides the .meta interface lerobot_train.py expects. Wraps our filtered index + precomputed stats. """ def __init__(self, index: dict, stats: dict, data_root: Path): self.repo_id = "SO100Dataset/local" self.root = data_root # Stats: training script expects dict[str, dict[str, torch.Tensor]] self.stats = {} for key, s in stats.items(): self.stats[key] = { "mean": torch.tensor(s["mean"], dtype=torch.float32), "std": torch.tensor(s["std"], dtype=torch.float32), # Preprocessor may also look for min/max/quantiles. # Approximate them from mean/std for MEAN_STD normalization. "min": torch.tensor(s["mean"], dtype=torch.float32) - 3 * torch.tensor(s["std"], dtype=torch.float32), "max": torch.tensor(s["mean"], dtype=torch.float32) + 3 * torch.tensor(s["std"], dtype=torch.float32), } # Tasks self.tasks = pd.DataFrame( {"task_index": range(len(index["tasks"]))}, index=index["tasks"], ) # Features self._features = { "observation.images.image": { "dtype": "video", "shape": [3, 480, 640], "names": ["channels", "height", "width"], }, "observation.images.image2": { "dtype": "video", "shape": [3, 480, 640], "names": ["channels", "height", "width"], }, "observation.state": { "dtype": "float32", "shape": [6], }, "action": { "dtype": "float32", "shape": [6], }, "timestamp": {"dtype": "float32", "shape": []}, "frame_index": {"dtype": "int64", "shape": []}, "episode_index": {"dtype": "int64", "shape": []}, "index": {"dtype": "int64", "shape": []}, "task_index": {"dtype": "int64", "shape": []}, } self.info = { "fps": 30, "robot_type": "so100", "total_episodes": index["summary"]["episodes"], "total_frames": index["summary"]["total_frames"], } @property def fps(self): return 30 @property def features(self): return self._features @property def camera_keys(self): return ["observation.images.image", "observation.images.image2"] @property def video_keys(self): return ["observation.images.image", "observation.images.image2"] @property def image_keys(self): return [] @property def total_episodes(self): return self.info["total_episodes"] @property def total_frames(self): return self.info["total_frames"] @property def robot_type(self): return "so100" class SO100Dataset(Dataset): """ Loads filtered SO-100/101 episodes from community_dataset_v3 on disk. Each sample is one frame with an action chunk of the next `chunk_size` steps. Returns raw unnormalized data — the Pi0.5 preprocessor handles normalization. Provides .meta property compatible with lerobot_train.py. """ def __init__( self, data_root: str | Path, index_path: str | Path, stats_path: str | Path | None = None, video_backend: str = "pyav", chunk_size: int = 50, image_transforms=None, ): self.data_root = Path(data_root) self.video_backend = video_backend self.chunk_size = chunk_size self.image_transforms = image_transforms self.fps = 30 # Load index with open(index_path) as f: self._index = json.load(f) self.tasks = self._index["tasks"] # Load stats raw_stats = {} if stats_path and Path(stats_path).exists(): with open(stats_path) as f: raw_stats = json.load(f) # Create meta adapter self.meta = _DatasetMeta(self._index, raw_stats, self.data_root) # Build flat frame-level index self._frame_index = [] self._episode_offsets = [] for ep in self._index["episodes"]: dataset_path = self.data_root / ep["dataset"] ep_idx = ep["episode_index"] task = ep["task"] task_idx = ep["task_index"] num_frames = ep["num_frames"] # Only include frames where a full action chunk fits valid_frames = max(0, num_frames - self.chunk_size) if valid_frames == 0: continue start = len(self._frame_index) self._episode_offsets.append(start) for frame_idx in range(valid_frames): self._frame_index.append(( dataset_path, ep_idx, frame_idx, num_frames, task, task_idx, )) # Parquet cache self._parquet_cache = {} self._cache_max = 200 def __len__(self): return len(self._frame_index) @property def num_episodes(self): return len(self._episode_offsets) @property def num_frames(self): return len(self._frame_index) @property def episodes(self): return None # Use all episodes (no further filtering) @property def features(self): return self.meta.features @property def video(self): return True @property def camera_keys(self): return self.meta.camera_keys @property def video_frame_keys(self): return self.meta.camera_keys def _load_parquet(self, dataset_path: Path, episode_index: int) -> pd.DataFrame: """Load and cache a parquet file.""" key = (str(dataset_path), episode_index) if key in self._parquet_cache: return self._parquet_cache[key] parquet_path = dataset_path / f"data/chunk-000/episode_{episode_index:06d}.parquet" df = pd.read_parquet(parquet_path) if len(self._parquet_cache) >= self._cache_max: oldest_key = next(iter(self._parquet_cache)) del self._parquet_cache[oldest_key] self._parquet_cache[key] = df return df def _decode_video_frame(self, video_path: Path, timestamp: float) -> torch.Tensor: """Decode a single frame from an MP4 at the given timestamp. Returns (C, H, W) float32 [0,1].""" if self.video_backend == "torchcodec": from torchcodec.decoders import VideoDecoder decoder = VideoDecoder(str(video_path)) frame = decoder.get_frame_played_at(timestamp) return frame.data.float() / 255.0 else: import av container = av.open(str(video_path)) stream = container.streams.video[0] target_pts = int(timestamp / float(stream.time_base)) container.seek(target_pts, stream=stream) for frame in container.decode(video=0): arr = frame.to_ndarray(format="rgb24") tensor = torch.from_numpy(arr).permute(2, 0, 1).float() / 255.0 container.close() return tensor container.close() raise RuntimeError(f"Could not decode frame at t={timestamp} from {video_path}") def __getitem__(self, idx: int) -> dict: # Retry with a different sample if this one has corrupt/mismatched video for _attempt in range(5): try: return self._get_sample(idx) except (IndexError, RuntimeError, OSError) as e: # Video duration doesn't match parquet timestamps, or file is corrupt. # Pick a random different index and try again. import random idx = random.randint(0, len(self._frame_index) - 1) # If all retries fail, raise return self._get_sample(idx) def _get_sample(self, idx: int) -> dict: dataset_path, ep_idx, frame_idx, num_frames, task, task_idx = self._frame_index[idx] df = self._load_parquet(dataset_path, ep_idx) # Current frame row = df.iloc[frame_idx] state = torch.tensor(row["observation.state"], dtype=torch.float32) timestamp = float(row["timestamp"]) # Action chunk: next chunk_size actions starting from current frame action_end = min(frame_idx + self.chunk_size, len(df)) action_rows = df.iloc[frame_idx:action_end] actions = torch.tensor( np.stack(action_rows["action"].values), dtype=torch.float32, ) # Pad with last action if near episode end if actions.shape[0] < self.chunk_size: pad = actions[-1:].expand(self.chunk_size - actions.shape[0], -1) actions = torch.cat([actions, pad], dim=0) # Decode video frames video_dir = dataset_path / "videos" / "chunk-000" ep_str = f"episode_{ep_idx:06d}.mp4" image1 = self._decode_video_frame( video_dir / "observation.images.image" / ep_str, timestamp ) image2 = self._decode_video_frame( video_dir / "observation.images.image2" / ep_str, timestamp ) if self.image_transforms is not None: image1 = self.image_transforms(image1) image2 = self.image_transforms(image2) return { "observation.images.image": image1, # (3, 480, 640) float32 [0,1] "observation.images.image2": image2, # (3, 480, 640) float32 [0,1] "observation.state": state, # (6,) float32, raw values "action": actions, # (50, 6) float32, raw values "task": task, # str "task_index": torch.tensor(task_idx), "timestamp": torch.tensor(timestamp), "frame_index": torch.tensor(frame_idx), "episode_index": torch.tensor(ep_idx), "index": torch.tensor(idx), } def __repr__(self): return ( f"SO100Dataset(\n" f" data_root='{self.data_root}',\n" f" episodes={self.num_episodes},\n" f" frames={self.num_frames:,},\n" f" tasks={len(self.tasks)},\n" f" video_backend='{self.video_backend}',\n" f")" )