pi05-so100-diverse / so100_dataset.py
justinstrong's picture
Upload folder using huggingface_hub
cd604b4 verified
#!/usr/bin/env python3
"""
Custom PyTorch Dataset that reads directly from community_dataset_v3 v2.1 files on disk.
No merging, no conversion, no copying. Just reads parquets + decodes video frames.
Returns raw (unnormalized) data in the format LeRobotDataset returns — the existing
Pi0.5 preprocessor handles normalization, padding, tokenization, and device placement.
Provides a .meta adapter so lerobot_train.py can use it as a drop-in replacement.
"""
import json
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
class _DatasetMeta:
"""
Lightweight adapter that provides the .meta interface lerobot_train.py expects.
Wraps our filtered index + precomputed stats.
"""
def __init__(self, index: dict, stats: dict, data_root: Path):
self.repo_id = "SO100Dataset/local"
self.root = data_root
# Stats: training script expects dict[str, dict[str, torch.Tensor]]
self.stats = {}
for key, s in stats.items():
self.stats[key] = {
"mean": torch.tensor(s["mean"], dtype=torch.float32),
"std": torch.tensor(s["std"], dtype=torch.float32),
# Preprocessor may also look for min/max/quantiles.
# Approximate them from mean/std for MEAN_STD normalization.
"min": torch.tensor(s["mean"], dtype=torch.float32) - 3 * torch.tensor(s["std"], dtype=torch.float32),
"max": torch.tensor(s["mean"], dtype=torch.float32) + 3 * torch.tensor(s["std"], dtype=torch.float32),
}
# Tasks
self.tasks = pd.DataFrame(
{"task_index": range(len(index["tasks"]))},
index=index["tasks"],
)
# Features
self._features = {
"observation.images.image": {
"dtype": "video",
"shape": [3, 480, 640],
"names": ["channels", "height", "width"],
},
"observation.images.image2": {
"dtype": "video",
"shape": [3, 480, 640],
"names": ["channels", "height", "width"],
},
"observation.state": {
"dtype": "float32",
"shape": [6],
},
"action": {
"dtype": "float32",
"shape": [6],
},
"timestamp": {"dtype": "float32", "shape": []},
"frame_index": {"dtype": "int64", "shape": []},
"episode_index": {"dtype": "int64", "shape": []},
"index": {"dtype": "int64", "shape": []},
"task_index": {"dtype": "int64", "shape": []},
}
self.info = {
"fps": 30,
"robot_type": "so100",
"total_episodes": index["summary"]["episodes"],
"total_frames": index["summary"]["total_frames"],
}
@property
def fps(self):
return 30
@property
def features(self):
return self._features
@property
def camera_keys(self):
return ["observation.images.image", "observation.images.image2"]
@property
def video_keys(self):
return ["observation.images.image", "observation.images.image2"]
@property
def image_keys(self):
return []
@property
def total_episodes(self):
return self.info["total_episodes"]
@property
def total_frames(self):
return self.info["total_frames"]
@property
def robot_type(self):
return "so100"
class SO100Dataset(Dataset):
"""
Loads filtered SO-100/101 episodes from community_dataset_v3 on disk.
Each sample is one frame with an action chunk of the next `chunk_size` steps.
Returns raw unnormalized data — the Pi0.5 preprocessor handles normalization.
Provides .meta property compatible with lerobot_train.py.
"""
def __init__(
self,
data_root: str | Path,
index_path: str | Path,
stats_path: str | Path | None = None,
video_backend: str = "pyav",
chunk_size: int = 50,
image_transforms=None,
):
self.data_root = Path(data_root)
self.video_backend = video_backend
self.chunk_size = chunk_size
self.image_transforms = image_transforms
self.fps = 30
# Load index
with open(index_path) as f:
self._index = json.load(f)
self.tasks = self._index["tasks"]
# Load stats
raw_stats = {}
if stats_path and Path(stats_path).exists():
with open(stats_path) as f:
raw_stats = json.load(f)
# Create meta adapter
self.meta = _DatasetMeta(self._index, raw_stats, self.data_root)
# Build flat frame-level index
self._frame_index = []
self._episode_offsets = []
for ep in self._index["episodes"]:
dataset_path = self.data_root / ep["dataset"]
ep_idx = ep["episode_index"]
task = ep["task"]
task_idx = ep["task_index"]
num_frames = ep["num_frames"]
# Only include frames where a full action chunk fits
valid_frames = max(0, num_frames - self.chunk_size)
if valid_frames == 0:
continue
start = len(self._frame_index)
self._episode_offsets.append(start)
for frame_idx in range(valid_frames):
self._frame_index.append((
dataset_path, ep_idx, frame_idx,
num_frames, task, task_idx,
))
# Parquet cache
self._parquet_cache = {}
self._cache_max = 200
def __len__(self):
return len(self._frame_index)
@property
def num_episodes(self):
return len(self._episode_offsets)
@property
def num_frames(self):
return len(self._frame_index)
@property
def episodes(self):
return None # Use all episodes (no further filtering)
@property
def features(self):
return self.meta.features
@property
def video(self):
return True
@property
def camera_keys(self):
return self.meta.camera_keys
@property
def video_frame_keys(self):
return self.meta.camera_keys
def _load_parquet(self, dataset_path: Path, episode_index: int) -> pd.DataFrame:
"""Load and cache a parquet file."""
key = (str(dataset_path), episode_index)
if key in self._parquet_cache:
return self._parquet_cache[key]
parquet_path = dataset_path / f"data/chunk-000/episode_{episode_index:06d}.parquet"
df = pd.read_parquet(parquet_path)
if len(self._parquet_cache) >= self._cache_max:
oldest_key = next(iter(self._parquet_cache))
del self._parquet_cache[oldest_key]
self._parquet_cache[key] = df
return df
def _decode_video_frame(self, video_path: Path, timestamp: float) -> torch.Tensor:
"""Decode a single frame from an MP4 at the given timestamp. Returns (C, H, W) float32 [0,1]."""
if self.video_backend == "torchcodec":
from torchcodec.decoders import VideoDecoder
decoder = VideoDecoder(str(video_path))
frame = decoder.get_frame_played_at(timestamp)
return frame.data.float() / 255.0
else:
import av
container = av.open(str(video_path))
stream = container.streams.video[0]
target_pts = int(timestamp / float(stream.time_base))
container.seek(target_pts, stream=stream)
for frame in container.decode(video=0):
arr = frame.to_ndarray(format="rgb24")
tensor = torch.from_numpy(arr).permute(2, 0, 1).float() / 255.0
container.close()
return tensor
container.close()
raise RuntimeError(f"Could not decode frame at t={timestamp} from {video_path}")
def __getitem__(self, idx: int) -> dict:
dataset_path, ep_idx, frame_idx, num_frames, task, task_idx = self._frame_index[idx]
df = self._load_parquet(dataset_path, ep_idx)
# Current frame
row = df.iloc[frame_idx]
state = torch.tensor(row["observation.state"], dtype=torch.float32)
timestamp = float(row["timestamp"])
# Action chunk: next chunk_size actions starting from current frame
action_end = min(frame_idx + self.chunk_size, len(df))
action_rows = df.iloc[frame_idx:action_end]
actions = torch.tensor(
np.stack(action_rows["action"].values),
dtype=torch.float32,
)
# Pad with last action if near episode end
if actions.shape[0] < self.chunk_size:
pad = actions[-1:].expand(self.chunk_size - actions.shape[0], -1)
actions = torch.cat([actions, pad], dim=0)
# Decode video frames
video_dir = dataset_path / "videos" / "chunk-000"
ep_str = f"episode_{ep_idx:06d}.mp4"
image1 = self._decode_video_frame(
video_dir / "observation.images.image" / ep_str, timestamp
)
image2 = self._decode_video_frame(
video_dir / "observation.images.image2" / ep_str, timestamp
)
if self.image_transforms is not None:
image1 = self.image_transforms(image1)
image2 = self.image_transforms(image2)
return {
"observation.images.image": image1, # (3, 480, 640) float32 [0,1]
"observation.images.image2": image2, # (3, 480, 640) float32 [0,1]
"observation.state": state, # (6,) float32, raw values
"action": actions, # (50, 6) float32, raw values
"task": task, # str
"task_index": torch.tensor(task_idx),
"timestamp": torch.tensor(timestamp),
"frame_index": torch.tensor(frame_idx),
"episode_index": torch.tensor(ep_idx),
"index": torch.tensor(idx),
}
def __repr__(self):
return (
f"SO100Dataset(\n"
f" data_root='{self.data_root}',\n"
f" episodes={self.num_episodes},\n"
f" frames={self.num_frames:,},\n"
f" tasks={len(self.tasks)},\n"
f" video_backend='{self.video_backend}',\n"
f")"
)