Instructions to use StrongRoboticsLab/pi05-so100-diverse with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- LeRobot
How to use StrongRoboticsLab/pi05-so100-diverse with LeRobot:
- Notebooks
- Google Colab
- Kaggle
| #!/usr/bin/env python | |
| # Copyright 2024 The HuggingFace Inc. team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from unittest.mock import patch | |
| import numpy as np | |
| import pytest | |
| from lerobot.datasets.compute_stats import ( | |
| RunningQuantileStats, | |
| _assert_type_and_shape, | |
| aggregate_feature_stats, | |
| aggregate_stats, | |
| compute_episode_stats, | |
| estimate_num_samples, | |
| get_feature_stats, | |
| sample_images, | |
| sample_indices, | |
| ) | |
| from lerobot.utils.constants import OBS_IMAGE, OBS_STATE | |
| def mock_load_image_as_numpy(path, dtype, channel_first): | |
| return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype) | |
| def sample_array(): | |
| return np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) | |
| def test_estimate_num_samples(): | |
| assert estimate_num_samples(1) == 1 | |
| assert estimate_num_samples(10) == 10 | |
| assert estimate_num_samples(100) == 100 | |
| assert estimate_num_samples(200) == 100 | |
| assert estimate_num_samples(1000) == 177 | |
| assert estimate_num_samples(2000) == 299 | |
| assert estimate_num_samples(5000) == 594 | |
| assert estimate_num_samples(10_000) == 1000 | |
| assert estimate_num_samples(20_000) == 1681 | |
| assert estimate_num_samples(50_000) == 3343 | |
| assert estimate_num_samples(500_000) == 10_000 | |
| def test_sample_indices(): | |
| indices = sample_indices(10) | |
| assert len(indices) > 0 | |
| assert indices[0] == 0 | |
| assert indices[-1] == 9 | |
| assert len(indices) == estimate_num_samples(10) | |
| def test_sample_images(mock_load): | |
| image_paths = [f"image_{i}.jpg" for i in range(100)] | |
| images = sample_images(image_paths) | |
| assert isinstance(images, np.ndarray) | |
| assert images.shape[1:] == (3, 32, 32) | |
| assert images.dtype == np.uint8 | |
| assert len(images) == estimate_num_samples(100) | |
| def test_get_feature_stats_images(): | |
| data = np.random.rand(100, 3, 32, 32) | |
| stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True) | |
| assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats | |
| np.testing.assert_equal(stats["count"], np.array([100])) | |
| assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape | |
| def test_get_feature_stats_axis_0_keepdims(sample_array): | |
| expected = { | |
| "min": np.array([[1, 2, 3]]), | |
| "max": np.array([[7, 8, 9]]), | |
| "mean": np.array([[4.0, 5.0, 6.0]]), | |
| "std": np.array([[2.44948974, 2.44948974, 2.44948974]]), | |
| "count": np.array([3]), | |
| } | |
| result = get_feature_stats(sample_array, axis=(0,), keepdims=True) | |
| for key in expected: | |
| np.testing.assert_allclose(result[key], expected[key]) | |
| def test_get_feature_stats_axis_1(sample_array): | |
| expected = { | |
| "min": np.array([1, 4, 7]), | |
| "max": np.array([3, 6, 9]), | |
| "mean": np.array([2.0, 5.0, 8.0]), | |
| "std": np.array([0.81649658, 0.81649658, 0.81649658]), | |
| "count": np.array([3]), | |
| } | |
| result = get_feature_stats(sample_array, axis=(1,), keepdims=False) | |
| # Check that basic stats are correct (quantiles are also included now) | |
| assert set(expected.keys()).issubset(set(result.keys())) | |
| for key in expected: | |
| np.testing.assert_allclose(result[key], expected[key]) | |
| def test_get_feature_stats_no_axis(sample_array): | |
| expected = { | |
| "min": np.array(1), | |
| "max": np.array(9), | |
| "mean": np.array(5.0), | |
| "std": np.array(2.5819889), | |
| "count": np.array([3]), | |
| } | |
| result = get_feature_stats(sample_array, axis=None, keepdims=False) | |
| # Check that basic stats are correct (quantiles are also included now) | |
| assert set(expected.keys()).issubset(set(result.keys())) | |
| for key in expected: | |
| np.testing.assert_allclose(result[key], expected[key]) | |
| def test_get_feature_stats_empty_array(): | |
| array = np.array([]) | |
| with pytest.raises(ValueError): | |
| get_feature_stats(array, axis=(0,), keepdims=True) | |
| def test_get_feature_stats_single_value(): | |
| array = np.array([[1337]]) | |
| result = get_feature_stats(array, axis=None, keepdims=True) | |
| np.testing.assert_equal(result["min"], np.array(1337)) | |
| np.testing.assert_equal(result["max"], np.array(1337)) | |
| np.testing.assert_equal(result["mean"], np.array(1337.0)) | |
| np.testing.assert_equal(result["std"], np.array(0.0)) | |
| np.testing.assert_equal(result["count"], np.array([1])) | |
| def test_compute_episode_stats(): | |
| episode_data = { | |
| OBS_IMAGE: [f"image_{i}.jpg" for i in range(100)], | |
| OBS_STATE: np.random.rand(100, 10), | |
| } | |
| features = { | |
| OBS_IMAGE: {"dtype": "image"}, | |
| OBS_STATE: {"dtype": "numeric"}, | |
| } | |
| with patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy): | |
| stats = compute_episode_stats(episode_data, features) | |
| assert OBS_IMAGE in stats and OBS_STATE in stats | |
| assert stats[OBS_IMAGE]["count"].item() == 100 | |
| assert stats[OBS_STATE]["count"].item() == 100 | |
| assert stats[OBS_IMAGE]["mean"].shape == (3, 1, 1) | |
| def test_assert_type_and_shape_valid(): | |
| valid_stats = [ | |
| { | |
| "feature1": { | |
| "min": np.array([1.0]), | |
| "max": np.array([10.0]), | |
| "mean": np.array([5.0]), | |
| "std": np.array([2.0]), | |
| "count": np.array([1]), | |
| } | |
| } | |
| ] | |
| _assert_type_and_shape(valid_stats) | |
| def test_assert_type_and_shape_invalid_type(): | |
| invalid_stats = [ | |
| { | |
| "feature1": { | |
| "min": [1.0], # Not a numpy array | |
| "max": np.array([10.0]), | |
| "mean": np.array([5.0]), | |
| "std": np.array([2.0]), | |
| "count": np.array([1]), | |
| } | |
| } | |
| ] | |
| with pytest.raises(ValueError, match="Stats must be composed of numpy array"): | |
| _assert_type_and_shape(invalid_stats) | |
| def test_assert_type_and_shape_invalid_shape(): | |
| invalid_stats = [ | |
| { | |
| "feature1": { | |
| "count": np.array([1, 2]), # Wrong shape | |
| } | |
| } | |
| ] | |
| with pytest.raises(ValueError, match=r"Shape of 'count' must be \(1\)"): | |
| _assert_type_and_shape(invalid_stats) | |
| def test_aggregate_feature_stats(): | |
| stats_ft_list = [ | |
| { | |
| "min": np.array([1.0]), | |
| "max": np.array([10.0]), | |
| "mean": np.array([5.0]), | |
| "std": np.array([2.0]), | |
| "count": np.array([1]), | |
| }, | |
| { | |
| "min": np.array([2.0]), | |
| "max": np.array([12.0]), | |
| "mean": np.array([6.0]), | |
| "std": np.array([2.5]), | |
| "count": np.array([1]), | |
| }, | |
| ] | |
| result = aggregate_feature_stats(stats_ft_list) | |
| np.testing.assert_allclose(result["min"], np.array([1.0])) | |
| np.testing.assert_allclose(result["max"], np.array([12.0])) | |
| np.testing.assert_allclose(result["mean"], np.array([5.5])) | |
| np.testing.assert_allclose(result["std"], np.array([2.318405]), atol=1e-6) | |
| np.testing.assert_allclose(result["count"], np.array([2])) | |
| def test_aggregate_stats(): | |
| all_stats = [ | |
| { | |
| OBS_IMAGE: { | |
| "min": [1, 2, 3], | |
| "max": [10, 20, 30], | |
| "mean": [5.5, 10.5, 15.5], | |
| "std": [2.87, 5.87, 8.87], | |
| "count": 10, | |
| }, | |
| OBS_STATE: {"min": 1, "max": 10, "mean": 5.5, "std": 2.87, "count": 10}, | |
| "extra_key_0": {"min": 5, "max": 25, "mean": 15, "std": 6, "count": 6}, | |
| }, | |
| { | |
| OBS_IMAGE: { | |
| "min": [2, 1, 0], | |
| "max": [15, 10, 5], | |
| "mean": [8.5, 5.5, 2.5], | |
| "std": [3.42, 2.42, 1.42], | |
| "count": 15, | |
| }, | |
| OBS_STATE: {"min": 2, "max": 15, "mean": 8.5, "std": 3.42, "count": 15}, | |
| "extra_key_1": {"min": 0, "max": 20, "mean": 10, "std": 5, "count": 5}, | |
| }, | |
| ] | |
| expected_agg_stats = { | |
| OBS_IMAGE: { | |
| "min": [1, 1, 0], | |
| "max": [15, 20, 30], | |
| "mean": [7.3, 7.5, 7.7], | |
| "std": [3.5317, 4.8267, 8.5581], | |
| "count": 25, | |
| }, | |
| OBS_STATE: { | |
| "min": 1, | |
| "max": 15, | |
| "mean": 7.3, | |
| "std": 3.5317, | |
| "count": 25, | |
| }, | |
| "extra_key_0": { | |
| "min": 5, | |
| "max": 25, | |
| "mean": 15.0, | |
| "std": 6.0, | |
| "count": 6, | |
| }, | |
| "extra_key_1": { | |
| "min": 0, | |
| "max": 20, | |
| "mean": 10.0, | |
| "std": 5.0, | |
| "count": 5, | |
| }, | |
| } | |
| # cast to numpy | |
| for ep_stats in all_stats: | |
| for fkey, stats in ep_stats.items(): | |
| for k in stats: | |
| stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32) | |
| if fkey == OBS_IMAGE and k != "count": | |
| stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels | |
| else: | |
| stats[k] = stats[k].reshape(1) | |
| # cast to numpy | |
| for fkey, stats in expected_agg_stats.items(): | |
| for k in stats: | |
| stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32) | |
| if fkey == OBS_IMAGE and k != "count": | |
| stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels | |
| else: | |
| stats[k] = stats[k].reshape(1) | |
| results = aggregate_stats(all_stats) | |
| for fkey in expected_agg_stats: | |
| np.testing.assert_allclose(results[fkey]["min"], expected_agg_stats[fkey]["min"]) | |
| np.testing.assert_allclose(results[fkey]["max"], expected_agg_stats[fkey]["max"]) | |
| np.testing.assert_allclose(results[fkey]["mean"], expected_agg_stats[fkey]["mean"]) | |
| np.testing.assert_allclose( | |
| results[fkey]["std"], expected_agg_stats[fkey]["std"], atol=1e-04, rtol=1e-04 | |
| ) | |
| np.testing.assert_allclose(results[fkey]["count"], expected_agg_stats[fkey]["count"]) | |
| def test_running_quantile_stats_initialization(): | |
| """Test proper initialization of RunningQuantileStats.""" | |
| running_stats = RunningQuantileStats() | |
| assert running_stats._count == 0 | |
| assert running_stats._mean is None | |
| assert running_stats._num_quantile_bins == 5000 | |
| # Test custom bin size | |
| running_stats_custom = RunningQuantileStats(num_quantile_bins=1000) | |
| assert running_stats_custom._num_quantile_bins == 1000 | |
| def test_running_quantile_stats_single_batch_update(): | |
| """Test updating with a single batch.""" | |
| np.random.seed(42) | |
| data = np.random.normal(0, 1, (100, 3)) | |
| running_stats = RunningQuantileStats() | |
| running_stats.update(data) | |
| assert running_stats._count == 100 | |
| assert running_stats._mean.shape == (3,) | |
| assert len(running_stats._histograms) == 3 | |
| assert len(running_stats._bin_edges) == 3 | |
| # Verify basic statistics are reasonable | |
| np.testing.assert_allclose(running_stats._mean, np.mean(data, axis=0), atol=1e-10) | |
| def test_running_quantile_stats_multiple_batch_updates(): | |
| """Test updating with multiple batches.""" | |
| np.random.seed(42) | |
| data1 = np.random.normal(0, 1, (100, 2)) | |
| data2 = np.random.normal(1, 1, (150, 2)) | |
| running_stats = RunningQuantileStats() | |
| running_stats.update(data1) | |
| running_stats.update(data2) | |
| assert running_stats._count == 250 | |
| # Verify running mean is correct | |
| combined_data = np.vstack([data1, data2]) | |
| expected_mean = np.mean(combined_data, axis=0) | |
| np.testing.assert_allclose(running_stats._mean, expected_mean, atol=1e-10) | |
| def test_running_quantile_stats_get_statistics_basic(): | |
| """Test getting basic statistics without quantiles.""" | |
| np.random.seed(42) | |
| data = np.random.normal(0, 1, (100, 2)) | |
| running_stats = RunningQuantileStats() | |
| running_stats.update(data) | |
| stats = running_stats.get_statistics() | |
| # Should have basic stats | |
| expected_keys = {"min", "max", "mean", "std", "count"} | |
| assert expected_keys.issubset(set(stats.keys())) | |
| # Verify values | |
| np.testing.assert_allclose(stats["mean"], np.mean(data, axis=0), atol=1e-10) | |
| np.testing.assert_allclose(stats["std"], np.std(data, axis=0), atol=1e-6) | |
| np.testing.assert_equal(stats["count"], np.array([100])) | |
| def test_running_quantile_stats_get_statistics_with_quantiles(): | |
| """Test getting statistics with quantiles.""" | |
| np.random.seed(42) | |
| data = np.random.normal(0, 1, (1000, 2)) | |
| running_stats = RunningQuantileStats() | |
| running_stats.update(data) | |
| stats = running_stats.get_statistics() | |
| # Should have basic stats plus quantiles | |
| expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"} | |
| assert expected_keys.issubset(set(stats.keys())) | |
| # Verify quantile values are reasonable | |
| from lerobot.datasets.compute_stats import DEFAULT_QUANTILES | |
| for i, q in enumerate(DEFAULT_QUANTILES): | |
| q_key = f"q{int(q * 100):02d}" | |
| assert q_key in stats | |
| assert stats[q_key].shape == (2,) | |
| # Check that quantiles are in reasonable order | |
| if i > 0: | |
| prev_q_key = f"q{int(DEFAULT_QUANTILES[i - 1] * 100):02d}" | |
| assert np.all(stats[prev_q_key] <= stats[q_key]) | |
| def test_running_quantile_stats_histogram_adjustment(): | |
| """Test that histograms adjust when min/max change.""" | |
| running_stats = RunningQuantileStats() | |
| # Initial data with small range | |
| data1 = np.array([[0.0, 1.0], [0.1, 1.1], [0.2, 1.2]]) | |
| running_stats.update(data1) | |
| initial_edges_0 = running_stats._bin_edges[0].copy() | |
| initial_edges_1 = running_stats._bin_edges[1].copy() | |
| # Add data with much larger range | |
| data2 = np.array([[10.0, -10.0], [11.0, -11.0]]) | |
| running_stats.update(data2) | |
| # Bin edges should have changed | |
| assert not np.array_equal(initial_edges_0, running_stats._bin_edges[0]) | |
| assert not np.array_equal(initial_edges_1, running_stats._bin_edges[1]) | |
| # New edges should cover the expanded range | |
| # First dimension: min should still be ~0.0, max should be ~11.0 | |
| assert running_stats._bin_edges[0][0] <= 0.0 | |
| assert running_stats._bin_edges[0][-1] >= 11.0 | |
| # Second dimension: min should be ~-11.0, max should be ~1.2 | |
| assert running_stats._bin_edges[1][0] <= -11.0 | |
| assert running_stats._bin_edges[1][-1] >= 1.2 | |
| def test_running_quantile_stats_insufficient_data_error(): | |
| """Test error when trying to get stats with insufficient data.""" | |
| running_stats = RunningQuantileStats() | |
| with pytest.raises(ValueError, match="Cannot compute statistics for less than 2 vectors"): | |
| running_stats.get_statistics() | |
| # Single vector should also fail | |
| running_stats.update(np.array([[1.0]])) | |
| with pytest.raises(ValueError, match="Cannot compute statistics for less than 2 vectors"): | |
| running_stats.get_statistics() | |
| def test_running_quantile_stats_vector_length_consistency(): | |
| """Test error when vector lengths don't match.""" | |
| running_stats = RunningQuantileStats() | |
| running_stats.update(np.array([[1.0, 2.0], [3.0, 4.0]])) | |
| with pytest.raises(ValueError, match="The length of new vectors does not match"): | |
| running_stats.update(np.array([[1.0, 2.0, 3.0]])) # Different length | |
| def test_running_quantile_stats_reshape_handling(): | |
| """Test that various input shapes are handled correctly.""" | |
| running_stats = RunningQuantileStats() | |
| # Test 3D input (e.g., images) | |
| data_3d = np.random.normal(0, 1, (10, 32, 32)) | |
| running_stats.update(data_3d) | |
| assert running_stats._count == 10 * 32 | |
| assert running_stats._mean.shape == (32,) | |
| # Test 1D input | |
| running_stats_1d = RunningQuantileStats() | |
| data_1d = np.array([1, 2, 3, 4, 5]).reshape(-1, 1) | |
| running_stats_1d.update(data_1d) | |
| assert running_stats_1d._count == 5 | |
| assert running_stats_1d._mean.shape == (1,) | |
| def test_get_feature_stats_quantiles_enabled_by_default(): | |
| """Test that quantiles are computed by default.""" | |
| data = np.random.normal(0, 1, (100, 5)) | |
| stats = get_feature_stats(data, axis=0, keepdims=False) | |
| expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"} | |
| assert set(stats.keys()) == expected_keys | |
| def test_get_feature_stats_quantiles_with_vector_data(): | |
| """Test quantile computation with vector data.""" | |
| np.random.seed(42) | |
| data = np.random.normal(0, 1, (100, 5)) | |
| stats = get_feature_stats(data, axis=0, keepdims=False) | |
| expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"} | |
| assert set(stats.keys()) == expected_keys | |
| # Verify shapes | |
| assert stats["q01"].shape == (5,) | |
| assert stats["q99"].shape == (5,) | |
| # Verify quantiles are reasonable | |
| assert np.all(stats["q01"] < stats["q99"]) | |
| def test_get_feature_stats_quantiles_with_image_data(): | |
| """Test quantile computation with image data.""" | |
| np.random.seed(42) | |
| data = np.random.normal(0, 1, (50, 3, 32, 32)) # batch, channels, height, width | |
| stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True) | |
| expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"} | |
| assert set(stats.keys()) == expected_keys | |
| # Verify shapes for images (should be (1, channels, 1, 1)) | |
| assert stats["q01"].shape == (1, 3, 1, 1) | |
| assert stats["q50"].shape == (1, 3, 1, 1) | |
| assert stats["q99"].shape == (1, 3, 1, 1) | |
| def test_get_feature_stats_fixed_quantiles(): | |
| """Test that fixed quantiles are always computed.""" | |
| data = np.random.normal(0, 1, (200, 3)) | |
| stats = get_feature_stats(data, axis=0, keepdims=False) | |
| expected_quantile_keys = {"q01", "q10", "q50", "q90", "q99"} | |
| assert expected_quantile_keys.issubset(set(stats.keys())) | |
| def test_get_feature_stats_unsupported_axis_error(): | |
| """Test error for unsupported axis configuration.""" | |
| data = np.random.normal(0, 1, (10, 5)) | |
| with pytest.raises(ValueError, match="Unsupported axis configuration"): | |
| get_feature_stats( | |
| data, | |
| axis=(1, 2), # Unsupported axis | |
| keepdims=False, | |
| ) | |
| def test_compute_episode_stats_backward_compatibility(): | |
| """Test that existing functionality is preserved.""" | |
| episode_data = { | |
| "action": np.random.normal(0, 1, (100, 7)), | |
| "observation.state": np.random.normal(0, 1, (100, 10)), | |
| } | |
| features = { | |
| "action": {"dtype": "float32", "shape": (7,)}, | |
| "observation.state": {"dtype": "float32", "shape": (10,)}, | |
| } | |
| stats = compute_episode_stats(episode_data, features) | |
| for key in ["action", "observation.state"]: | |
| expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"} | |
| assert set(stats[key].keys()) == expected_keys | |
| def test_compute_episode_stats_with_custom_quantiles(): | |
| """Test quantile computation with custom quantile values.""" | |
| np.random.seed(42) | |
| episode_data = { | |
| "action": np.random.normal(0, 1, (100, 7)), | |
| "observation.state": np.random.normal(2, 1, (100, 10)), | |
| } | |
| features = { | |
| "action": {"dtype": "float32", "shape": (7,)}, | |
| "observation.state": {"dtype": "float32", "shape": (10,)}, | |
| } | |
| stats = compute_episode_stats(episode_data, features) | |
| # Should have quantiles | |
| for key in ["action", "observation.state"]: | |
| expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"} | |
| assert set(stats[key].keys()) == expected_keys | |
| # Verify shapes | |
| assert stats[key]["q01"].shape == (features[key]["shape"][0],) | |
| assert stats[key]["q99"].shape == (features[key]["shape"][0],) | |
| def test_compute_episode_stats_with_image_data(): | |
| """Test quantile computation with image features.""" | |
| image_paths = [f"image_{i}.jpg" for i in range(50)] | |
| episode_data = { | |
| "observation.image": image_paths, | |
| "action": np.random.normal(0, 1, (50, 5)), | |
| } | |
| features = { | |
| "observation.image": {"dtype": "image"}, | |
| "action": {"dtype": "float32", "shape": (5,)}, | |
| } | |
| with patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy): | |
| stats = compute_episode_stats(episode_data, features) | |
| # Image quantiles should be normalized and have correct shape | |
| assert "q01" in stats["observation.image"] | |
| assert "q50" in stats["observation.image"] | |
| assert "q99" in stats["observation.image"] | |
| assert stats["observation.image"]["q01"].shape == (3, 1, 1) | |
| assert stats["observation.image"]["q50"].shape == (3, 1, 1) | |
| assert stats["observation.image"]["q99"].shape == (3, 1, 1) | |
| # Action quantiles should have correct shape | |
| assert stats["action"]["q01"].shape == (5,) | |
| assert stats["action"]["q50"].shape == (5,) | |
| assert stats["action"]["q99"].shape == (5,) | |
| def test_compute_episode_stats_string_features_skipped(): | |
| """Test that string features are properly skipped.""" | |
| episode_data = { | |
| "task": ["pick_apple"] * 100, # String feature | |
| "action": np.random.normal(0, 1, (100, 5)), | |
| } | |
| features = { | |
| "task": {"dtype": "string"}, | |
| "action": {"dtype": "float32", "shape": (5,)}, | |
| } | |
| stats = compute_episode_stats( | |
| episode_data, | |
| features, | |
| ) | |
| # String features should be skipped | |
| assert "task" not in stats | |
| assert "action" in stats | |
| assert "q01" in stats["action"] | |
| def test_aggregate_feature_stats_with_quantiles(): | |
| """Test aggregating feature stats that include quantiles.""" | |
| stats_ft_list = [ | |
| { | |
| "min": np.array([1.0]), | |
| "max": np.array([10.0]), | |
| "mean": np.array([5.0]), | |
| "std": np.array([2.0]), | |
| "count": np.array([100]), | |
| "q01": np.array([1.5]), | |
| "q99": np.array([9.5]), | |
| }, | |
| { | |
| "min": np.array([2.0]), | |
| "max": np.array([12.0]), | |
| "mean": np.array([6.0]), | |
| "std": np.array([2.5]), | |
| "count": np.array([150]), | |
| "q01": np.array([2.5]), | |
| "q99": np.array([11.5]), | |
| }, | |
| ] | |
| result = aggregate_feature_stats(stats_ft_list) | |
| # Should preserve quantiles | |
| assert "q01" in result | |
| assert "q99" in result | |
| # Verify quantile aggregation (weighted average) | |
| expected_q01 = (1.5 * 100 + 2.5 * 150) / 250 # ≈ 2.1 | |
| expected_q99 = (9.5 * 100 + 11.5 * 150) / 250 # ≈ 10.7 | |
| np.testing.assert_allclose(result["q01"], np.array([expected_q01]), atol=1e-6) | |
| np.testing.assert_allclose(result["q99"], np.array([expected_q99]), atol=1e-6) | |
| def test_aggregate_stats_mixed_quantiles(): | |
| """Test aggregating stats where some have quantiles and some don't.""" | |
| stats_with_quantiles = { | |
| "feature1": { | |
| "min": np.array([1.0]), | |
| "max": np.array([10.0]), | |
| "mean": np.array([5.0]), | |
| "std": np.array([2.0]), | |
| "count": np.array([100]), | |
| "q01": np.array([1.5]), | |
| "q99": np.array([9.5]), | |
| } | |
| } | |
| stats_without_quantiles = { | |
| "feature2": { | |
| "min": np.array([0.0]), | |
| "max": np.array([5.0]), | |
| "mean": np.array([2.5]), | |
| "std": np.array([1.5]), | |
| "count": np.array([50]), | |
| } | |
| } | |
| all_stats = [stats_with_quantiles, stats_without_quantiles] | |
| result = aggregate_stats(all_stats) | |
| # Feature1 should keep its quantiles | |
| assert "q01" in result["feature1"] | |
| assert "q99" in result["feature1"] | |
| # Feature2 should not have quantiles | |
| assert "q01" not in result["feature2"] | |
| assert "q99" not in result["feature2"] | |
| def test_assert_type_and_shape_with_quantiles(): | |
| """Test validation works correctly with quantile keys.""" | |
| # Valid stats with quantiles | |
| valid_stats = [ | |
| { | |
| "observation.image": { | |
| "min": np.array([0.0, 0.0, 0.0]).reshape(3, 1, 1), | |
| "max": np.array([1.0, 1.0, 1.0]).reshape(3, 1, 1), | |
| "mean": np.array([0.5, 0.5, 0.5]).reshape(3, 1, 1), | |
| "std": np.array([0.2, 0.2, 0.2]).reshape(3, 1, 1), | |
| "count": np.array([100]), | |
| "q01": np.array([0.1, 0.1, 0.1]).reshape(3, 1, 1), | |
| "q99": np.array([0.9, 0.9, 0.9]).reshape(3, 1, 1), | |
| } | |
| } | |
| ] | |
| # Should not raise error | |
| _assert_type_and_shape(valid_stats) | |
| # Invalid shape for quantile | |
| invalid_stats = [ | |
| { | |
| "observation.image": { | |
| "count": np.array([100]), | |
| "q01": np.array([0.1, 0.2]), # Wrong shape for image quantile | |
| } | |
| } | |
| ] | |
| with pytest.raises(ValueError, match="Shape of quantile 'q01' must be \\(3,1,1\\)"): | |
| _assert_type_and_shape(invalid_stats) | |
| def test_quantile_integration_single_value_quantiles(): | |
| """Test quantile computation with single repeated value.""" | |
| data = np.ones((100, 3)) # All ones | |
| running_stats = RunningQuantileStats() | |
| running_stats.update(data) | |
| stats = running_stats.get_statistics() | |
| # All quantiles should be approximately 1.0 | |
| np.testing.assert_allclose(stats["q01"], np.array([1.0, 1.0, 1.0]), atol=1e-6) | |
| np.testing.assert_allclose(stats["q50"], np.array([1.0, 1.0, 1.0]), atol=1e-6) | |
| np.testing.assert_allclose(stats["q99"], np.array([1.0, 1.0, 1.0]), atol=1e-6) | |
| def test_quantile_integration_fixed_quantiles(): | |
| """Test that fixed quantiles are computed.""" | |
| np.random.seed(42) | |
| data = np.random.normal(0, 1, (1000, 2)) | |
| stats = get_feature_stats(data, axis=0, keepdims=False) | |
| # Check all fixed quantiles are present | |
| assert "q01" in stats | |
| assert "q10" in stats | |
| assert "q50" in stats | |
| assert "q90" in stats | |
| assert "q99" in stats | |
| def test_quantile_integration_large_dataset_quantiles(): | |
| """Test quantile computation efficiency with large datasets.""" | |
| np.random.seed(42) | |
| large_data = np.random.normal(0, 1, (10000, 5)) | |
| running_stats = RunningQuantileStats(num_quantile_bins=1000) # Reduced bins for speed | |
| running_stats.update(large_data) | |
| stats = running_stats.get_statistics() | |
| # Should complete without issues and produce reasonable results | |
| assert stats["count"][0] == 10000 | |
| assert len(stats["q01"]) == 5 | |
| def test_fixed_quantiles_always_computed(): | |
| """Test that the fixed quantiles [0.01, 0.10, 0.50, 0.90, 0.99] are always computed.""" | |
| np.random.seed(42) | |
| # Test with vector data | |
| vector_data = np.random.normal(0, 1, (100, 5)) | |
| vector_stats = get_feature_stats(vector_data, axis=0, keepdims=False) | |
| # Check all fixed quantiles are present | |
| expected_quantiles = ["q01", "q10", "q50", "q90", "q99"] | |
| for q_key in expected_quantiles: | |
| assert q_key in vector_stats | |
| assert vector_stats[q_key].shape == (5,) | |
| # Test with image data | |
| image_data = np.random.randint(0, 256, (50, 3, 32, 32), dtype=np.uint8) | |
| image_stats = get_feature_stats(image_data, axis=(0, 2, 3), keepdims=True) | |
| # Check all fixed quantiles are present for images | |
| for q_key in expected_quantiles: | |
| assert q_key in image_stats | |
| assert image_stats[q_key].shape == (1, 3, 1, 1) | |
| # Test with episode data | |
| episode_data = { | |
| "action": np.random.normal(0, 1, (100, 7)), | |
| "observation.state": np.random.normal(0, 1, (100, 10)), | |
| } | |
| features = { | |
| "action": {"dtype": "float32", "shape": (7,)}, | |
| "observation.state": {"dtype": "float32", "shape": (10,)}, | |
| } | |
| episode_stats = compute_episode_stats(episode_data, features) | |
| # Check all fixed quantiles are present in episode stats | |
| for key in ["action", "observation.state"]: | |
| for q_key in expected_quantiles: | |
| assert q_key in episode_stats[key] | |
| assert episode_stats[key][q_key].shape == (features[key]["shape"][0],) | |