| import numpy as np |
| import pytest |
|
|
| from various_speed.core import SpeedTransformConfig |
| from various_speed.core import clean_near_zero_actions |
| from various_speed.core import compute_replay_metrics |
| from various_speed.core import segment_actions |
| from various_speed.core import transform_episode |
|
|
|
|
| def _episode(length: int = 8) -> tuple[np.ndarray, np.ndarray, np.ndarray]: |
| actions = np.zeros((length, 7), dtype=np.float32) |
| actions[:, 0] = 1.0 |
| actions[:, 3] = 0.1 |
| actions[: length // 2, 6] = -1.0 |
| actions[length // 2 :, 6] = 1.0 |
| states = np.zeros((length, 8), dtype=np.float32) |
| frames = np.arange(length, dtype=np.int64) |
| return actions, states, frames |
|
|
|
|
| def test_clean_near_zero_actions_does_not_touch_gripper(): |
| actions = np.zeros((2, 7), dtype=np.float32) |
| actions[:, 6] = [-1.0, 1.0] |
|
|
| cleaned, mask = clean_near_zero_actions(actions, transl_eps=1e-3, rot_eps=1e-3) |
|
|
| np.testing.assert_array_equal(cleaned[:, 6], actions[:, 6]) |
| np.testing.assert_array_equal(mask, np.ones((2, 2), dtype=bool)) |
|
|
|
|
| def test_gripper_only_actions_do_not_create_motion_segments(): |
| actions = np.zeros((4, 7), dtype=np.float32) |
| actions[:, 6] = [-1.0, -1.0, 1.0, 1.0] |
|
|
| segments = segment_actions(actions, SpeedTransformConfig()) |
|
|
| assert segments == [(0, 4, 0)] |
|
|
|
|
| def test_slow_transform_pads_every_repeated_observation(): |
| actions, states, frames = _episode(length=6) |
|
|
| transformed, metrics = transform_episode(actions, states, frames, 0.5, SpeedTransformConfig()) |
|
|
| assert len(transformed["action"]) == 12 |
| assert int(transformed["is_padded"].sum()) == 6 |
| assert float(metrics["padded_ratio"]) == 0.5 |
| assert metrics["gripper_switch_delta"] == 0 |
|
|
|
|
| def test_fast_transform_preserves_integrated_motion_and_gripper_switches(): |
| actions, states, frames = _episode(length=10) |
|
|
| transformed, metrics = transform_episode(actions, states, frames, 2.0, SpeedTransformConfig()) |
|
|
| assert len(transformed["action"]) <= 6 |
| assert metrics["gripper_switch_delta"] == 0 |
| assert metrics["integrated_translation_l2_error"] < 1e-5 |
| assert metrics["integrated_rotation_l2_error"] < 1e-5 |
|
|
|
|
| def test_fractional_speed_transforms_cover_0p75_and_1p25(): |
| actions, states, frames = _episode(length=12) |
|
|
| slow, slow_metrics = transform_episode(actions, states, frames, 0.75, SpeedTransformConfig()) |
| fast, fast_metrics = transform_episode(actions, states, frames, 1.25, SpeedTransformConfig()) |
|
|
| assert len(slow["action"]) == 16 |
| assert int(slow["is_padded"].sum()) > 0 |
| assert slow_metrics["target_speed"] == 0.75 |
| assert len(fast["action"]) < len(actions) |
| assert fast_metrics["target_speed"] == 1.25 |
| assert fast_metrics["integrated_translation_l2_error"] < 1e-5 |
| assert fast_metrics["integrated_rotation_l2_error"] < 1e-5 |
|
|
|
|
| def test_chunk_phase_one_valid_starts_are_shifted(): |
| actions, states, frames = _episode(length=12) |
|
|
| transformed, metrics = transform_episode( |
| actions, |
| states, |
| frames, |
| 1.25, |
| SpeedTransformConfig(chunk_aligned_observation=True, chunk_phase=1), |
| ) |
|
|
| valid_source_steps = transformed["source_step_index"][transformed["observation_mask"] == 1] |
| np.testing.assert_array_equal(valid_source_steps, np.asarray([1, 6], dtype=np.int64)) |
| assert metrics["gripper_switch_delta"] == 0 |
| assert metrics["integrated_translation_l2_error"] < 1e-5 |
| assert metrics["integrated_rotation_l2_error"] < 1e-5 |
|
|
|
|
| def test_all_chunk_phases_cover_full_chunk_starts_once(): |
| actions, states, frames = _episode(length=12) |
| starts = [] |
|
|
| for phase in range(5): |
| transformed, _metrics = transform_episode( |
| actions, |
| states, |
| frames, |
| 1.25, |
| SpeedTransformConfig(chunk_aligned_observation=True, chunk_phase=phase), |
| ) |
| valid_source_steps = transformed["source_step_index"][transformed["observation_mask"] == 1] |
| starts.extend(int(step) for step in valid_source_steps if int(step) <= len(actions) - 5) |
|
|
| assert sorted(starts) == list(range(8)) |
|
|
|
|
| def test_invalid_chunk_phase_raises(): |
| actions, states, frames = _episode(length=12) |
|
|
| with pytest.raises(ValueError, match="chunk_phase"): |
| transform_episode( |
| actions, |
| states, |
| frames, |
| 1.25, |
| SpeedTransformConfig(chunk_aligned_observation=True, chunk_phase=5), |
| ) |
|
|
|
|
| def test_replay_metrics_report_path_ratios_and_target_speed(): |
| actions, _states, _frames = _episode(length=6) |
|
|
| metrics = compute_replay_metrics(actions, actions, target_speed=1.0) |
|
|
| assert metrics["target_speed"] == 1.0 |
| assert metrics["translation_path_ratio"] == 1.0 |
| assert metrics["rotation_path_ratio"] == 1.0 |
|
|