scrapeRL / backend /tests /test_core /test_episode.py
NeerajCodz's picture
test: add comprehensive API and core module tests
e8d7c11
"""Tests for episode management."""
import pytest
from app.core.episode import Episode, EpisodeStep, EpisodeStatus, EpisodeManager
class TestEpisode:
"""Test Episode class."""
def test_episode_creation(self) -> None:
"""Test creating an episode."""
episode = Episode(
episode_id="ep_001",
task_id="task_001",
max_steps=10,
)
assert episode.episode_id == "ep_001"
assert episode.task_id == "task_001"
assert episode.max_steps == 10
assert episode.status == EpisodeStatus.PENDING
assert len(episode.steps) == 0
def test_episode_start(self) -> None:
"""Test starting an episode."""
episode = Episode(episode_id="ep_002", task_id="task_002")
episode.start()
assert episode.status == EpisodeStatus.RUNNING
assert episode.started_at is not None
def test_episode_add_step(self) -> None:
"""Test adding a step to episode."""
episode = Episode(episode_id="ep_003", task_id="task_003")
episode.start()
step = episode.add_step(
action_type="navigate",
action_params={"target": "/login"},
reward=0.5,
reward_breakdown={"progress": 0.5},
observation_summary={"url": "https://example.com"},
)
assert len(episode.steps) == 1
assert episode.steps[0].step_number == 1
assert episode.total_reward == 0.5
def test_episode_multiple_steps(self) -> None:
"""Test adding multiple steps."""
episode = Episode(episode_id="ep_004", task_id="task_004")
episode.start()
rewards = [0.1, 0.2, 0.3, 0.4]
for i, reward in enumerate(rewards):
episode.add_step(
action_type="test",
action_params={"step": i},
reward=reward,
reward_breakdown={"base": reward},
observation_summary={"step": i},
)
assert len(episode.steps) == 4
assert episode.total_reward == pytest.approx(1.0)
assert episode.current_step == 4
def test_episode_completion(self) -> None:
"""Test completing an episode."""
episode = Episode(episode_id="ep_005", task_id="task_005")
episode.start()
episode.complete(success=True)
assert episode.status == EpisodeStatus.COMPLETED
assert episode.ended_at is not None
def test_episode_failure(self) -> None:
"""Test failing an episode."""
episode = Episode(episode_id="ep_006", task_id="task_006")
episode.start()
episode.fail(reason="Test failure")
assert episode.status == EpisodeStatus.FAILED
assert episode.failure_reason == "Test failure"
def test_episode_truncation(self) -> None:
"""Test truncating an episode."""
episode = Episode(episode_id="ep_007", task_id="task_007", max_steps=5)
episode.start()
# Add steps up to max
for i in range(5):
episode.add_step(
action_type="test",
action_params={},
reward=0.1,
reward_breakdown={"base": 0.1},
observation_summary={},
)
episode.truncate()
assert episode.status == EpisodeStatus.TRUNCATED
def test_episode_is_terminal(self) -> None:
"""Test terminal state check."""
episode = Episode(episode_id="ep_008", task_id="task_008")
assert not episode.is_terminal
episode.start()
assert not episode.is_terminal
episode.complete(success=True)
assert episode.is_terminal
def test_episode_duration(self) -> None:
"""Test episode duration calculation."""
episode = Episode(episode_id="ep_009", task_id="task_009")
episode.start()
# Duration should be None before completion
import time
time.sleep(0.01) # Small delay
episode.complete(success=True)
assert episode.duration_seconds is not None
assert episode.duration_seconds >= 0
def test_episode_average_reward(self) -> None:
"""Test average reward calculation."""
episode = Episode(episode_id="ep_010", task_id="task_010")
episode.start()
rewards = [0.2, 0.4, 0.6]
for i, reward in enumerate(rewards):
episode.add_step(
action_type="test",
action_params={},
reward=reward,
reward_breakdown={"base": reward},
observation_summary={},
)
assert episode.average_reward == pytest.approx(0.4)
def test_episode_summary(self) -> None:
"""Test episode summary."""
episode = Episode(episode_id="ep_011", task_id="task_011")
episode.start()
summary = episode.get_summary()
assert summary["episode_id"] == "ep_011"
assert summary["task_id"] == "task_011"
assert "status" in summary
assert "steps" in summary
def test_episode_cancel(self) -> None:
"""Test episode cancellation."""
episode = Episode(episode_id="ep_012", task_id="task_012")
episode.start()
episode.cancel()
assert episode.status == EpisodeStatus.CANCELLED
assert episode.is_terminal
def test_episode_get_action_sequence(self) -> None:
"""Test getting action sequence."""
episode = Episode(episode_id="ep_013", task_id="task_013")
episode.start()
episode.add_step("navigate", {}, 0.1, {}, {})
episode.add_step("click", {}, 0.2, {}, {})
episode.add_step("extract", {}, 0.3, {}, {})
actions = episode.get_action_sequence()
assert actions == ["navigate", "click", "extract"]
def test_episode_get_reward_history(self) -> None:
"""Test getting reward history."""
episode = Episode(episode_id="ep_014", task_id="task_014")
episode.start()
episode.add_step("a", {}, 0.1, {}, {})
episode.add_step("b", {}, 0.2, {}, {})
episode.add_step("c", {}, 0.3, {}, {})
rewards = episode.get_reward_history()
assert rewards == [0.1, 0.2, 0.3]
class TestEpisodeStep:
"""Test EpisodeStep class."""
def test_step_creation(self) -> None:
"""Test creating an episode step."""
from datetime import datetime, timezone
step = EpisodeStep(
step_number=1,
timestamp=datetime.now(timezone.utc).isoformat(),
action_type="click",
action_params={"selector": "#btn"},
reward=0.75,
reward_breakdown={"progress": 0.75},
observation_summary={"url": "https://example.com", "title": "Test"},
)
assert step.step_number == 1
assert step.action_type == "click"
assert step.action_params["selector"] == "#btn"
assert step.reward == 0.75
def test_step_with_error(self) -> None:
"""Test step with error."""
from datetime import datetime, timezone
step = EpisodeStep(
step_number=1,
timestamp=datetime.now(timezone.utc).isoformat(),
action_type="click",
action_params={},
reward=-0.5,
reward_breakdown={"error": -0.5},
observation_summary={},
error="Element not found",
duration_ms=150.0,
)
assert step.error == "Element not found"
assert step.duration_ms == 150.0
def test_step_with_reasoning(self) -> None:
"""Test step with action reasoning."""
from datetime import datetime, timezone
step = EpisodeStep(
step_number=1,
timestamp=datetime.now(timezone.utc).isoformat(),
action_type="extract",
action_params={"field": "price"},
action_reasoning="Extracting price from product page",
reward=0.5,
reward_breakdown={"extraction": 0.5},
observation_summary={},
)
assert step.action_reasoning == "Extracting price from product page"
class TestEpisodeManager:
"""Test EpisodeManager class."""
def test_manager_create_episode(self) -> None:
"""Test creating episode via manager."""
manager = EpisodeManager()
episode = manager.create_episode("ep_100", "task_100")
assert episode.episode_id == "ep_100"
assert episode.task_id == "task_100"
def test_manager_get_episode(self) -> None:
"""Test getting episode from manager."""
manager = EpisodeManager()
manager.create_episode("ep_101", "task_101")
episode = manager.get_episode("ep_101")
assert episode is not None
assert episode.episode_id == "ep_101"
def test_manager_get_nonexistent(self) -> None:
"""Test getting non-existent episode."""
manager = EpisodeManager()
episode = manager.get_episode("nonexistent")
assert episode is None
def test_manager_remove_episode(self) -> None:
"""Test removing episode from manager."""
manager = EpisodeManager()
manager.create_episode("ep_102", "task_102")
removed = manager.remove_episode("ep_102")
assert removed is True
episode = manager.get_episode("ep_102")
assert episode is None
def test_manager_list_episodes(self) -> None:
"""Test listing episodes."""
manager = EpisodeManager()
manager.create_episode("ep_103", "task_103")
manager.create_episode("ep_104", "task_104")
manager.create_episode("ep_105", "task_105")
episodes = manager.list_episodes()
assert len(episodes) == 3
def test_manager_list_episodes_by_status(self) -> None:
"""Test listing episodes by status."""
manager = EpisodeManager()
ep1 = manager.create_episode("ep_106", "task_106")
ep2 = manager.create_episode("ep_107", "task_107")
ep3 = manager.create_episode("ep_108", "task_108")
ep1.start()
ep2.start()
ep2.complete(success=True)
running = manager.list_episodes(status=EpisodeStatus.RUNNING)
assert len(running) == 1
assert running[0].episode_id == "ep_106"
completed = manager.list_episodes(status=EpisodeStatus.COMPLETED)
assert len(completed) == 1
assert completed[0].episode_id == "ep_107"
def test_manager_list_episodes_by_task(self) -> None:
"""Test listing episodes by task ID."""
manager = EpisodeManager()
manager.create_episode("ep_109", "task_A")
manager.create_episode("ep_110", "task_A")
manager.create_episode("ep_111", "task_B")
task_a_episodes = manager.list_episodes(task_id="task_A")
assert len(task_a_episodes) == 2
task_b_episodes = manager.list_episodes(task_id="task_B")
assert len(task_b_episodes) == 1