from collections.abc import Iterator, Sequence import dataclasses from dataclasses import dataclass import logging import multiprocessing import os import typing from typing import Literal, Protocol, SupportsIndex, TypeVar import jax import jax.numpy as jnp import lerobot.common.datasets.lerobot_dataset as lerobot_dataset import numpy as np import torch import openpi.models.model as _model import openpi.training.config as _config import openpi.transforms as _transforms from various_speed.core import ( SpeedTransformConfig, _speed_chunk_ratio, transform_episode, ) T_co = TypeVar("T_co", covariant=True) class Dataset(Protocol[T_co]): """Interface for a dataset with random access.""" def __getitem__(self, index: SupportsIndex) -> T_co: raise NotImplementedError("Subclasses of Dataset should implement __getitem__.") def __len__(self) -> int: raise NotImplementedError("Subclasses of Dataset should implement __len__.") class DataLoader(Protocol[T_co]): """Interface for a data loader.""" def data_config(self) -> _config.DataConfig: """Get the data config for this data loader.""" raise NotImplementedError("Subclasses of DataLoader should implement data_config.") def __iter__(self) -> Iterator[T_co]: raise NotImplementedError("Subclasses of DataLoader should implement __iter__.") class TransformedDataset(Dataset[T_co]): def __init__(self, dataset: Dataset, transforms: Sequence[_transforms.DataTransformFn]): self._dataset = dataset self._transform = _transforms.compose(transforms) def __getitem__(self, index: SupportsIndex) -> T_co: return self._transform(self._dataset[index]) def __len__(self) -> int: return len(self._dataset) @dataclass(frozen=True) class _OnlineSlidingSample: """One DataLoader-addressable training entry: just (episode, speed). Segments and chunk_phase are NOT pre-bound here -- they're decided at __getitem__ time so each access gets a fresh random phase / random row. """ episode_position: int episode_start: int # global flat-dataset index where this episode begins episode_end: int # global flat-dataset index where this episode ends (exclusive) speed_index: int speed: float class OnlineSlidingChunkDataset(Dataset[dict]): """Online speed augmentation. Each access does: 1. (episode, speed) is picked by DataLoader index (every pair once per epoch). 2. speed == 1.0: identity fast path -- random offset, slice source directly. 3. speed != 1.0: a. Random chunk_phase r in [0, q-1] (q = numerator of speed=q/p). b. transform_episode(full_episode, speed, phase=r) -> re-timed trajectory. c. Pick ONE row uniformly from {observation_mask == 1}. d. actions = transformed_action[row : row + H], end-clamped. Important: * Segments are an internal partition for cumulative integration -- they don't restrict where training samples can start. Any mask=1 row is fair. * Random phase per access is the SOLE source of resampling randomness; no caching (cache hit rate would be ~1/q anyway -- multi-worker DataLoader parallelizes the per-sample compute instead). * Norm stats: source 1.0x stats are reused for ALL speeds by design (see ``_reuse_source_one_x_norm_stats``). Different speeds yield slightly different post-norm action magnitudes; the model handles this via speed conditioning. """ def __init__( self, dataset: Dataset, speeds: Sequence[float], action_horizon: int, *, speed_config: SpeedTransformConfig | None = None, ): if action_horizon <= 0: raise ValueError(f"action_horizon must be positive, got {action_horizon}") if not speeds: raise ValueError("online_sliding_speeds must contain at least one speed.") self._dataset = dataset self._speeds = tuple(float(speed) for speed in speeds) self._action_horizon = int(action_horizon) self._speed_config = speed_config or SpeedTransformConfig(chunk_aligned_observation=True) # Force chunk_aligned: we rely on source_step_index in the transformed # output to locate the chunk-start row, and that's only meaningful # when chunks are aligned to integer source positions. if not self._speed_config.chunk_aligned_observation: self._speed_config = dataclasses.replace(self._speed_config, chunk_aligned_observation=True) # Pull state/action/index columns from the LeRobot HF dataset into raw # numpy. This bypasses the heavy image-decoding path for the speed # transform; image/prompt are still fetched per-access via self._dataset. hf_dataset = getattr(dataset, "hf_dataset", None) if hf_dataset is None: raise ValueError("Online sliding chunks require a LeRobot-style dataset with hf_dataset.") raw = hf_dataset.with_format(None) columns = set(getattr(raw, "column_names", getattr(hf_dataset, "column_names", ()))) self._action_key = _first_existing_column(columns, ("actions", "action")) self._state_key = _first_existing_column(columns, ("state", "observation.state")) if "episode_index" not in columns: raise ValueError("Online sliding chunks require an episode_index column.") self._actions = _stack_vector_column(raw[self._action_key], dtype=np.float32) self._states = _stack_vector_column(raw[self._state_key], dtype=np.float32) self._episode_indices = _scalar_column(raw["episode_index"], dtype=np.int64) if "frame_index" in columns: self._frame_indices = _scalar_column(raw["frame_index"], dtype=np.int64) else: self._frame_indices = _episode_local_indices(self._episode_indices) if len(self._actions) != len(dataset) or len(self._states) != len(dataset): raise ValueError("Online sliding source arrays do not match dataset length.") self._episode_slices = _contiguous_episode_slices(self._episode_indices) self._samples = self._build_samples() if not self._samples: raise ValueError( "Online sliding chunks produced no (episode, speed) tuples. " "Check that the source dataset has episodes and online_sliding_speeds is non-empty." ) self._cache_order: list[tuple[int, int, int]] = [] def __len__(self) -> int: return len(self._samples) def _build_samples(self) -> list[_OnlineSlidingSample]: """Enumerate (episode, speed) once each. Length = num_episodes * num_speeds.""" samples: list[_OnlineSlidingSample] = [] for episode_position, (episode_start, episode_end) in enumerate(self._episode_slices): for speed_index, speed in enumerate(self._speeds): samples.append( _OnlineSlidingSample( episode_position=episode_position, episode_start=episode_start, episode_end=episode_end, speed_index=speed_index, speed=speed, ) ) return samples def __getitem__(self, index: SupportsIndex) -> dict: sample = self._samples[index.__index__()] episode_len = sample.episode_end - sample.episode_start # speed == 1.0 fast path: identity. Pick random source offset, slice # state + horizon directly; transform_episode is NOT called. Horizon # clamps to the last episode frame (BC pad-at-end convention). if sample.speed == 1.0: target_local = int(np.random.randint(0, episode_len)) global_source_step = sample.episode_start + target_local local_horizon = np.minimum( target_local + np.arange(self._action_horizon), episode_len - 1, ) global_horizon = sample.episode_start + local_horizon item = dict(self._dataset[global_source_step]) item[self._state_key] = self._states[global_source_step].astype(np.float32) item["actions"] = self._actions[global_horizon].astype(np.float32) item[self._action_key] = item["actions"] item["speed"] = np.asarray([1.0], dtype=np.float32) item["speed_index"] = np.asarray(sample.speed_index, dtype=np.int64) item["speed_label"] = _speed_label(1.0) item["chunk_phase"] = np.asarray(0, dtype=np.int64) item["chunk_phase_count"] = np.asarray(1, dtype=np.int64) item["valid_mask"] = np.asarray(1, dtype=np.int8) item["observation_mask"] = np.asarray(1, dtype=np.int8) item["source_step_index"] = np.asarray(target_local, dtype=np.int64) return item # ---- Variable-speed path ---- # 1. Random phase per access -> rotates which source frames become # chunk-starts across epochs. q, _p = _speed_chunk_ratio(sample.speed) chunk_phase = int(np.random.randint(0, q)) # 2. Transform the FULL episode. Output is one continuous trajectory # (segments concatenated end-to-end) with mask=1 only at chunk- # aligned starts (and -- for phase=0 -- trailing passthrough). config = dataclasses.replace(self._speed_config, chunk_phase=chunk_phase) transformed, _metrics = transform_episode( self._actions[sample.episode_start:sample.episode_end], self._states[sample.episode_start:sample.episode_end], self._frame_indices[sample.episode_start:sample.episode_end], sample.speed, config, ) # 3. Pick ANY valid row uniformly. Internal segment boundaries don't # constrain where the training chunk-start can land. valid_indices = np.flatnonzero(transformed["observation_mask"] == 1) if valid_indices.size == 0: # All segments shorter than q under the chosen phase. Rare on # LIBERO (q <= 8 and segments are typically tens of frames). raise RuntimeError( f"Online sliding produced 0 valid rows. episode_position=" f"{sample.episode_position}, speed={sample.speed}, " f"chunk_phase={chunk_phase}, transformed_len={len(transformed['action'])}." ) row_index = int(valid_indices[np.random.randint(len(valid_indices))]) # 4. Slice the action horizon, end-clamped at the trajectory tail # (segment boundaries inside the horizon are crossed transparently). action_indices = np.minimum( row_index + np.arange(self._action_horizon), len(transformed["action"]) - 1, ) # 5. Assemble: image/prompt come from the underlying LeRobot row at the # source frame; state/actions are overwritten with the transformed # values; speed metadata is exposed for prompt/modulation transforms. source_step = int(transformed["source_step_index"][row_index]) global_source_step = sample.episode_start + source_step item = dict(self._dataset[global_source_step]) item[self._state_key] = transformed["state"][row_index].astype(np.float32) item["actions"] = transformed["action"][action_indices].astype(np.float32) item[self._action_key] = item["actions"] item["speed"] = np.asarray([sample.speed], dtype=np.float32) item["speed_index"] = np.asarray(sample.speed_index, dtype=np.int64) item["speed_label"] = _speed_label(sample.speed) item["chunk_phase"] = np.asarray(chunk_phase, dtype=np.int64) item["chunk_phase_count"] = np.asarray(q, dtype=np.int64) item["valid_mask"] = np.asarray(1, dtype=np.int8) item["observation_mask"] = np.asarray(1, dtype=np.int8) item["source_step_index"] = np.asarray(source_step, dtype=np.int64) return item def _first_existing_column(columns: set[str], candidates: Sequence[str]) -> str: for candidate in candidates: if candidate in columns: return candidate raise ValueError(f"None of the expected columns were found: {tuple(candidates)}") def _stack_vector_column(values, *, dtype: np.dtype) -> np.ndarray: return np.asarray([np.asarray(value, dtype=dtype).reshape(-1) for value in values], dtype=dtype) def _scalar_column(values, *, dtype: np.dtype) -> np.ndarray: return np.asarray([np.asarray(value).reshape(-1)[0] for value in values], dtype=dtype) def _episode_local_indices(episode_indices: np.ndarray) -> np.ndarray: out = np.zeros(len(episode_indices), dtype=np.int64) for start, end in _contiguous_episode_slices(episode_indices): out[start:end] = np.arange(end - start, dtype=np.int64) return out def _contiguous_episode_slices(episode_indices: np.ndarray) -> list[tuple[int, int]]: if len(episode_indices) == 0: return [] slices: list[tuple[int, int]] = [] start = 0 for idx in range(1, len(episode_indices)): if episode_indices[idx] != episode_indices[idx - 1]: slices.append((start, idx)) start = idx slices.append((start, len(episode_indices))) return slices def _speed_label(speed: float) -> str: text = f"{speed:g}".replace(".", "p") return f"{text}x" class FakeDataset(Dataset): def __init__(self, model_config: _model.BaseModelConfig, num_samples: int): self._num_samples = num_samples self._observation_spec, self._action_spec = model_config.inputs_spec() def __getitem__(self, index: SupportsIndex) -> dict: rng = jax.random.key(index.__index__()) def make_from_spec(spec: jax.ShapeDtypeStruct): nonlocal rng rng, data_rng = jax.random.split(rng) # Remove the batch dimension. shape = spec.shape[1:] if spec.dtype == jnp.float32: return jax.random.uniform(data_rng, shape=shape, minval=-1.0, maxval=1.0) if spec.dtype == jnp.int32: return jax.random.randint(data_rng, shape=shape, minval=0, maxval=2048) return jnp.zeros(shape=shape, dtype=spec.dtype) observation = jax.tree.map(make_from_spec, self._observation_spec) action = jax.tree.map(make_from_spec, self._action_spec) return { **observation.to_dict(), "actions": action, } def __len__(self) -> int: return self._num_samples def create_torch_dataset( data_config: _config.DataConfig, action_horizon: int, model_config: _model.BaseModelConfig ) -> Dataset: """Create a dataset for training. Online-sliding-chunk path only. The wrapper requires ``data_config.online_sliding_chunks=True`` so the underlying source LeRobot dataset is fanned out by OnlineSlidingChunkDataset (variable speeds applied on the fly per access). The legacy offline path that read pre-expanded variable-speed datasets has been removed. """ repo_id = data_config.repo_id if repo_id is None: raise ValueError("Repo ID is not set. Cannot create dataset.") if repo_id == "fake": return FakeDataset(model_config, num_samples=1024) if not data_config.online_sliding_chunks: raise ValueError( "Only the online-sliding-chunk dataset path is supported. " "Set data_config.online_sliding_chunks=True and provide online_sliding_speeds." ) dataset_meta = lerobot_dataset.LeRobotDatasetMetadata(repo_id) dataset = lerobot_dataset.LeRobotDataset(data_config.repo_id) dataset = OnlineSlidingChunkDataset( dataset, data_config.online_sliding_speeds, action_horizon, ) if data_config.prompt_from_task: dataset = TransformedDataset(dataset, [_transforms.PromptFromLeRobotTask(dataset_meta.tasks)]) return dataset def transform_dataset(dataset: Dataset, data_config: _config.DataConfig, *, skip_norm_stats: bool = False) -> Dataset: """Transform the dataset by applying the data transforms.""" norm_stats = {} if data_config.repo_id != "fake" and not skip_norm_stats: if data_config.norm_stats is None: raise ValueError( "Normalization stats not found. " "Make sure to run `scripts/compute_norm_stats.py --config-name=`." ) norm_stats = data_config.norm_stats return TransformedDataset( dataset, [ *data_config.repack_transforms.inputs, *data_config.data_transforms.inputs, _transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm), *data_config.model_transforms.inputs, ], ) def create_data_loader( config: _config.TrainConfig, *, sharding: jax.sharding.Sharding | None = None, shuffle: bool = False, num_batches: int | None = None, skip_norm_stats: bool = False, framework: Literal["jax", "pytorch"] = "jax", ) -> DataLoader[tuple[_model.Observation, _model.Actions]]: """Create a data loader for training. Args: config: The training configuration. sharding: The sharding to use for the data loader (JAX only). shuffle: Whether to shuffle the data. num_batches: Determines the number of batches to return. skip_norm_stats: Whether to skip data normalization. framework: The framework to use ("jax" or "pytorch"). """ data_config = config.data.create(config.assets_dirs, config.model) logging.info(f"data_config: {data_config}") return create_torch_data_loader( data_config, model_config=config.model, action_horizon=config.model.action_horizon, batch_size=config.batch_size, sharding=sharding, shuffle=shuffle, num_batches=num_batches, num_workers=config.num_workers, seed=config.seed, skip_norm_stats=skip_norm_stats, framework=framework, ) def create_torch_data_loader( data_config: _config.DataConfig, model_config: _model.BaseModelConfig, action_horizon: int, batch_size: int, *, sharding: jax.sharding.Sharding | None = None, skip_norm_stats: bool = False, shuffle: bool = False, num_batches: int | None = None, num_workers: int = 0, seed: int = 0, framework: str = "jax", ) -> DataLoader[tuple[_model.Observation, _model.Actions]]: """Create a data loader for training. Args: data_config: The data configuration. action_horizon: The action horizon. batch_size: The batch size. sharding: The sharding to use for the data loader. If None, the data loader will use a single device sharding. skip_norm_stats: Whether to skip data normalization. shuffle: Whether to shuffle the data. num_batches: Determines the number of batches to return. If the number exceeds the number of batches in the dataset, the data loader will loop over the dataset. If not provided, will iterate over the dataset indefinitely. num_workers: The number of worker processes to use. If zero, the data loader will execute in the main process. seed: The seed to use for shuffling the data. """ dataset = create_torch_dataset(data_config, action_horizon, model_config) dataset = transform_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats) # Use TorchDataLoader for both frameworks # For PyTorch DDP, create DistributedSampler and divide batch size by world size # For JAX, divide by process count sampler = None if framework == "pytorch": if torch.distributed.is_initialized(): sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=torch.distributed.get_world_size(), rank=torch.distributed.get_rank(), shuffle=shuffle, drop_last=True, ) local_batch_size = batch_size // torch.distributed.get_world_size() else: local_batch_size = batch_size else: local_batch_size = batch_size // jax.process_count() logging.info(f"local_batch_size: {local_batch_size}") data_loader = TorchDataLoader( dataset, local_batch_size=local_batch_size, sharding=None if framework == "pytorch" else sharding, shuffle=(sampler is None and shuffle), # Don't shuffle if using sampler sampler=sampler, num_batches=num_batches, num_workers=num_workers, seed=seed, framework=framework, ) return DataLoaderImpl(data_config, data_loader) class TorchDataLoader: """Torch data loader implementation.""" def __init__( self, dataset, local_batch_size: int, *, sharding: jax.sharding.Sharding | None = None, shuffle: bool = False, sampler: torch.utils.data.Sampler | None = None, num_batches: int | None = None, num_workers: int = 0, seed: int = 0, framework: str = "jax", ): """Create a PyTorch data loader. Args: dataset: The dataset to load. local_batch_size: The local batch size for each process. sharding: The sharding to use for the data loader. shuffle: Whether to shuffle the data. num_batches: If provided, determines the number of returned batches. If the number is larger than the number of batches in the dataset, the data loader will loop over the dataset. If not provided, will iterate over the dataset indefinitely. num_workers: The number of worker processes to use. If zero, the data loader will execute in the main process. seed: The seed to use for shuffling the data. """ if jax.process_count() > 1: raise NotImplementedError("Data loading with multiple processes is not supported.") if len(dataset) < local_batch_size: raise ValueError(f"Local batch size ({local_batch_size}) is larger than the dataset size ({len(dataset)}).") # Store sharding - None for PyTorch, JAX sharding for JAX self._sharding = sharding if sharding is None and framework == "jax": # Use data parallel sharding by default for JAX only. self._sharding = jax.sharding.NamedSharding( jax.sharding.Mesh(jax.devices(), ("B",)), jax.sharding.PartitionSpec("B"), ) self._num_batches = num_batches mp_context = None if num_workers > 0: mp_context = multiprocessing.get_context("spawn") generator = torch.Generator() generator.manual_seed(seed) self._data_loader = torch.utils.data.DataLoader( typing.cast(torch.utils.data.Dataset, dataset), batch_size=local_batch_size, shuffle=(sampler is None and shuffle), # Don't shuffle if using sampler sampler=sampler, num_workers=num_workers, multiprocessing_context=mp_context, persistent_workers=num_workers > 0, collate_fn=_collate_fn, worker_init_fn=_worker_init_fn, drop_last=True, generator=generator, ) @property def torch_loader(self) -> torch.utils.data.DataLoader: return self._data_loader def __iter__(self): num_items = 0 while True: data_iter = iter(self._data_loader) while True: if self._num_batches is not None and num_items >= self._num_batches: return try: batch = next(data_iter) except StopIteration: break # We've exhausted the dataset. Create a new iterator and start over. num_items += 1 # For JAX, convert to sharded arrays; for PyTorch, return torch tensors if self._sharding is not None: yield jax.tree.map(lambda x: jax.make_array_from_process_local_data(self._sharding, x), batch) else: yield jax.tree.map(torch.as_tensor, batch) def _collate_fn(items): """Collate the batch elements into batched numpy arrays.""" # Make sure to convert to numpy arrays before stacking since some of the incoming elements # may be JAX arrays. return jax.tree.map(lambda *xs: np.stack([np.asarray(x) for x in xs], axis=0), *items) def _worker_init_fn(worker_id: int) -> None: """Tell JAX inside the worker process not to preallocate the GPU memory.""" # NOTE: This is called after jax is imported inside the worker process. This # means that this approach will not work for selecting the backend. os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" class DataLoaderImpl(DataLoader): def __init__(self, data_config: _config.DataConfig, data_loader: TorchDataLoader): self._data_config = data_config self._data_loader = data_loader def data_config(self) -> _config.DataConfig: return self._data_config def __iter__(self): for batch in self._data_loader: yield _model.Observation.from_dict(batch), batch["actions"]