import numpy as np import pytest from various_speed.core import SpeedTransformConfig from various_speed.core import clean_near_zero_actions from various_speed.core import compute_replay_metrics from various_speed.core import segment_actions from various_speed.core import transform_episode def _episode(length: int = 8) -> tuple[np.ndarray, np.ndarray, np.ndarray]: actions = np.zeros((length, 7), dtype=np.float32) actions[:, 0] = 1.0 actions[:, 3] = 0.1 actions[: length // 2, 6] = -1.0 actions[length // 2 :, 6] = 1.0 states = np.zeros((length, 8), dtype=np.float32) frames = np.arange(length, dtype=np.int64) return actions, states, frames def test_clean_near_zero_actions_does_not_touch_gripper(): actions = np.zeros((2, 7), dtype=np.float32) actions[:, 6] = [-1.0, 1.0] cleaned, mask = clean_near_zero_actions(actions, transl_eps=1e-3, rot_eps=1e-3) np.testing.assert_array_equal(cleaned[:, 6], actions[:, 6]) np.testing.assert_array_equal(mask, np.ones((2, 2), dtype=bool)) def test_gripper_only_actions_do_not_create_motion_segments(): actions = np.zeros((4, 7), dtype=np.float32) actions[:, 6] = [-1.0, -1.0, 1.0, 1.0] segments = segment_actions(actions, SpeedTransformConfig()) assert segments == [(0, 4, 0)] def test_slow_transform_pads_every_repeated_observation(): actions, states, frames = _episode(length=6) transformed, metrics = transform_episode(actions, states, frames, 0.5, SpeedTransformConfig()) assert len(transformed["action"]) == 12 assert int(transformed["is_padded"].sum()) == 6 assert float(metrics["padded_ratio"]) == 0.5 assert metrics["gripper_switch_delta"] == 0 def test_fast_transform_preserves_integrated_motion_and_gripper_switches(): actions, states, frames = _episode(length=10) transformed, metrics = transform_episode(actions, states, frames, 2.0, SpeedTransformConfig()) assert len(transformed["action"]) <= 6 assert metrics["gripper_switch_delta"] == 0 assert metrics["integrated_translation_l2_error"] < 1e-5 assert metrics["integrated_rotation_l2_error"] < 1e-5 def test_fractional_speed_transforms_cover_0p75_and_1p25(): actions, states, frames = _episode(length=12) slow, slow_metrics = transform_episode(actions, states, frames, 0.75, SpeedTransformConfig()) fast, fast_metrics = transform_episode(actions, states, frames, 1.25, SpeedTransformConfig()) assert len(slow["action"]) == 16 assert int(slow["is_padded"].sum()) > 0 assert slow_metrics["target_speed"] == 0.75 assert len(fast["action"]) < len(actions) assert fast_metrics["target_speed"] == 1.25 assert fast_metrics["integrated_translation_l2_error"] < 1e-5 assert fast_metrics["integrated_rotation_l2_error"] < 1e-5 def test_chunk_phase_one_valid_starts_are_shifted(): actions, states, frames = _episode(length=12) transformed, metrics = transform_episode( actions, states, frames, 1.25, SpeedTransformConfig(chunk_aligned_observation=True, chunk_phase=1), ) valid_source_steps = transformed["source_step_index"][transformed["observation_mask"] == 1] np.testing.assert_array_equal(valid_source_steps, np.asarray([1, 6], dtype=np.int64)) assert metrics["gripper_switch_delta"] == 0 assert metrics["integrated_translation_l2_error"] < 1e-5 assert metrics["integrated_rotation_l2_error"] < 1e-5 def test_all_chunk_phases_cover_full_chunk_starts_once(): actions, states, frames = _episode(length=12) starts = [] for phase in range(5): transformed, _metrics = transform_episode( actions, states, frames, 1.25, SpeedTransformConfig(chunk_aligned_observation=True, chunk_phase=phase), ) valid_source_steps = transformed["source_step_index"][transformed["observation_mask"] == 1] starts.extend(int(step) for step in valid_source_steps if int(step) <= len(actions) - 5) assert sorted(starts) == list(range(8)) def test_invalid_chunk_phase_raises(): actions, states, frames = _episode(length=12) with pytest.raises(ValueError, match="chunk_phase"): transform_episode( actions, states, frames, 1.25, SpeedTransformConfig(chunk_aligned_observation=True, chunk_phase=5), ) def test_replay_metrics_report_path_ratios_and_target_speed(): actions, _states, _frames = _episode(length=6) metrics = compute_replay_metrics(actions, actions, target_speed=1.0) assert metrics["target_speed"] == 1.0 assert metrics["translation_path_ratio"] == 1.0 assert metrics["rotation_path_ratio"] == 1.0