| from collections.abc import Iterator, Sequence |
| import dataclasses |
| from dataclasses import dataclass |
| import logging |
| import multiprocessing |
| import os |
| import typing |
| from typing import Literal, Protocol, SupportsIndex, TypeVar |
|
|
| import jax |
| import jax.numpy as jnp |
| import lerobot.common.datasets.lerobot_dataset as lerobot_dataset |
| import numpy as np |
| import torch |
|
|
| import openpi.models.model as _model |
| import openpi.training.config as _config |
| import openpi.transforms as _transforms |
| from various_speed.core import ( |
| SpeedTransformConfig, |
| _speed_chunk_ratio, |
| transform_episode, |
| ) |
|
|
| T_co = TypeVar("T_co", covariant=True) |
|
|
|
|
| class Dataset(Protocol[T_co]): |
| """Interface for a dataset with random access.""" |
|
|
| def __getitem__(self, index: SupportsIndex) -> T_co: |
| raise NotImplementedError("Subclasses of Dataset should implement __getitem__.") |
|
|
| def __len__(self) -> int: |
| raise NotImplementedError("Subclasses of Dataset should implement __len__.") |
|
|
|
|
| class DataLoader(Protocol[T_co]): |
| """Interface for a data loader.""" |
|
|
| def data_config(self) -> _config.DataConfig: |
| """Get the data config for this data loader.""" |
| raise NotImplementedError("Subclasses of DataLoader should implement data_config.") |
|
|
| def __iter__(self) -> Iterator[T_co]: |
| raise NotImplementedError("Subclasses of DataLoader should implement __iter__.") |
|
|
|
|
| class TransformedDataset(Dataset[T_co]): |
| def __init__(self, dataset: Dataset, transforms: Sequence[_transforms.DataTransformFn]): |
| self._dataset = dataset |
| self._transform = _transforms.compose(transforms) |
|
|
| def __getitem__(self, index: SupportsIndex) -> T_co: |
| return self._transform(self._dataset[index]) |
|
|
| def __len__(self) -> int: |
| return len(self._dataset) |
|
|
|
|
| @dataclass(frozen=True) |
| class _OnlineSlidingSample: |
| """One DataLoader-addressable training entry: just (episode, speed). |
| |
| Segments and chunk_phase are NOT pre-bound here -- they're decided at |
| __getitem__ time so each access gets a fresh random phase / random row. |
| """ |
|
|
| episode_position: int |
| episode_start: int |
| episode_end: int |
| speed_index: int |
| speed: float |
|
|
|
|
| class OnlineSlidingChunkDataset(Dataset[dict]): |
| """Online speed augmentation. Each access does: |
| |
| 1. (episode, speed) is picked by DataLoader index (every pair once per epoch). |
| 2. speed == 1.0: identity fast path -- random offset, slice source directly. |
| 3. speed != 1.0: |
| a. Random chunk_phase r in [0, q-1] (q = numerator of speed=q/p). |
| b. transform_episode(full_episode, speed, phase=r) -> re-timed trajectory. |
| c. Pick ONE row uniformly from {observation_mask == 1}. |
| d. actions = transformed_action[row : row + H], end-clamped. |
| |
| Important: |
| * Segments are an internal partition for cumulative integration -- they |
| don't restrict where training samples can start. Any mask=1 row is fair. |
| * Random phase per access is the SOLE source of resampling randomness; |
| no caching (cache hit rate would be ~1/q anyway -- multi-worker |
| DataLoader parallelizes the per-sample compute instead). |
| * Norm stats: source 1.0x stats are reused for ALL speeds by design (see |
| ``_reuse_source_one_x_norm_stats``). Different speeds yield slightly |
| different post-norm action magnitudes; the model handles this via |
| speed conditioning. |
| """ |
|
|
| def __init__( |
| self, |
| dataset: Dataset, |
| speeds: Sequence[float], |
| action_horizon: int, |
| *, |
| speed_config: SpeedTransformConfig | None = None, |
| ): |
| if action_horizon <= 0: |
| raise ValueError(f"action_horizon must be positive, got {action_horizon}") |
| if not speeds: |
| raise ValueError("online_sliding_speeds must contain at least one speed.") |
|
|
| self._dataset = dataset |
| self._speeds = tuple(float(speed) for speed in speeds) |
| self._action_horizon = int(action_horizon) |
| self._speed_config = speed_config or SpeedTransformConfig(chunk_aligned_observation=True) |
| |
| |
| |
| if not self._speed_config.chunk_aligned_observation: |
| self._speed_config = dataclasses.replace(self._speed_config, chunk_aligned_observation=True) |
|
|
| |
| |
| |
| hf_dataset = getattr(dataset, "hf_dataset", None) |
| if hf_dataset is None: |
| raise ValueError("Online sliding chunks require a LeRobot-style dataset with hf_dataset.") |
| raw = hf_dataset.with_format(None) |
| columns = set(getattr(raw, "column_names", getattr(hf_dataset, "column_names", ()))) |
| self._action_key = _first_existing_column(columns, ("actions", "action")) |
| self._state_key = _first_existing_column(columns, ("state", "observation.state")) |
| if "episode_index" not in columns: |
| raise ValueError("Online sliding chunks require an episode_index column.") |
|
|
| self._actions = _stack_vector_column(raw[self._action_key], dtype=np.float32) |
| self._states = _stack_vector_column(raw[self._state_key], dtype=np.float32) |
| self._episode_indices = _scalar_column(raw["episode_index"], dtype=np.int64) |
| if "frame_index" in columns: |
| self._frame_indices = _scalar_column(raw["frame_index"], dtype=np.int64) |
| else: |
| self._frame_indices = _episode_local_indices(self._episode_indices) |
|
|
| if len(self._actions) != len(dataset) or len(self._states) != len(dataset): |
| raise ValueError("Online sliding source arrays do not match dataset length.") |
|
|
| self._episode_slices = _contiguous_episode_slices(self._episode_indices) |
| self._samples = self._build_samples() |
| if not self._samples: |
| raise ValueError( |
| "Online sliding chunks produced no (episode, speed) tuples. " |
| "Check that the source dataset has episodes and online_sliding_speeds is non-empty." |
| ) |
| self._cache_order: list[tuple[int, int, int]] = [] |
|
|
| def __len__(self) -> int: |
| return len(self._samples) |
|
|
| def _build_samples(self) -> list[_OnlineSlidingSample]: |
| """Enumerate (episode, speed) once each. Length = num_episodes * num_speeds.""" |
| samples: list[_OnlineSlidingSample] = [] |
| for episode_position, (episode_start, episode_end) in enumerate(self._episode_slices): |
| for speed_index, speed in enumerate(self._speeds): |
| samples.append( |
| _OnlineSlidingSample( |
| episode_position=episode_position, |
| episode_start=episode_start, |
| episode_end=episode_end, |
| speed_index=speed_index, |
| speed=speed, |
| ) |
| ) |
| return samples |
|
|
| def __getitem__(self, index: SupportsIndex) -> dict: |
| sample = self._samples[index.__index__()] |
| episode_len = sample.episode_end - sample.episode_start |
|
|
| |
| |
| |
| if sample.speed == 1.0: |
| target_local = int(np.random.randint(0, episode_len)) |
| global_source_step = sample.episode_start + target_local |
|
|
| local_horizon = np.minimum( |
| target_local + np.arange(self._action_horizon), |
| episode_len - 1, |
| ) |
| global_horizon = sample.episode_start + local_horizon |
|
|
| item = dict(self._dataset[global_source_step]) |
| item[self._state_key] = self._states[global_source_step].astype(np.float32) |
| item["actions"] = self._actions[global_horizon].astype(np.float32) |
| item[self._action_key] = item["actions"] |
| item["speed"] = np.asarray([1.0], dtype=np.float32) |
| item["speed_index"] = np.asarray(sample.speed_index, dtype=np.int64) |
| item["speed_label"] = _speed_label(1.0) |
| item["chunk_phase"] = np.asarray(0, dtype=np.int64) |
| item["chunk_phase_count"] = np.asarray(1, dtype=np.int64) |
| item["valid_mask"] = np.asarray(1, dtype=np.int8) |
| item["observation_mask"] = np.asarray(1, dtype=np.int8) |
| item["source_step_index"] = np.asarray(target_local, dtype=np.int64) |
| return item |
|
|
| |
|
|
| |
| |
| q, _p = _speed_chunk_ratio(sample.speed) |
| chunk_phase = int(np.random.randint(0, q)) |
|
|
| |
| |
| |
| config = dataclasses.replace(self._speed_config, chunk_phase=chunk_phase) |
| transformed, _metrics = transform_episode( |
| self._actions[sample.episode_start:sample.episode_end], |
| self._states[sample.episode_start:sample.episode_end], |
| self._frame_indices[sample.episode_start:sample.episode_end], |
| sample.speed, |
| config, |
| ) |
|
|
| |
| |
| valid_indices = np.flatnonzero(transformed["observation_mask"] == 1) |
| if valid_indices.size == 0: |
| |
| |
| raise RuntimeError( |
| f"Online sliding produced 0 valid rows. episode_position=" |
| f"{sample.episode_position}, speed={sample.speed}, " |
| f"chunk_phase={chunk_phase}, transformed_len={len(transformed['action'])}." |
| ) |
| row_index = int(valid_indices[np.random.randint(len(valid_indices))]) |
|
|
| |
| |
| action_indices = np.minimum( |
| row_index + np.arange(self._action_horizon), |
| len(transformed["action"]) - 1, |
| ) |
|
|
| |
| |
| |
| source_step = int(transformed["source_step_index"][row_index]) |
| global_source_step = sample.episode_start + source_step |
| item = dict(self._dataset[global_source_step]) |
| item[self._state_key] = transformed["state"][row_index].astype(np.float32) |
| item["actions"] = transformed["action"][action_indices].astype(np.float32) |
| item[self._action_key] = item["actions"] |
| item["speed"] = np.asarray([sample.speed], dtype=np.float32) |
| item["speed_index"] = np.asarray(sample.speed_index, dtype=np.int64) |
| item["speed_label"] = _speed_label(sample.speed) |
| item["chunk_phase"] = np.asarray(chunk_phase, dtype=np.int64) |
| item["chunk_phase_count"] = np.asarray(q, dtype=np.int64) |
| item["valid_mask"] = np.asarray(1, dtype=np.int8) |
| item["observation_mask"] = np.asarray(1, dtype=np.int8) |
| item["source_step_index"] = np.asarray(source_step, dtype=np.int64) |
| return item |
|
|
|
|
| def _first_existing_column(columns: set[str], candidates: Sequence[str]) -> str: |
| for candidate in candidates: |
| if candidate in columns: |
| return candidate |
| raise ValueError(f"None of the expected columns were found: {tuple(candidates)}") |
|
|
|
|
| def _stack_vector_column(values, *, dtype: np.dtype) -> np.ndarray: |
| return np.asarray([np.asarray(value, dtype=dtype).reshape(-1) for value in values], dtype=dtype) |
|
|
|
|
| def _scalar_column(values, *, dtype: np.dtype) -> np.ndarray: |
| return np.asarray([np.asarray(value).reshape(-1)[0] for value in values], dtype=dtype) |
|
|
|
|
| def _episode_local_indices(episode_indices: np.ndarray) -> np.ndarray: |
| out = np.zeros(len(episode_indices), dtype=np.int64) |
| for start, end in _contiguous_episode_slices(episode_indices): |
| out[start:end] = np.arange(end - start, dtype=np.int64) |
| return out |
|
|
|
|
| def _contiguous_episode_slices(episode_indices: np.ndarray) -> list[tuple[int, int]]: |
| if len(episode_indices) == 0: |
| return [] |
| slices: list[tuple[int, int]] = [] |
| start = 0 |
| for idx in range(1, len(episode_indices)): |
| if episode_indices[idx] != episode_indices[idx - 1]: |
| slices.append((start, idx)) |
| start = idx |
| slices.append((start, len(episode_indices))) |
| return slices |
|
|
|
|
| def _speed_label(speed: float) -> str: |
| text = f"{speed:g}".replace(".", "p") |
| return f"{text}x" |
|
|
|
|
| class FakeDataset(Dataset): |
| def __init__(self, model_config: _model.BaseModelConfig, num_samples: int): |
| self._num_samples = num_samples |
| self._observation_spec, self._action_spec = model_config.inputs_spec() |
|
|
| def __getitem__(self, index: SupportsIndex) -> dict: |
| rng = jax.random.key(index.__index__()) |
|
|
| def make_from_spec(spec: jax.ShapeDtypeStruct): |
| nonlocal rng |
| rng, data_rng = jax.random.split(rng) |
| |
| shape = spec.shape[1:] |
| if spec.dtype == jnp.float32: |
| return jax.random.uniform(data_rng, shape=shape, minval=-1.0, maxval=1.0) |
| if spec.dtype == jnp.int32: |
| return jax.random.randint(data_rng, shape=shape, minval=0, maxval=2048) |
| return jnp.zeros(shape=shape, dtype=spec.dtype) |
|
|
| observation = jax.tree.map(make_from_spec, self._observation_spec) |
| action = jax.tree.map(make_from_spec, self._action_spec) |
|
|
| return { |
| **observation.to_dict(), |
| "actions": action, |
| } |
|
|
| def __len__(self) -> int: |
| return self._num_samples |
|
|
|
|
| def create_torch_dataset( |
| data_config: _config.DataConfig, action_horizon: int, model_config: _model.BaseModelConfig |
| ) -> Dataset: |
| """Create a dataset for training. Online-sliding-chunk path only. |
| |
| The wrapper requires ``data_config.online_sliding_chunks=True`` so the |
| underlying source LeRobot dataset is fanned out by OnlineSlidingChunkDataset |
| (variable speeds applied on the fly per access). The legacy offline path |
| that read pre-expanded variable-speed datasets has been removed. |
| """ |
| repo_id = data_config.repo_id |
| if repo_id is None: |
| raise ValueError("Repo ID is not set. Cannot create dataset.") |
| if repo_id == "fake": |
| return FakeDataset(model_config, num_samples=1024) |
| if not data_config.online_sliding_chunks: |
| raise ValueError( |
| "Only the online-sliding-chunk dataset path is supported. " |
| "Set data_config.online_sliding_chunks=True and provide online_sliding_speeds." |
| ) |
|
|
| dataset_meta = lerobot_dataset.LeRobotDatasetMetadata(repo_id) |
| dataset = lerobot_dataset.LeRobotDataset(data_config.repo_id) |
| dataset = OnlineSlidingChunkDataset( |
| dataset, |
| data_config.online_sliding_speeds, |
| action_horizon, |
| ) |
|
|
| if data_config.prompt_from_task: |
| dataset = TransformedDataset(dataset, [_transforms.PromptFromLeRobotTask(dataset_meta.tasks)]) |
|
|
| return dataset |
|
|
|
|
| def transform_dataset(dataset: Dataset, data_config: _config.DataConfig, *, skip_norm_stats: bool = False) -> Dataset: |
| """Transform the dataset by applying the data transforms.""" |
| norm_stats = {} |
| if data_config.repo_id != "fake" and not skip_norm_stats: |
| if data_config.norm_stats is None: |
| raise ValueError( |
| "Normalization stats not found. " |
| "Make sure to run `scripts/compute_norm_stats.py --config-name=<your-config>`." |
| ) |
| norm_stats = data_config.norm_stats |
|
|
| return TransformedDataset( |
| dataset, |
| [ |
| *data_config.repack_transforms.inputs, |
| *data_config.data_transforms.inputs, |
| _transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm), |
| *data_config.model_transforms.inputs, |
| ], |
| ) |
|
|
|
|
| def create_data_loader( |
| config: _config.TrainConfig, |
| *, |
| sharding: jax.sharding.Sharding | None = None, |
| shuffle: bool = False, |
| num_batches: int | None = None, |
| skip_norm_stats: bool = False, |
| framework: Literal["jax", "pytorch"] = "jax", |
| ) -> DataLoader[tuple[_model.Observation, _model.Actions]]: |
| """Create a data loader for training. |
| |
| Args: |
| config: The training configuration. |
| sharding: The sharding to use for the data loader (JAX only). |
| shuffle: Whether to shuffle the data. |
| num_batches: Determines the number of batches to return. |
| skip_norm_stats: Whether to skip data normalization. |
| framework: The framework to use ("jax" or "pytorch"). |
| """ |
| data_config = config.data.create(config.assets_dirs, config.model) |
| logging.info(f"data_config: {data_config}") |
|
|
| return create_torch_data_loader( |
| data_config, |
| model_config=config.model, |
| action_horizon=config.model.action_horizon, |
| batch_size=config.batch_size, |
| sharding=sharding, |
| shuffle=shuffle, |
| num_batches=num_batches, |
| num_workers=config.num_workers, |
| seed=config.seed, |
| skip_norm_stats=skip_norm_stats, |
| framework=framework, |
| ) |
|
|
|
|
| def create_torch_data_loader( |
| data_config: _config.DataConfig, |
| model_config: _model.BaseModelConfig, |
| action_horizon: int, |
| batch_size: int, |
| *, |
| sharding: jax.sharding.Sharding | None = None, |
| skip_norm_stats: bool = False, |
| shuffle: bool = False, |
| num_batches: int | None = None, |
| num_workers: int = 0, |
| seed: int = 0, |
| framework: str = "jax", |
| ) -> DataLoader[tuple[_model.Observation, _model.Actions]]: |
| """Create a data loader for training. |
| |
| Args: |
| data_config: The data configuration. |
| action_horizon: The action horizon. |
| batch_size: The batch size. |
| sharding: The sharding to use for the data loader. If None, the data loader will |
| use a single device sharding. |
| skip_norm_stats: Whether to skip data normalization. |
| shuffle: Whether to shuffle the data. |
| num_batches: Determines the number of batches to return. If the number exceeds the |
| number of batches in the dataset, the data loader will loop over the dataset. |
| If not provided, will iterate over the dataset indefinitely. |
| num_workers: The number of worker processes to use. If zero, the data loader will |
| execute in the main process. |
| seed: The seed to use for shuffling the data. |
| """ |
| dataset = create_torch_dataset(data_config, action_horizon, model_config) |
| dataset = transform_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats) |
|
|
| |
| |
| |
| sampler = None |
| if framework == "pytorch": |
| if torch.distributed.is_initialized(): |
| sampler = torch.utils.data.distributed.DistributedSampler( |
| dataset, |
| num_replicas=torch.distributed.get_world_size(), |
| rank=torch.distributed.get_rank(), |
| shuffle=shuffle, |
| drop_last=True, |
| ) |
| local_batch_size = batch_size // torch.distributed.get_world_size() |
| else: |
| local_batch_size = batch_size |
| else: |
| local_batch_size = batch_size // jax.process_count() |
|
|
| logging.info(f"local_batch_size: {local_batch_size}") |
| data_loader = TorchDataLoader( |
| dataset, |
| local_batch_size=local_batch_size, |
| sharding=None if framework == "pytorch" else sharding, |
| shuffle=(sampler is None and shuffle), |
| sampler=sampler, |
| num_batches=num_batches, |
| num_workers=num_workers, |
| seed=seed, |
| framework=framework, |
| ) |
|
|
| return DataLoaderImpl(data_config, data_loader) |
|
|
|
|
| class TorchDataLoader: |
| """Torch data loader implementation.""" |
|
|
| def __init__( |
| self, |
| dataset, |
| local_batch_size: int, |
| *, |
| sharding: jax.sharding.Sharding | None = None, |
| shuffle: bool = False, |
| sampler: torch.utils.data.Sampler | None = None, |
| num_batches: int | None = None, |
| num_workers: int = 0, |
| seed: int = 0, |
| framework: str = "jax", |
| ): |
| """Create a PyTorch data loader. |
| |
| Args: |
| dataset: The dataset to load. |
| local_batch_size: The local batch size for each process. |
| sharding: The sharding to use for the data loader. |
| shuffle: Whether to shuffle the data. |
| num_batches: If provided, determines the number of returned batches. If the |
| number is larger than the number of batches in the dataset, the data loader |
| will loop over the dataset. If not provided, will iterate over the dataset |
| indefinitely. |
| num_workers: The number of worker processes to use. If zero, the data loader will |
| execute in the main process. |
| seed: The seed to use for shuffling the data. |
| """ |
| if jax.process_count() > 1: |
| raise NotImplementedError("Data loading with multiple processes is not supported.") |
|
|
| if len(dataset) < local_batch_size: |
| raise ValueError(f"Local batch size ({local_batch_size}) is larger than the dataset size ({len(dataset)}).") |
|
|
| |
| self._sharding = sharding |
| if sharding is None and framework == "jax": |
| |
| self._sharding = jax.sharding.NamedSharding( |
| jax.sharding.Mesh(jax.devices(), ("B",)), |
| jax.sharding.PartitionSpec("B"), |
| ) |
| self._num_batches = num_batches |
|
|
| mp_context = None |
| if num_workers > 0: |
| mp_context = multiprocessing.get_context("spawn") |
|
|
| generator = torch.Generator() |
| generator.manual_seed(seed) |
| self._data_loader = torch.utils.data.DataLoader( |
| typing.cast(torch.utils.data.Dataset, dataset), |
| batch_size=local_batch_size, |
| shuffle=(sampler is None and shuffle), |
| sampler=sampler, |
| num_workers=num_workers, |
| multiprocessing_context=mp_context, |
| persistent_workers=num_workers > 0, |
| collate_fn=_collate_fn, |
| worker_init_fn=_worker_init_fn, |
| drop_last=True, |
| generator=generator, |
| ) |
|
|
| @property |
| def torch_loader(self) -> torch.utils.data.DataLoader: |
| return self._data_loader |
|
|
| def __iter__(self): |
| num_items = 0 |
| while True: |
| data_iter = iter(self._data_loader) |
| while True: |
| if self._num_batches is not None and num_items >= self._num_batches: |
| return |
| try: |
| batch = next(data_iter) |
| except StopIteration: |
| break |
| num_items += 1 |
| |
| if self._sharding is not None: |
| yield jax.tree.map(lambda x: jax.make_array_from_process_local_data(self._sharding, x), batch) |
| else: |
| yield jax.tree.map(torch.as_tensor, batch) |
|
|
|
|
| def _collate_fn(items): |
| """Collate the batch elements into batched numpy arrays.""" |
| |
| |
| return jax.tree.map(lambda *xs: np.stack([np.asarray(x) for x in xs], axis=0), *items) |
|
|
|
|
| def _worker_init_fn(worker_id: int) -> None: |
| """Tell JAX inside the worker process not to preallocate the GPU memory.""" |
| |
| |
| os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" |
| os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" |
|
|
|
|
| class DataLoaderImpl(DataLoader): |
| def __init__(self, data_config: _config.DataConfig, data_loader: TorchDataLoader): |
| self._data_config = data_config |
| self._data_loader = data_loader |
|
|
| def data_config(self) -> _config.DataConfig: |
| return self._data_config |
|
|
| def __iter__(self): |
| for batch in self._data_loader: |
| yield _model.Observation.from_dict(batch), batch["actions"] |
|
|