Spaces:
Running
Running
| """Unit tests for movement generator.""" | |
| import pytest | |
| import numpy as np | |
| from reachy_mini_danceml.movement_tools import KeyFrame | |
| from reachy_mini_danceml.movement_generator import GeneratedMove | |
| class TestKeyFrame: | |
| """Tests for KeyFrame dataclass.""" | |
| def test_from_dict_full(self): | |
| """Test KeyFrame.from_dict with all fields.""" | |
| data = { | |
| "t": 1.5, | |
| "head": {"roll": 10, "pitch": -5, "yaw": 30}, | |
| "antennas": [15, -15] | |
| } | |
| kf = KeyFrame.from_dict(data) | |
| assert kf.t == 1.5 | |
| assert kf.head["roll"] == 10 | |
| assert kf.head["pitch"] == -5 | |
| assert kf.head["yaw"] == 30 | |
| assert kf.antennas == (15, -15) | |
| def test_from_dict_minimal(self): | |
| """Test KeyFrame.from_dict with minimal fields.""" | |
| data = {"t": 0.5} | |
| kf = KeyFrame.from_dict(data) | |
| assert kf.t == 0.5 | |
| assert kf.head == {} | |
| assert kf.antennas == (0, 0) | |
| def test_from_dict_defaults(self): | |
| """Test KeyFrame.from_dict uses defaults for missing fields.""" | |
| data = {} | |
| kf = KeyFrame.from_dict(data) | |
| assert kf.t == 0 | |
| assert kf.head == {} | |
| assert kf.antennas == (0, 0) | |
| class TestGeneratedMove: | |
| """Tests for GeneratedMove class.""" | |
| def test_requires_minimum_keyframes(self): | |
| """Test that at least 2 keyframes are required.""" | |
| with pytest.raises(ValueError, match="at least 2 keyframes"): | |
| GeneratedMove([KeyFrame(t=0, head={})]) | |
| def test_duration(self): | |
| """Test that duration equals max keyframe time.""" | |
| keyframes = [ | |
| KeyFrame(t=0.0, head={"yaw": 0}), | |
| KeyFrame(t=1.5, head={"yaw": 30}), | |
| KeyFrame(t=3.0, head={"yaw": 0}), | |
| ] | |
| move = GeneratedMove(keyframes) | |
| assert move.duration == 3.0 | |
| def test_evaluate_returns_correct_types(self): | |
| """Test that evaluate returns correct data types.""" | |
| keyframes = [ | |
| KeyFrame(t=0.0, head={"roll": 0, "pitch": 0, "yaw": 0}, antennas=(0, 0)), | |
| KeyFrame(t=1.0, head={"roll": 10, "pitch": 10, "yaw": 30}, antennas=(20, -20)), | |
| ] | |
| move = GeneratedMove(keyframes) | |
| head, antennas, body_yaw = move.evaluate(0.5) | |
| # Head should be 4x4 matrix | |
| assert head.shape == (4, 4) | |
| # Antennas should be array of 2 | |
| assert len(antennas) == 2 | |
| assert isinstance(antennas, np.ndarray) | |
| # Body yaw should be 0 | |
| assert body_yaw == 0.0 | |
| def test_evaluate_at_boundaries(self): | |
| """Test evaluation at start and end times.""" | |
| keyframes = [ | |
| KeyFrame(t=0.0, head={"yaw": 0}), | |
| KeyFrame(t=1.0, head={"yaw": 30}), | |
| ] | |
| move = GeneratedMove(keyframes) | |
| # Should not raise at boundaries | |
| head_start, _, _ = move.evaluate(0.0) | |
| head_end, _, _ = move.evaluate(1.0) | |
| assert head_start is not None | |
| assert head_end is not None | |
| def test_evaluate_clamps_time(self): | |
| """Test that evaluation clamps time to valid range.""" | |
| keyframes = [ | |
| KeyFrame(t=0.0, head={"yaw": 0}), | |
| KeyFrame(t=1.0, head={"yaw": 30}), | |
| ] | |
| move = GeneratedMove(keyframes) | |
| # Should not raise for out-of-range times | |
| head_before, _, _ = move.evaluate(-1.0) | |
| head_after, _, _ = move.evaluate(5.0) | |
| assert head_before is not None | |
| assert head_after is not None | |
| def test_interpolation_midpoint(self): | |
| """Test that interpolation produces reasonable midpoint values.""" | |
| keyframes = [ | |
| KeyFrame(t=0.0, head={"yaw": 0}, antennas=(0, 0)), | |
| KeyFrame(t=1.0, head={"yaw": 30}, antennas=(30, -30)), | |
| ] | |
| move = GeneratedMove(keyframes) | |
| # At midpoint, values should be roughly halfway | |
| # (Cubic spline may vary slightly from linear midpoint) | |
| _, antennas, _ = move.evaluate(0.5) | |
| # Convert back to degrees for comparison | |
| left_deg = np.rad2deg(antennas[0]) | |
| right_deg = np.rad2deg(antennas[1]) | |
| # Should be roughly 15 and -15 (within tolerance for cubic spline) | |
| assert 10 < left_deg < 20 | |
| assert -20 < right_deg < -10 | |
| class TestMovementToolSchemas: | |
| """Tests for movement tool schemas.""" | |
| def test_all_tools_have_required_fields(self): | |
| """Test that all tool schemas have required fields.""" | |
| from reachy_mini_danceml.movement_tools import ALL_TOOLS | |
| for tool in ALL_TOOLS: | |
| assert "type" in tool | |
| assert tool["type"] == "function" | |
| assert "name" in tool | |
| def test_goto_pose_schema(self): | |
| """Test goto_pose tool schema structure.""" | |
| from reachy_mini_danceml.movement_tools import GOTO_POSE_TOOL | |
| assert GOTO_POSE_TOOL["name"] == "goto_pose" | |
| params = GOTO_POSE_TOOL["parameters"]["properties"] | |
| assert "roll" in params | |
| assert "pitch" in params | |
| assert "yaw" in params | |
| assert "duration" in params | |
| def test_create_sequence_schema(self): | |
| """Test create_sequence tool schema structure.""" | |
| from reachy_mini_danceml.movement_tools import CREATE_SEQUENCE_TOOL | |
| assert CREATE_SEQUENCE_TOOL["name"] == "create_sequence" | |
| params = CREATE_SEQUENCE_TOOL["parameters"]["properties"] | |
| assert "keyframes" in params | |
| assert params["keyframes"]["type"] == "array" | |