VLAwithVariousSpeed / tests /test_core.py
Alan0928's picture
Upload folder using huggingface_hub
08ff31f verified
Raw
History Blame Contribute Delete
4.76 kB
import numpy as np
import pytest
from various_speed.core import SpeedTransformConfig
from various_speed.core import clean_near_zero_actions
from various_speed.core import compute_replay_metrics
from various_speed.core import segment_actions
from various_speed.core import transform_episode
def _episode(length: int = 8) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
actions = np.zeros((length, 7), dtype=np.float32)
actions[:, 0] = 1.0
actions[:, 3] = 0.1
actions[: length // 2, 6] = -1.0
actions[length // 2 :, 6] = 1.0
states = np.zeros((length, 8), dtype=np.float32)
frames = np.arange(length, dtype=np.int64)
return actions, states, frames
def test_clean_near_zero_actions_does_not_touch_gripper():
actions = np.zeros((2, 7), dtype=np.float32)
actions[:, 6] = [-1.0, 1.0]
cleaned, mask = clean_near_zero_actions(actions, transl_eps=1e-3, rot_eps=1e-3)
np.testing.assert_array_equal(cleaned[:, 6], actions[:, 6])
np.testing.assert_array_equal(mask, np.ones((2, 2), dtype=bool))
def test_gripper_only_actions_do_not_create_motion_segments():
actions = np.zeros((4, 7), dtype=np.float32)
actions[:, 6] = [-1.0, -1.0, 1.0, 1.0]
segments = segment_actions(actions, SpeedTransformConfig())
assert segments == [(0, 4, 0)]
def test_slow_transform_pads_every_repeated_observation():
actions, states, frames = _episode(length=6)
transformed, metrics = transform_episode(actions, states, frames, 0.5, SpeedTransformConfig())
assert len(transformed["action"]) == 12
assert int(transformed["is_padded"].sum()) == 6
assert float(metrics["padded_ratio"]) == 0.5
assert metrics["gripper_switch_delta"] == 0
def test_fast_transform_preserves_integrated_motion_and_gripper_switches():
actions, states, frames = _episode(length=10)
transformed, metrics = transform_episode(actions, states, frames, 2.0, SpeedTransformConfig())
assert len(transformed["action"]) <= 6
assert metrics["gripper_switch_delta"] == 0
assert metrics["integrated_translation_l2_error"] < 1e-5
assert metrics["integrated_rotation_l2_error"] < 1e-5
def test_fractional_speed_transforms_cover_0p75_and_1p25():
actions, states, frames = _episode(length=12)
slow, slow_metrics = transform_episode(actions, states, frames, 0.75, SpeedTransformConfig())
fast, fast_metrics = transform_episode(actions, states, frames, 1.25, SpeedTransformConfig())
assert len(slow["action"]) == 16
assert int(slow["is_padded"].sum()) > 0
assert slow_metrics["target_speed"] == 0.75
assert len(fast["action"]) < len(actions)
assert fast_metrics["target_speed"] == 1.25
assert fast_metrics["integrated_translation_l2_error"] < 1e-5
assert fast_metrics["integrated_rotation_l2_error"] < 1e-5
def test_chunk_phase_one_valid_starts_are_shifted():
actions, states, frames = _episode(length=12)
transformed, metrics = transform_episode(
actions,
states,
frames,
1.25,
SpeedTransformConfig(chunk_aligned_observation=True, chunk_phase=1),
)
valid_source_steps = transformed["source_step_index"][transformed["observation_mask"] == 1]
np.testing.assert_array_equal(valid_source_steps, np.asarray([1, 6], dtype=np.int64))
assert metrics["gripper_switch_delta"] == 0
assert metrics["integrated_translation_l2_error"] < 1e-5
assert metrics["integrated_rotation_l2_error"] < 1e-5
def test_all_chunk_phases_cover_full_chunk_starts_once():
actions, states, frames = _episode(length=12)
starts = []
for phase in range(5):
transformed, _metrics = transform_episode(
actions,
states,
frames,
1.25,
SpeedTransformConfig(chunk_aligned_observation=True, chunk_phase=phase),
)
valid_source_steps = transformed["source_step_index"][transformed["observation_mask"] == 1]
starts.extend(int(step) for step in valid_source_steps if int(step) <= len(actions) - 5)
assert sorted(starts) == list(range(8))
def test_invalid_chunk_phase_raises():
actions, states, frames = _episode(length=12)
with pytest.raises(ValueError, match="chunk_phase"):
transform_episode(
actions,
states,
frames,
1.25,
SpeedTransformConfig(chunk_aligned_observation=True, chunk_phase=5),
)
def test_replay_metrics_report_path_ratios_and_target_speed():
actions, _states, _frames = _episode(length=6)
metrics = compute_replay_metrics(actions, actions, target_speed=1.0)
assert metrics["target_speed"] == 1.0
assert metrics["translation_path_ratio"] == 1.0
assert metrics["rotation_path_ratio"] == 1.0