Instructions to use StrongRoboticsLab/pi05-so100-diverse with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- LeRobot
How to use StrongRoboticsLab/pi05-so100-diverse with LeRobot:
- Notebooks
- Google Colab
- Kaggle
| #!/usr/bin/env python | |
| # Copyright 2025 The HuggingFace Inc. team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Tests for dataset tools utilities.""" | |
| from unittest.mock import patch | |
| import numpy as np | |
| import pytest | |
| import torch | |
| from lerobot.datasets.dataset_tools import ( | |
| add_features, | |
| delete_episodes, | |
| merge_datasets, | |
| modify_features, | |
| modify_tasks, | |
| remove_feature, | |
| split_dataset, | |
| ) | |
| from lerobot.scripts.lerobot_edit_dataset import convert_image_to_video_dataset | |
| def sample_dataset(tmp_path, empty_lerobot_dataset_factory): | |
| """Create a sample dataset for testing.""" | |
| features = { | |
| "action": {"dtype": "float32", "shape": (6,), "names": None}, | |
| "observation.state": {"dtype": "float32", "shape": (4,), "names": None}, | |
| "observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None}, | |
| } | |
| dataset = empty_lerobot_dataset_factory( | |
| root=tmp_path / "test_dataset", | |
| features=features, | |
| ) | |
| for ep_idx in range(5): | |
| for _ in range(10): | |
| frame = { | |
| "action": np.random.randn(6).astype(np.float32), | |
| "observation.state": np.random.randn(4).astype(np.float32), | |
| "observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8), | |
| "task": f"task_{ep_idx % 2}", | |
| } | |
| dataset.add_frame(frame) | |
| dataset.save_episode() | |
| dataset.finalize() | |
| return dataset | |
| def test_delete_single_episode(sample_dataset, tmp_path): | |
| """Test deleting a single episode.""" | |
| output_dir = tmp_path / "filtered" | |
| with ( | |
| patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, | |
| patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, | |
| ): | |
| mock_get_safe_version.return_value = "v3.0" | |
| mock_snapshot_download.return_value = str(output_dir) | |
| new_dataset = delete_episodes( | |
| sample_dataset, | |
| episode_indices=[2], | |
| output_dir=output_dir, | |
| ) | |
| assert new_dataset.meta.total_episodes == 4 | |
| assert new_dataset.meta.total_frames == 40 | |
| episode_indices = {int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]} | |
| assert episode_indices == {0, 1, 2, 3} | |
| assert len(new_dataset) == 40 | |
| def test_delete_multiple_episodes(sample_dataset, tmp_path): | |
| """Test deleting multiple episodes.""" | |
| output_dir = tmp_path / "filtered" | |
| with ( | |
| patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, | |
| patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, | |
| ): | |
| mock_get_safe_version.return_value = "v3.0" | |
| mock_snapshot_download.return_value = str(output_dir) | |
| new_dataset = delete_episodes( | |
| sample_dataset, | |
| episode_indices=[1, 3], | |
| output_dir=output_dir, | |
| ) | |
| assert new_dataset.meta.total_episodes == 3 | |
| assert new_dataset.meta.total_frames == 30 | |
| episode_indices = {int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]} | |
| assert episode_indices == {0, 1, 2} | |
| def test_delete_invalid_episodes(sample_dataset, tmp_path): | |
| """Test error handling for invalid episode indices.""" | |
| with pytest.raises(ValueError, match="Invalid episode indices"): | |
| delete_episodes( | |
| sample_dataset, | |
| episode_indices=[10, 20], | |
| output_dir=tmp_path / "filtered", | |
| ) | |
| def test_delete_all_episodes(sample_dataset, tmp_path): | |
| """Test error when trying to delete all episodes.""" | |
| with pytest.raises(ValueError, match="Cannot delete all episodes"): | |
| delete_episodes( | |
| sample_dataset, | |
| episode_indices=list(range(5)), | |
| output_dir=tmp_path / "filtered", | |
| ) | |
| def test_delete_empty_list(sample_dataset, tmp_path): | |
| """Test error when no episodes specified.""" | |
| with pytest.raises(ValueError, match="No episodes to delete"): | |
| delete_episodes( | |
| sample_dataset, | |
| episode_indices=[], | |
| output_dir=tmp_path / "filtered", | |
| ) | |
| def test_split_by_episodes(sample_dataset, tmp_path): | |
| """Test splitting dataset by specific episode indices.""" | |
| splits = { | |
| "train": [0, 1, 2], | |
| "val": [3, 4], | |
| } | |
| with ( | |
| patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, | |
| patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, | |
| ): | |
| mock_get_safe_version.return_value = "v3.0" | |
| def mock_snapshot(repo_id, **kwargs): | |
| if "train" in repo_id: | |
| return str(tmp_path / f"{sample_dataset.repo_id}_train") | |
| elif "val" in repo_id: | |
| return str(tmp_path / f"{sample_dataset.repo_id}_val") | |
| return str(kwargs.get("local_dir", tmp_path)) | |
| mock_snapshot_download.side_effect = mock_snapshot | |
| result = split_dataset( | |
| sample_dataset, | |
| splits=splits, | |
| output_dir=tmp_path, | |
| ) | |
| assert set(result.keys()) == {"train", "val"} | |
| assert result["train"].meta.total_episodes == 3 | |
| assert result["train"].meta.total_frames == 30 | |
| assert result["val"].meta.total_episodes == 2 | |
| assert result["val"].meta.total_frames == 20 | |
| train_episodes = {int(idx.item()) for idx in result["train"].hf_dataset["episode_index"]} | |
| assert train_episodes == {0, 1, 2} | |
| val_episodes = {int(idx.item()) for idx in result["val"].hf_dataset["episode_index"]} | |
| assert val_episodes == {0, 1} | |
| def test_split_by_fractions(sample_dataset, tmp_path): | |
| """Test splitting dataset by fractions.""" | |
| splits = { | |
| "train": 0.6, | |
| "val": 0.4, | |
| } | |
| with ( | |
| patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, | |
| patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, | |
| ): | |
| mock_get_safe_version.return_value = "v3.0" | |
| def mock_snapshot(repo_id, **kwargs): | |
| for split_name in splits: | |
| if split_name in repo_id: | |
| return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}") | |
| return str(kwargs.get("local_dir", tmp_path)) | |
| mock_snapshot_download.side_effect = mock_snapshot | |
| result = split_dataset( | |
| sample_dataset, | |
| splits=splits, | |
| output_dir=tmp_path, | |
| ) | |
| assert result["train"].meta.total_episodes == 3 | |
| assert result["val"].meta.total_episodes == 2 | |
| def test_split_overlapping_episodes(sample_dataset, tmp_path): | |
| """Test error when episodes appear in multiple splits.""" | |
| splits = { | |
| "train": [0, 1, 2], | |
| "val": [2, 3, 4], | |
| } | |
| with pytest.raises(ValueError, match="Episodes cannot appear in multiple splits"): | |
| split_dataset(sample_dataset, splits=splits, output_dir=tmp_path) | |
| def test_split_invalid_fractions(sample_dataset, tmp_path): | |
| """Test error when fractions sum to more than 1.""" | |
| splits = { | |
| "train": 0.7, | |
| "val": 0.5, | |
| } | |
| with pytest.raises(ValueError, match="Split fractions must sum to <= 1.0"): | |
| split_dataset(sample_dataset, splits=splits, output_dir=tmp_path) | |
| def test_split_empty(sample_dataset, tmp_path): | |
| """Test error with empty splits.""" | |
| with pytest.raises(ValueError, match="No splits provided"): | |
| split_dataset(sample_dataset, splits={}, output_dir=tmp_path) | |
| def test_merge_two_datasets(sample_dataset, tmp_path, empty_lerobot_dataset_factory): | |
| """Test merging two datasets.""" | |
| features = { | |
| "action": {"dtype": "float32", "shape": (6,), "names": None}, | |
| "observation.state": {"dtype": "float32", "shape": (4,), "names": None}, | |
| "observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None}, | |
| } | |
| dataset2 = empty_lerobot_dataset_factory( | |
| root=tmp_path / "test_dataset2", | |
| features=features, | |
| ) | |
| for ep_idx in range(3): | |
| for _ in range(10): | |
| frame = { | |
| "action": np.random.randn(6).astype(np.float32), | |
| "observation.state": np.random.randn(4).astype(np.float32), | |
| "observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8), | |
| "task": f"task_{ep_idx % 2}", | |
| } | |
| dataset2.add_frame(frame) | |
| dataset2.save_episode() | |
| dataset2.finalize() | |
| with ( | |
| patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, | |
| patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, | |
| ): | |
| mock_get_safe_version.return_value = "v3.0" | |
| mock_snapshot_download.return_value = str(tmp_path / "merged_dataset") | |
| merged = merge_datasets( | |
| [sample_dataset, dataset2], | |
| output_repo_id="merged_dataset", | |
| output_dir=tmp_path / "merged_dataset", | |
| ) | |
| assert merged.meta.total_episodes == 8 # 5 + 3 | |
| assert merged.meta.total_frames == 80 # 50 + 30 | |
| episode_indices = sorted({int(idx.item()) for idx in merged.hf_dataset["episode_index"]}) | |
| assert episode_indices == list(range(8)) | |
| def test_merge_empty_list(tmp_path): | |
| """Test error when merging empty list.""" | |
| with pytest.raises(ValueError, match="No datasets to merge"): | |
| merge_datasets([], output_repo_id="merged", output_dir=tmp_path) | |
| def test_add_features_with_values(sample_dataset, tmp_path): | |
| """Test adding a feature with pre-computed values.""" | |
| num_frames = sample_dataset.meta.total_frames | |
| reward_values = np.random.randn(num_frames, 1).astype(np.float32) | |
| feature_info = { | |
| "dtype": "float32", | |
| "shape": (1,), | |
| "names": None, | |
| } | |
| features = { | |
| "reward": (reward_values, feature_info), | |
| } | |
| with ( | |
| patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, | |
| patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, | |
| ): | |
| mock_get_safe_version.return_value = "v3.0" | |
| mock_snapshot_download.return_value = str(tmp_path / "with_reward") | |
| new_dataset = add_features( | |
| dataset=sample_dataset, | |
| features=features, | |
| output_dir=tmp_path / "with_reward", | |
| ) | |
| assert "reward" in new_dataset.meta.features | |
| assert new_dataset.meta.features["reward"] == feature_info | |
| assert len(new_dataset) == num_frames | |
| sample_item = new_dataset[0] | |
| assert "reward" in sample_item | |
| assert isinstance(sample_item["reward"], torch.Tensor) | |
| def test_add_features_with_callable(sample_dataset, tmp_path): | |
| """Test adding a feature with a callable.""" | |
| def compute_reward(frame_dict, episode_idx, frame_idx): | |
| return float(episode_idx * 10 + frame_idx) | |
| feature_info = { | |
| "dtype": "float32", | |
| "shape": (1,), | |
| "names": None, | |
| } | |
| features = { | |
| "reward": (compute_reward, feature_info), | |
| } | |
| with ( | |
| patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, | |
| patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, | |
| ): | |
| mock_get_safe_version.return_value = "v3.0" | |
| mock_snapshot_download.return_value = str(tmp_path / "with_reward") | |
| new_dataset = add_features( | |
| dataset=sample_dataset, | |
| features=features, | |
| output_dir=tmp_path / "with_reward", | |
| ) | |
| assert "reward" in new_dataset.meta.features | |
| items = [new_dataset[i] for i in range(10)] | |
| first_episode_items = [item for item in items if item["episode_index"] == 0] | |
| assert len(first_episode_items) == 10 | |
| first_frame = first_episode_items[0] | |
| assert first_frame["frame_index"] == 0 | |
| assert float(first_frame["reward"]) == 0.0 | |
| def test_add_existing_feature(sample_dataset, tmp_path): | |
| """Test error when adding an existing feature.""" | |
| feature_info = {"dtype": "float32", "shape": (1,)} | |
| features = { | |
| "action": (np.zeros(50), feature_info), | |
| } | |
| with pytest.raises(ValueError, match="Feature 'action' already exists"): | |
| add_features( | |
| dataset=sample_dataset, | |
| features=features, | |
| output_dir=tmp_path / "modified", | |
| ) | |
| def test_add_feature_invalid_info(sample_dataset, tmp_path): | |
| """Test error with invalid feature info.""" | |
| with pytest.raises(ValueError, match="feature_info for 'reward' must contain keys"): | |
| add_features( | |
| dataset=sample_dataset, | |
| features={ | |
| "reward": (np.zeros(50), {"dtype": "float32"}), | |
| }, | |
| output_dir=tmp_path / "modified", | |
| ) | |
| def test_modify_features_add_and_remove(sample_dataset, tmp_path): | |
| """Test modifying features by adding and removing simultaneously.""" | |
| feature_info = {"dtype": "float32", "shape": (1,), "names": None} | |
| with ( | |
| patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, | |
| patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, | |
| ): | |
| mock_get_safe_version.return_value = "v3.0" | |
| mock_snapshot_download.return_value = str(tmp_path / "modified") | |
| # First add a feature we'll later remove | |
| dataset_with_reward = add_features( | |
| sample_dataset, | |
| features={"reward": (np.random.randn(50, 1).astype(np.float32), feature_info)}, | |
| output_dir=tmp_path / "with_reward", | |
| ) | |
| # Now use modify_features to add "success" and remove "reward" in one pass | |
| modified_dataset = modify_features( | |
| dataset_with_reward, | |
| add_features={ | |
| "success": (np.random.randn(50, 1).astype(np.float32), feature_info), | |
| }, | |
| remove_features="reward", | |
| output_dir=tmp_path / "modified", | |
| ) | |
| assert "success" in modified_dataset.meta.features | |
| assert "reward" not in modified_dataset.meta.features | |
| assert len(modified_dataset) == 50 | |
| def test_modify_features_only_add(sample_dataset, tmp_path): | |
| """Test that modify_features works with only add_features.""" | |
| feature_info = {"dtype": "float32", "shape": (1,), "names": None} | |
| with ( | |
| patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, | |
| patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, | |
| ): | |
| mock_get_safe_version.return_value = "v3.0" | |
| mock_snapshot_download.return_value = str(tmp_path / "modified") | |
| modified_dataset = modify_features( | |
| sample_dataset, | |
| add_features={ | |
| "reward": (np.random.randn(50, 1).astype(np.float32), feature_info), | |
| }, | |
| output_dir=tmp_path / "modified", | |
| ) | |
| assert "reward" in modified_dataset.meta.features | |
| assert len(modified_dataset) == 50 | |
| def test_modify_features_only_remove(sample_dataset, tmp_path): | |
| """Test that modify_features works with only remove_features.""" | |
| feature_info = {"dtype": "float32", "shape": (1,), "names": None} | |
| with ( | |
| patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, | |
| patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, | |
| ): | |
| mock_get_safe_version.return_value = "v3.0" | |
| mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) | |
| dataset_with_reward = add_features( | |
| sample_dataset, | |
| features={"reward": (np.random.randn(50, 1).astype(np.float32), feature_info)}, | |
| output_dir=tmp_path / "with_reward", | |
| ) | |
| modified_dataset = modify_features( | |
| dataset_with_reward, | |
| remove_features="reward", | |
| output_dir=tmp_path / "modified", | |
| ) | |
| assert "reward" not in modified_dataset.meta.features | |
| def test_modify_features_no_changes(sample_dataset, tmp_path): | |
| """Test error when modify_features is called with no changes.""" | |
| with pytest.raises(ValueError, match="Must specify at least one of add_features or remove_features"): | |
| modify_features( | |
| sample_dataset, | |
| output_dir=tmp_path / "modified", | |
| ) | |
| def test_remove_single_feature(sample_dataset, tmp_path): | |
| """Test removing a single feature.""" | |
| feature_info = {"dtype": "float32", "shape": (1,), "names": None} | |
| features = { | |
| "reward": (np.random.randn(50, 1).astype(np.float32), feature_info), | |
| } | |
| with ( | |
| patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, | |
| patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, | |
| ): | |
| mock_get_safe_version.return_value = "v3.0" | |
| mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) | |
| dataset_with_reward = add_features( | |
| dataset=sample_dataset, | |
| features=features, | |
| output_dir=tmp_path / "with_reward", | |
| ) | |
| dataset_without_reward = remove_feature( | |
| dataset_with_reward, | |
| feature_names="reward", | |
| output_dir=tmp_path / "without_reward", | |
| ) | |
| assert "reward" not in dataset_without_reward.meta.features | |
| sample_item = dataset_without_reward[0] | |
| assert "reward" not in sample_item | |
| def test_remove_multiple_features(sample_dataset, tmp_path): | |
| """Test removing multiple features at once.""" | |
| with ( | |
| patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, | |
| patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, | |
| ): | |
| mock_get_safe_version.return_value = "v3.0" | |
| mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) | |
| dataset = sample_dataset | |
| features = {} | |
| for feature_name in ["reward", "success"]: | |
| feature_info = {"dtype": "float32", "shape": (1,), "names": None} | |
| features[feature_name] = ( | |
| np.random.randn(dataset.meta.total_frames, 1).astype(np.float32), | |
| feature_info, | |
| ) | |
| dataset_with_features = add_features( | |
| dataset, features=features, output_dir=tmp_path / "with_features" | |
| ) | |
| dataset_clean = remove_feature( | |
| dataset_with_features, feature_names=["reward", "success"], output_dir=tmp_path / "clean" | |
| ) | |
| assert "reward" not in dataset_clean.meta.features | |
| assert "success" not in dataset_clean.meta.features | |
| def test_remove_nonexistent_feature(sample_dataset, tmp_path): | |
| """Test error when removing non-existent feature.""" | |
| with pytest.raises(ValueError, match="Feature 'nonexistent' not found"): | |
| remove_feature( | |
| sample_dataset, | |
| feature_names="nonexistent", | |
| output_dir=tmp_path / "modified", | |
| ) | |
| def test_remove_required_feature(sample_dataset, tmp_path): | |
| """Test error when trying to remove required features.""" | |
| with pytest.raises(ValueError, match="Cannot remove required features"): | |
| remove_feature( | |
| sample_dataset, | |
| feature_names="timestamp", | |
| output_dir=tmp_path / "modified", | |
| ) | |
| def test_remove_camera_feature(sample_dataset, tmp_path): | |
| """Test removing a camera feature.""" | |
| camera_keys = sample_dataset.meta.camera_keys | |
| if not camera_keys: | |
| pytest.skip("No camera keys in dataset") | |
| camera_to_remove = camera_keys[0] | |
| with ( | |
| patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, | |
| patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, | |
| ): | |
| mock_get_safe_version.return_value = "v3.0" | |
| mock_snapshot_download.return_value = str(tmp_path / "without_camera") | |
| dataset_without_camera = remove_feature( | |
| sample_dataset, | |
| feature_names=camera_to_remove, | |
| output_dir=tmp_path / "without_camera", | |
| ) | |
| assert camera_to_remove not in dataset_without_camera.meta.features | |
| assert camera_to_remove not in dataset_without_camera.meta.camera_keys | |
| sample_item = dataset_without_camera[0] | |
| assert camera_to_remove not in sample_item | |
| def test_complex_workflow_integration(sample_dataset, tmp_path): | |
| """Test a complex workflow combining multiple operations.""" | |
| with ( | |
| patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, | |
| patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, | |
| ): | |
| mock_get_safe_version.return_value = "v3.0" | |
| mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) | |
| dataset = add_features( | |
| sample_dataset, | |
| features={ | |
| "reward": ( | |
| np.random.randn(50, 1).astype(np.float32), | |
| {"dtype": "float32", "shape": (1,), "names": None}, | |
| ) | |
| }, | |
| output_dir=tmp_path / "step1", | |
| ) | |
| dataset = delete_episodes( | |
| dataset, | |
| episode_indices=[2], | |
| output_dir=tmp_path / "step2", | |
| ) | |
| splits = split_dataset( | |
| dataset, | |
| splits={"train": 0.75, "val": 0.25}, | |
| output_dir=tmp_path / "step3", | |
| ) | |
| merged = merge_datasets( | |
| list(splits.values()), | |
| output_repo_id="final_dataset", | |
| output_dir=tmp_path / "step4", | |
| ) | |
| assert merged.meta.total_episodes == 4 | |
| assert merged.meta.total_frames == 40 | |
| assert "reward" in merged.meta.features | |
| assert len(merged) == 40 | |
| sample_item = merged[0] | |
| assert "reward" in sample_item | |
| def test_delete_episodes_preserves_stats(sample_dataset, tmp_path): | |
| """Test that deleting episodes preserves statistics correctly.""" | |
| output_dir = tmp_path / "filtered" | |
| with ( | |
| patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, | |
| patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, | |
| ): | |
| mock_get_safe_version.return_value = "v3.0" | |
| mock_snapshot_download.return_value = str(output_dir) | |
| new_dataset = delete_episodes( | |
| sample_dataset, | |
| episode_indices=[2], | |
| output_dir=output_dir, | |
| ) | |
| assert new_dataset.meta.stats is not None | |
| for feature in ["action", "observation.state"]: | |
| assert feature in new_dataset.meta.stats | |
| assert "mean" in new_dataset.meta.stats[feature] | |
| assert "std" in new_dataset.meta.stats[feature] | |
| def test_delete_episodes_preserves_tasks(sample_dataset, tmp_path): | |
| """Test that tasks are preserved correctly after deletion.""" | |
| output_dir = tmp_path / "filtered" | |
| with ( | |
| patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, | |
| patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, | |
| ): | |
| mock_get_safe_version.return_value = "v3.0" | |
| mock_snapshot_download.return_value = str(output_dir) | |
| new_dataset = delete_episodes( | |
| sample_dataset, | |
| episode_indices=[0], | |
| output_dir=output_dir, | |
| ) | |
| assert new_dataset.meta.tasks is not None | |
| assert len(new_dataset.meta.tasks) == 2 | |
| tasks_in_dataset = {str(item["task"]) for item in new_dataset} | |
| assert len(tasks_in_dataset) > 0 | |
| def test_split_three_ways(sample_dataset, tmp_path): | |
| """Test splitting dataset into three splits.""" | |
| splits = { | |
| "train": 0.6, | |
| "val": 0.2, | |
| "test": 0.2, | |
| } | |
| with ( | |
| patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, | |
| patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, | |
| ): | |
| mock_get_safe_version.return_value = "v3.0" | |
| def mock_snapshot(repo_id, **kwargs): | |
| for split_name in splits: | |
| if split_name in repo_id: | |
| return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}") | |
| return str(kwargs.get("local_dir", tmp_path)) | |
| mock_snapshot_download.side_effect = mock_snapshot | |
| result = split_dataset( | |
| sample_dataset, | |
| splits=splits, | |
| output_dir=tmp_path, | |
| ) | |
| assert set(result.keys()) == {"train", "val", "test"} | |
| assert result["train"].meta.total_episodes == 3 | |
| assert result["val"].meta.total_episodes == 1 | |
| assert result["test"].meta.total_episodes == 1 | |
| total_frames = sum(ds.meta.total_frames for ds in result.values()) | |
| assert total_frames == sample_dataset.meta.total_frames | |
| def test_split_preserves_stats(sample_dataset, tmp_path): | |
| """Test that statistics are preserved when splitting.""" | |
| splits = {"train": [0, 1, 2], "val": [3, 4]} | |
| with ( | |
| patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, | |
| patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, | |
| ): | |
| mock_get_safe_version.return_value = "v3.0" | |
| def mock_snapshot(repo_id, **kwargs): | |
| for split_name in splits: | |
| if split_name in repo_id: | |
| return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}") | |
| return str(kwargs.get("local_dir", tmp_path)) | |
| mock_snapshot_download.side_effect = mock_snapshot | |
| result = split_dataset( | |
| sample_dataset, | |
| splits=splits, | |
| output_dir=tmp_path, | |
| ) | |
| for split_ds in result.values(): | |
| assert split_ds.meta.stats is not None | |
| for feature in ["action", "observation.state"]: | |
| assert feature in split_ds.meta.stats | |
| assert "mean" in split_ds.meta.stats[feature] | |
| assert "std" in split_ds.meta.stats[feature] | |
| def test_merge_three_datasets(sample_dataset, tmp_path, empty_lerobot_dataset_factory): | |
| """Test merging three datasets.""" | |
| features = { | |
| "action": {"dtype": "float32", "shape": (6,), "names": None}, | |
| "observation.state": {"dtype": "float32", "shape": (4,), "names": None}, | |
| "observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None}, | |
| } | |
| datasets = [sample_dataset] | |
| for i in range(2): | |
| dataset = empty_lerobot_dataset_factory( | |
| root=tmp_path / f"test_dataset{i + 2}", | |
| features=features, | |
| ) | |
| for ep_idx in range(2): | |
| for _ in range(10): | |
| frame = { | |
| "action": np.random.randn(6).astype(np.float32), | |
| "observation.state": np.random.randn(4).astype(np.float32), | |
| "observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8), | |
| "task": f"task_{ep_idx}", | |
| } | |
| dataset.add_frame(frame) | |
| dataset.save_episode() | |
| dataset.finalize() | |
| datasets.append(dataset) | |
| with ( | |
| patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, | |
| patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, | |
| ): | |
| mock_get_safe_version.return_value = "v3.0" | |
| mock_snapshot_download.return_value = str(tmp_path / "merged_dataset") | |
| merged = merge_datasets( | |
| datasets, | |
| output_repo_id="merged_dataset", | |
| output_dir=tmp_path / "merged_dataset", | |
| ) | |
| assert merged.meta.total_episodes == 9 | |
| assert merged.meta.total_frames == 90 | |
| def test_merge_preserves_stats(sample_dataset, tmp_path, empty_lerobot_dataset_factory): | |
| """Test that statistics are computed for merged datasets.""" | |
| features = { | |
| "action": {"dtype": "float32", "shape": (6,), "names": None}, | |
| "observation.state": {"dtype": "float32", "shape": (4,), "names": None}, | |
| "observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None}, | |
| } | |
| dataset2 = empty_lerobot_dataset_factory( | |
| root=tmp_path / "test_dataset2", | |
| features=features, | |
| ) | |
| for ep_idx in range(3): | |
| for _ in range(10): | |
| frame = { | |
| "action": np.random.randn(6).astype(np.float32), | |
| "observation.state": np.random.randn(4).astype(np.float32), | |
| "observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8), | |
| "task": f"task_{ep_idx % 2}", | |
| } | |
| dataset2.add_frame(frame) | |
| dataset2.save_episode() | |
| dataset2.finalize() | |
| with ( | |
| patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, | |
| patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, | |
| ): | |
| mock_get_safe_version.return_value = "v3.0" | |
| mock_snapshot_download.return_value = str(tmp_path / "merged_dataset") | |
| merged = merge_datasets( | |
| [sample_dataset, dataset2], | |
| output_repo_id="merged_dataset", | |
| output_dir=tmp_path / "merged_dataset", | |
| ) | |
| assert merged.meta.stats is not None | |
| for feature in ["action", "observation.state"]: | |
| assert feature in merged.meta.stats | |
| assert "mean" in merged.meta.stats[feature] | |
| assert "std" in merged.meta.stats[feature] | |
| def test_add_features_preserves_existing_stats(sample_dataset, tmp_path): | |
| """Test that adding a feature preserves existing stats.""" | |
| num_frames = sample_dataset.meta.total_frames | |
| reward_values = np.random.randn(num_frames, 1).astype(np.float32) | |
| feature_info = { | |
| "dtype": "float32", | |
| "shape": (1,), | |
| "names": None, | |
| } | |
| features = { | |
| "reward": (reward_values, feature_info), | |
| } | |
| with ( | |
| patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, | |
| patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, | |
| ): | |
| mock_get_safe_version.return_value = "v3.0" | |
| mock_snapshot_download.return_value = str(tmp_path / "with_reward") | |
| new_dataset = add_features( | |
| dataset=sample_dataset, | |
| features=features, | |
| output_dir=tmp_path / "with_reward", | |
| ) | |
| assert new_dataset.meta.stats is not None | |
| for feature in ["action", "observation.state"]: | |
| assert feature in new_dataset.meta.stats | |
| assert "mean" in new_dataset.meta.stats[feature] | |
| assert "std" in new_dataset.meta.stats[feature] | |
| def test_remove_feature_updates_stats(sample_dataset, tmp_path): | |
| """Test that removing a feature removes it from stats.""" | |
| feature_info = {"dtype": "float32", "shape": (1,), "names": None} | |
| with ( | |
| patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, | |
| patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, | |
| ): | |
| mock_get_safe_version.return_value = "v3.0" | |
| mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) | |
| dataset_with_reward = add_features( | |
| sample_dataset, | |
| features={ | |
| "reward": (np.random.randn(50, 1).astype(np.float32), feature_info), | |
| }, | |
| output_dir=tmp_path / "with_reward", | |
| ) | |
| dataset_without_reward = remove_feature( | |
| dataset_with_reward, | |
| feature_names="reward", | |
| output_dir=tmp_path / "without_reward", | |
| ) | |
| if dataset_without_reward.meta.stats: | |
| assert "reward" not in dataset_without_reward.meta.stats | |
| def test_delete_consecutive_episodes(sample_dataset, tmp_path): | |
| """Test deleting consecutive episodes.""" | |
| output_dir = tmp_path / "filtered" | |
| with ( | |
| patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, | |
| patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, | |
| ): | |
| mock_get_safe_version.return_value = "v3.0" | |
| mock_snapshot_download.return_value = str(output_dir) | |
| new_dataset = delete_episodes( | |
| sample_dataset, | |
| episode_indices=[1, 2, 3], | |
| output_dir=output_dir, | |
| ) | |
| assert new_dataset.meta.total_episodes == 2 | |
| assert new_dataset.meta.total_frames == 20 | |
| episode_indices = sorted({int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]}) | |
| assert episode_indices == [0, 1] | |
| def test_delete_first_and_last_episodes(sample_dataset, tmp_path): | |
| """Test deleting first and last episodes.""" | |
| output_dir = tmp_path / "filtered" | |
| with ( | |
| patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, | |
| patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, | |
| ): | |
| mock_get_safe_version.return_value = "v3.0" | |
| mock_snapshot_download.return_value = str(output_dir) | |
| new_dataset = delete_episodes( | |
| sample_dataset, | |
| episode_indices=[0, 4], | |
| output_dir=output_dir, | |
| ) | |
| assert new_dataset.meta.total_episodes == 3 | |
| assert new_dataset.meta.total_frames == 30 | |
| episode_indices = sorted({int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]}) | |
| assert episode_indices == [0, 1, 2] | |
| def test_split_all_episodes_assigned(sample_dataset, tmp_path): | |
| """Test that all episodes can be explicitly assigned to splits.""" | |
| splits = { | |
| "split1": [0, 1], | |
| "split2": [2, 3], | |
| "split3": [4], | |
| } | |
| with ( | |
| patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, | |
| patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, | |
| ): | |
| mock_get_safe_version.return_value = "v3.0" | |
| def mock_snapshot(repo_id, **kwargs): | |
| for split_name in splits: | |
| if split_name in repo_id: | |
| return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}") | |
| return str(kwargs.get("local_dir", tmp_path)) | |
| mock_snapshot_download.side_effect = mock_snapshot | |
| result = split_dataset( | |
| sample_dataset, | |
| splits=splits, | |
| output_dir=tmp_path, | |
| ) | |
| total_episodes = sum(ds.meta.total_episodes for ds in result.values()) | |
| assert total_episodes == sample_dataset.meta.total_episodes | |
| def test_modify_features_preserves_file_structure(sample_dataset, tmp_path): | |
| """Test that modifying features preserves chunk_idx and file_idx from source dataset.""" | |
| feature_info = {"dtype": "float32", "shape": (1,), "names": None} | |
| with ( | |
| patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, | |
| patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, | |
| ): | |
| mock_get_safe_version.return_value = "v3.0" | |
| def mock_snapshot(repo_id, **kwargs): | |
| return str(kwargs.get("local_dir", tmp_path / repo_id.split("/")[-1])) | |
| mock_snapshot_download.side_effect = mock_snapshot | |
| # First split the dataset to create a non-zero starting chunk/file structure | |
| splits = split_dataset( | |
| sample_dataset, | |
| splits={"train": [0, 1, 2], "val": [3, 4]}, | |
| output_dir=tmp_path / "splits", | |
| ) | |
| train_dataset = splits["train"] | |
| # Get original chunk/file indices from first episode | |
| if train_dataset.meta.episodes is None: | |
| from lerobot.datasets.io_utils import load_episodes | |
| train_dataset.meta.episodes = load_episodes(train_dataset.meta.root) | |
| original_chunk_indices = [ep["data/chunk_index"] for ep in train_dataset.meta.episodes] | |
| original_file_indices = [ep["data/file_index"] for ep in train_dataset.meta.episodes] | |
| # Now add a feature to the split dataset | |
| modified_dataset = add_features( | |
| train_dataset, | |
| features={ | |
| "reward": ( | |
| np.random.randn(train_dataset.meta.total_frames, 1).astype(np.float32), | |
| feature_info, | |
| ), | |
| }, | |
| output_dir=tmp_path / "modified", | |
| ) | |
| # Check that chunk/file indices are preserved | |
| if modified_dataset.meta.episodes is None: | |
| from lerobot.datasets.io_utils import load_episodes | |
| modified_dataset.meta.episodes = load_episodes(modified_dataset.meta.root) | |
| new_chunk_indices = [ep["data/chunk_index"] for ep in modified_dataset.meta.episodes] | |
| new_file_indices = [ep["data/file_index"] for ep in modified_dataset.meta.episodes] | |
| assert new_chunk_indices == original_chunk_indices, "Chunk indices should be preserved" | |
| assert new_file_indices == original_file_indices, "File indices should be preserved" | |
| assert "reward" in modified_dataset.meta.features | |
| def test_modify_tasks_single_task_for_all(sample_dataset): | |
| """Test setting a single task for all episodes.""" | |
| new_task = "Pick up the cube and place it" | |
| modified_dataset = modify_tasks(sample_dataset, new_task=new_task) | |
| # Verify all episodes have the new task | |
| assert len(modified_dataset.meta.tasks) == 1 | |
| assert new_task in modified_dataset.meta.tasks.index | |
| # Verify task_index is 0 for all frames (only one task) | |
| for i in range(len(modified_dataset)): | |
| item = modified_dataset[i] | |
| assert item["task_index"].item() == 0 | |
| assert item["task"] == new_task | |
| def test_modify_tasks_episode_specific(sample_dataset): | |
| """Test setting different tasks for specific episodes.""" | |
| episode_tasks = { | |
| 0: "Task A", | |
| 1: "Task B", | |
| 2: "Task A", | |
| 3: "Task C", | |
| 4: "Task B", | |
| } | |
| modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks) | |
| # Verify correct number of unique tasks | |
| unique_tasks = set(episode_tasks.values()) | |
| assert len(modified_dataset.meta.tasks) == len(unique_tasks) | |
| # Verify each episode has the correct task | |
| for ep_idx, expected_task in episode_tasks.items(): | |
| ep_data = modified_dataset.meta.episodes[ep_idx] | |
| assert ep_data["tasks"][0] == expected_task | |
| def test_modify_tasks_default_with_overrides(sample_dataset): | |
| """Test setting a default task with specific overrides.""" | |
| default_task = "Default task" | |
| override_task = "Special task" | |
| episode_tasks = {2: override_task, 4: override_task} | |
| modified_dataset = modify_tasks( | |
| sample_dataset, | |
| new_task=default_task, | |
| episode_tasks=episode_tasks, | |
| ) | |
| # Verify correct number of unique tasks | |
| assert len(modified_dataset.meta.tasks) == 2 | |
| assert default_task in modified_dataset.meta.tasks.index | |
| assert override_task in modified_dataset.meta.tasks.index | |
| # Verify episodes have correct tasks | |
| for ep_idx in range(5): | |
| ep_data = modified_dataset.meta.episodes[ep_idx] | |
| if ep_idx in episode_tasks: | |
| assert ep_data["tasks"][0] == override_task | |
| else: | |
| assert ep_data["tasks"][0] == default_task | |
| def test_modify_tasks_no_task_specified(sample_dataset): | |
| """Test error when no task is specified.""" | |
| with pytest.raises(ValueError, match="Must specify at least one of new_task or episode_tasks"): | |
| modify_tasks(sample_dataset) | |
| def test_modify_tasks_invalid_episode_indices(sample_dataset): | |
| """Test error with invalid episode indices.""" | |
| with pytest.raises(ValueError, match="Invalid episode indices"): | |
| modify_tasks(sample_dataset, episode_tasks={10: "Task", 20: "Task"}) | |
| def test_modify_tasks_updates_info_json(sample_dataset): | |
| """Test that total_tasks is updated in info.json.""" | |
| episode_tasks = {0: "Task A", 1: "Task B", 2: "Task C", 3: "Task A", 4: "Task B"} | |
| modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks) | |
| # Verify total_tasks is updated | |
| assert modified_dataset.meta.total_tasks == 3 | |
| def test_modify_tasks_preserves_other_metadata(sample_dataset): | |
| """Test that modifying tasks preserves other metadata.""" | |
| original_frames = sample_dataset.meta.total_frames | |
| original_episodes = sample_dataset.meta.total_episodes | |
| original_fps = sample_dataset.meta.fps | |
| modified_dataset = modify_tasks(sample_dataset, new_task="New task") | |
| # Verify other metadata is preserved | |
| assert modified_dataset.meta.total_frames == original_frames | |
| assert modified_dataset.meta.total_episodes == original_episodes | |
| assert modified_dataset.meta.fps == original_fps | |
| def test_modify_tasks_task_index_correct(sample_dataset): | |
| """Test that task_index values are correct in data files.""" | |
| # Create tasks that will have predictable indices (sorted alphabetically) | |
| episode_tasks = { | |
| 0: "Alpha task", # Will be index 0 | |
| 1: "Beta task", # Will be index 1 | |
| 2: "Alpha task", # Will be index 0 | |
| 3: "Gamma task", # Will be index 2 | |
| 4: "Beta task", # Will be index 1 | |
| } | |
| modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks) | |
| # Verify task indices are correct | |
| task_to_expected_idx = { | |
| "Alpha task": 0, | |
| "Beta task": 1, | |
| "Gamma task": 2, | |
| } | |
| for i in range(len(modified_dataset)): | |
| item = modified_dataset[i] | |
| ep_idx = item["episode_index"].item() | |
| expected_task = episode_tasks[ep_idx] | |
| expected_idx = task_to_expected_idx[expected_task] | |
| assert item["task_index"].item() == expected_idx | |
| assert item["task"] == expected_task | |
| def test_modify_tasks_in_place(sample_dataset): | |
| """Test that modify_tasks modifies the dataset in-place.""" | |
| original_root = sample_dataset.root | |
| modified_dataset = modify_tasks(sample_dataset, new_task="New task") | |
| # Verify same instance is returned and root is unchanged | |
| assert modified_dataset is sample_dataset | |
| assert modified_dataset.root == original_root | |
| def test_modify_tasks_keeps_original_when_not_overridden(sample_dataset): | |
| """Test that original tasks are kept when using episode_tasks without new_task.""" | |
| from lerobot.datasets.io_utils import load_episodes | |
| # Ensure episodes metadata is loaded | |
| if sample_dataset.meta.episodes is None: | |
| sample_dataset.meta.episodes = load_episodes(sample_dataset.meta.root) | |
| # Get original tasks for episodes not being overridden | |
| original_task_ep0 = sample_dataset.meta.episodes[0]["tasks"][0] | |
| original_task_ep1 = sample_dataset.meta.episodes[1]["tasks"][0] | |
| # Only override episodes 2, 3, 4 | |
| episode_tasks = {2: "New Task A", 3: "New Task B", 4: "New Task A"} | |
| modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks) | |
| # Verify original tasks are kept for episodes 0 and 1 | |
| assert modified_dataset.meta.episodes[0]["tasks"][0] == original_task_ep0 | |
| assert modified_dataset.meta.episodes[1]["tasks"][0] == original_task_ep1 | |
| # Verify new tasks for overridden episodes | |
| assert modified_dataset.meta.episodes[2]["tasks"][0] == "New Task A" | |
| assert modified_dataset.meta.episodes[3]["tasks"][0] == "New Task B" | |
| assert modified_dataset.meta.episodes[4]["tasks"][0] == "New Task A" | |
| def test_convert_image_to_video_dataset(tmp_path): | |
| """Test converting lerobot/pusht_image dataset to video format.""" | |
| from lerobot.datasets.lerobot_dataset import LeRobotDataset | |
| # Load the actual lerobot/pusht_image dataset (only first 2 episodes for speed) | |
| source_dataset = LeRobotDataset("lerobot/pusht_image", episodes=[0, 1]) | |
| output_dir = tmp_path / "pusht_video" | |
| with ( | |
| patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, | |
| patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, | |
| ): | |
| mock_get_safe_version.return_value = "v3.0" | |
| mock_snapshot_download.return_value = str(output_dir) | |
| # Verify source dataset has images, not videos | |
| assert len(source_dataset.meta.video_keys) == 0 | |
| assert "observation.image" in source_dataset.meta.features | |
| # Convert to video dataset (only first 2 episodes for speed) | |
| video_dataset = convert_image_to_video_dataset( | |
| dataset=source_dataset, | |
| output_dir=output_dir, | |
| repo_id="lerobot/pusht_video", | |
| vcodec="libsvtav1", | |
| pix_fmt="yuv420p", | |
| g=2, | |
| crf=30, | |
| episode_indices=[0, 1], | |
| num_workers=2, | |
| ) | |
| # Verify new dataset has videos | |
| assert len(video_dataset.meta.video_keys) > 0 | |
| assert "observation.image" in video_dataset.meta.video_keys | |
| # Verify correct number of episodes and frames (2 episodes) | |
| assert video_dataset.meta.total_episodes == 2 | |
| # Compare against the actual number of frames in the loaded episodes, not metadata total | |
| assert len(video_dataset) == len(source_dataset) | |
| # Verify video files exist | |
| for ep_idx in range(video_dataset.meta.total_episodes): | |
| for video_key in video_dataset.meta.video_keys: | |
| video_path = video_dataset.root / video_dataset.meta.get_video_file_path(ep_idx, video_key) | |
| assert video_path.exists(), f"Video file should exist: {video_path}" | |
| # Verify we can load the dataset and access it | |
| assert len(video_dataset) == video_dataset.meta.total_frames | |
| # Test that we can actually get an item from the video dataset | |
| item = video_dataset[0] | |
| assert "observation.image" in item | |
| assert "action" in item | |
| # Cleanup | |
| import shutil | |
| if output_dir.exists(): | |
| shutil.rmtree(output_dir) | |
| def test_convert_image_to_video_dataset_subset_episodes(tmp_path): | |
| """Test converting only specific episodes from lerobot/pusht_image to video format.""" | |
| from lerobot.datasets.lerobot_dataset import LeRobotDataset | |
| # Load the actual lerobot/pusht_image dataset (only first 3 episodes) | |
| source_dataset = LeRobotDataset("lerobot/pusht_image", episodes=[0, 1, 2]) | |
| output_dir = tmp_path / "pusht_video_subset" | |
| with ( | |
| patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, | |
| patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, | |
| ): | |
| mock_get_safe_version.return_value = "v3.0" | |
| mock_snapshot_download.return_value = str(output_dir) | |
| # Convert only episode 0 to video (subset of loaded episodes) | |
| episode_indices = [0] | |
| video_dataset = convert_image_to_video_dataset( | |
| dataset=source_dataset, | |
| output_dir=output_dir, | |
| repo_id="lerobot/pusht_video_subset", | |
| episode_indices=episode_indices, | |
| num_workers=2, | |
| ) | |
| # Verify correct number of episodes | |
| assert video_dataset.meta.total_episodes == len(episode_indices) | |
| # Verify video files exist for selected episodes | |
| assert len(video_dataset.meta.video_keys) > 0 | |
| assert "observation.image" in video_dataset.meta.video_keys | |
| # Cleanup | |
| import shutil | |
| if output_dir.exists(): | |
| shutil.rmtree(output_dir) | |