Spaces:
Sleeping
Sleeping
| """Tests for episode management.""" | |
| import pytest | |
| from app.core.episode import Episode, EpisodeStep, EpisodeStatus, EpisodeManager | |
| class TestEpisode: | |
| """Test Episode class.""" | |
| def test_episode_creation(self) -> None: | |
| """Test creating an episode.""" | |
| episode = Episode( | |
| episode_id="ep_001", | |
| task_id="task_001", | |
| max_steps=10, | |
| ) | |
| assert episode.episode_id == "ep_001" | |
| assert episode.task_id == "task_001" | |
| assert episode.max_steps == 10 | |
| assert episode.status == EpisodeStatus.PENDING | |
| assert len(episode.steps) == 0 | |
| def test_episode_start(self) -> None: | |
| """Test starting an episode.""" | |
| episode = Episode(episode_id="ep_002", task_id="task_002") | |
| episode.start() | |
| assert episode.status == EpisodeStatus.RUNNING | |
| assert episode.started_at is not None | |
| def test_episode_add_step(self) -> None: | |
| """Test adding a step to episode.""" | |
| episode = Episode(episode_id="ep_003", task_id="task_003") | |
| episode.start() | |
| step = episode.add_step( | |
| action_type="navigate", | |
| action_params={"target": "/login"}, | |
| reward=0.5, | |
| reward_breakdown={"progress": 0.5}, | |
| observation_summary={"url": "https://example.com"}, | |
| ) | |
| assert len(episode.steps) == 1 | |
| assert episode.steps[0].step_number == 1 | |
| assert episode.total_reward == 0.5 | |
| def test_episode_multiple_steps(self) -> None: | |
| """Test adding multiple steps.""" | |
| episode = Episode(episode_id="ep_004", task_id="task_004") | |
| episode.start() | |
| rewards = [0.1, 0.2, 0.3, 0.4] | |
| for i, reward in enumerate(rewards): | |
| episode.add_step( | |
| action_type="test", | |
| action_params={"step": i}, | |
| reward=reward, | |
| reward_breakdown={"base": reward}, | |
| observation_summary={"step": i}, | |
| ) | |
| assert len(episode.steps) == 4 | |
| assert episode.total_reward == pytest.approx(1.0) | |
| assert episode.current_step == 4 | |
| def test_episode_completion(self) -> None: | |
| """Test completing an episode.""" | |
| episode = Episode(episode_id="ep_005", task_id="task_005") | |
| episode.start() | |
| episode.complete(success=True) | |
| assert episode.status == EpisodeStatus.COMPLETED | |
| assert episode.ended_at is not None | |
| def test_episode_failure(self) -> None: | |
| """Test failing an episode.""" | |
| episode = Episode(episode_id="ep_006", task_id="task_006") | |
| episode.start() | |
| episode.fail(reason="Test failure") | |
| assert episode.status == EpisodeStatus.FAILED | |
| assert episode.failure_reason == "Test failure" | |
| def test_episode_truncation(self) -> None: | |
| """Test truncating an episode.""" | |
| episode = Episode(episode_id="ep_007", task_id="task_007", max_steps=5) | |
| episode.start() | |
| # Add steps up to max | |
| for i in range(5): | |
| episode.add_step( | |
| action_type="test", | |
| action_params={}, | |
| reward=0.1, | |
| reward_breakdown={"base": 0.1}, | |
| observation_summary={}, | |
| ) | |
| episode.truncate() | |
| assert episode.status == EpisodeStatus.TRUNCATED | |
| def test_episode_is_terminal(self) -> None: | |
| """Test terminal state check.""" | |
| episode = Episode(episode_id="ep_008", task_id="task_008") | |
| assert not episode.is_terminal | |
| episode.start() | |
| assert not episode.is_terminal | |
| episode.complete(success=True) | |
| assert episode.is_terminal | |
| def test_episode_duration(self) -> None: | |
| """Test episode duration calculation.""" | |
| episode = Episode(episode_id="ep_009", task_id="task_009") | |
| episode.start() | |
| # Duration should be None before completion | |
| import time | |
| time.sleep(0.01) # Small delay | |
| episode.complete(success=True) | |
| assert episode.duration_seconds is not None | |
| assert episode.duration_seconds >= 0 | |
| def test_episode_average_reward(self) -> None: | |
| """Test average reward calculation.""" | |
| episode = Episode(episode_id="ep_010", task_id="task_010") | |
| episode.start() | |
| rewards = [0.2, 0.4, 0.6] | |
| for i, reward in enumerate(rewards): | |
| episode.add_step( | |
| action_type="test", | |
| action_params={}, | |
| reward=reward, | |
| reward_breakdown={"base": reward}, | |
| observation_summary={}, | |
| ) | |
| assert episode.average_reward == pytest.approx(0.4) | |
| def test_episode_summary(self) -> None: | |
| """Test episode summary.""" | |
| episode = Episode(episode_id="ep_011", task_id="task_011") | |
| episode.start() | |
| summary = episode.get_summary() | |
| assert summary["episode_id"] == "ep_011" | |
| assert summary["task_id"] == "task_011" | |
| assert "status" in summary | |
| assert "steps" in summary | |
| def test_episode_cancel(self) -> None: | |
| """Test episode cancellation.""" | |
| episode = Episode(episode_id="ep_012", task_id="task_012") | |
| episode.start() | |
| episode.cancel() | |
| assert episode.status == EpisodeStatus.CANCELLED | |
| assert episode.is_terminal | |
| def test_episode_get_action_sequence(self) -> None: | |
| """Test getting action sequence.""" | |
| episode = Episode(episode_id="ep_013", task_id="task_013") | |
| episode.start() | |
| episode.add_step("navigate", {}, 0.1, {}, {}) | |
| episode.add_step("click", {}, 0.2, {}, {}) | |
| episode.add_step("extract", {}, 0.3, {}, {}) | |
| actions = episode.get_action_sequence() | |
| assert actions == ["navigate", "click", "extract"] | |
| def test_episode_get_reward_history(self) -> None: | |
| """Test getting reward history.""" | |
| episode = Episode(episode_id="ep_014", task_id="task_014") | |
| episode.start() | |
| episode.add_step("a", {}, 0.1, {}, {}) | |
| episode.add_step("b", {}, 0.2, {}, {}) | |
| episode.add_step("c", {}, 0.3, {}, {}) | |
| rewards = episode.get_reward_history() | |
| assert rewards == [0.1, 0.2, 0.3] | |
| class TestEpisodeStep: | |
| """Test EpisodeStep class.""" | |
| def test_step_creation(self) -> None: | |
| """Test creating an episode step.""" | |
| from datetime import datetime, timezone | |
| step = EpisodeStep( | |
| step_number=1, | |
| timestamp=datetime.now(timezone.utc).isoformat(), | |
| action_type="click", | |
| action_params={"selector": "#btn"}, | |
| reward=0.75, | |
| reward_breakdown={"progress": 0.75}, | |
| observation_summary={"url": "https://example.com", "title": "Test"}, | |
| ) | |
| assert step.step_number == 1 | |
| assert step.action_type == "click" | |
| assert step.action_params["selector"] == "#btn" | |
| assert step.reward == 0.75 | |
| def test_step_with_error(self) -> None: | |
| """Test step with error.""" | |
| from datetime import datetime, timezone | |
| step = EpisodeStep( | |
| step_number=1, | |
| timestamp=datetime.now(timezone.utc).isoformat(), | |
| action_type="click", | |
| action_params={}, | |
| reward=-0.5, | |
| reward_breakdown={"error": -0.5}, | |
| observation_summary={}, | |
| error="Element not found", | |
| duration_ms=150.0, | |
| ) | |
| assert step.error == "Element not found" | |
| assert step.duration_ms == 150.0 | |
| def test_step_with_reasoning(self) -> None: | |
| """Test step with action reasoning.""" | |
| from datetime import datetime, timezone | |
| step = EpisodeStep( | |
| step_number=1, | |
| timestamp=datetime.now(timezone.utc).isoformat(), | |
| action_type="extract", | |
| action_params={"field": "price"}, | |
| action_reasoning="Extracting price from product page", | |
| reward=0.5, | |
| reward_breakdown={"extraction": 0.5}, | |
| observation_summary={}, | |
| ) | |
| assert step.action_reasoning == "Extracting price from product page" | |
| class TestEpisodeManager: | |
| """Test EpisodeManager class.""" | |
| def test_manager_create_episode(self) -> None: | |
| """Test creating episode via manager.""" | |
| manager = EpisodeManager() | |
| episode = manager.create_episode("ep_100", "task_100") | |
| assert episode.episode_id == "ep_100" | |
| assert episode.task_id == "task_100" | |
| def test_manager_get_episode(self) -> None: | |
| """Test getting episode from manager.""" | |
| manager = EpisodeManager() | |
| manager.create_episode("ep_101", "task_101") | |
| episode = manager.get_episode("ep_101") | |
| assert episode is not None | |
| assert episode.episode_id == "ep_101" | |
| def test_manager_get_nonexistent(self) -> None: | |
| """Test getting non-existent episode.""" | |
| manager = EpisodeManager() | |
| episode = manager.get_episode("nonexistent") | |
| assert episode is None | |
| def test_manager_remove_episode(self) -> None: | |
| """Test removing episode from manager.""" | |
| manager = EpisodeManager() | |
| manager.create_episode("ep_102", "task_102") | |
| removed = manager.remove_episode("ep_102") | |
| assert removed is True | |
| episode = manager.get_episode("ep_102") | |
| assert episode is None | |
| def test_manager_list_episodes(self) -> None: | |
| """Test listing episodes.""" | |
| manager = EpisodeManager() | |
| manager.create_episode("ep_103", "task_103") | |
| manager.create_episode("ep_104", "task_104") | |
| manager.create_episode("ep_105", "task_105") | |
| episodes = manager.list_episodes() | |
| assert len(episodes) == 3 | |
| def test_manager_list_episodes_by_status(self) -> None: | |
| """Test listing episodes by status.""" | |
| manager = EpisodeManager() | |
| ep1 = manager.create_episode("ep_106", "task_106") | |
| ep2 = manager.create_episode("ep_107", "task_107") | |
| ep3 = manager.create_episode("ep_108", "task_108") | |
| ep1.start() | |
| ep2.start() | |
| ep2.complete(success=True) | |
| running = manager.list_episodes(status=EpisodeStatus.RUNNING) | |
| assert len(running) == 1 | |
| assert running[0].episode_id == "ep_106" | |
| completed = manager.list_episodes(status=EpisodeStatus.COMPLETED) | |
| assert len(completed) == 1 | |
| assert completed[0].episode_id == "ep_107" | |
| def test_manager_list_episodes_by_task(self) -> None: | |
| """Test listing episodes by task ID.""" | |
| manager = EpisodeManager() | |
| manager.create_episode("ep_109", "task_A") | |
| manager.create_episode("ep_110", "task_A") | |
| manager.create_episode("ep_111", "task_B") | |
| task_a_episodes = manager.list_episodes(task_id="task_A") | |
| assert len(task_a_episodes) == 2 | |
| task_b_episodes = manager.list_episodes(task_id="task_B") | |
| assert len(task_b_episodes) == 1 | |