agentic-rag-gym / tests /test_models.py
williyam's picture
fix: remove dead code (EXECUTOR, MySQL, bleach), fix PlannerAgent receiver
201faa7
"""Tests for core domain models."""
from __future__ import annotations
import pytest
from rag_master.models import (
AgentMessage,
AgentRole,
DifficultyLevel,
Document,
EpisodeState,
RetrievalResult,
StepRecord,
TaskDefinition,
Trajectory,
)
class TestDocument:
"""Test suite for Document model."""
def test_create_document(self) -> None:
doc = Document(content="Test content", source="test")
assert doc.content == "Test content"
assert doc.source == "test"
assert doc.doc_id is not None
assert doc.created_at is not None
def test_document_with_metadata(self) -> None:
doc = Document(content="x", metadata={"topic": "propulsion"})
assert doc.metadata["topic"] == "propulsion"
def test_document_default_embedding(self) -> None:
doc = Document(content="x")
assert doc.embedding is None
class TestRetrievalResult:
"""Test suite for RetrievalResult model."""
def test_create_result(self) -> None:
doc = Document(content="test")
result = RetrievalResult(document=doc, score=0.85, rank=0)
assert result.score == 0.85
assert result.rank == 0
def test_score_bounds(self) -> None:
doc = Document(content="test")
with pytest.raises(Exception):
RetrievalResult(document=doc, score=1.5, rank=0)
with pytest.raises(Exception):
RetrievalResult(document=doc, score=-0.1, rank=0)
class TestAgentMessage:
"""Test suite for AgentMessage model."""
def test_create_message(self) -> None:
msg = AgentMessage(
sender=AgentRole.RETRIEVER,
receiver=AgentRole.REASONER,
content="Found 3 documents",
)
assert msg.sender == AgentRole.RETRIEVER
assert msg.receiver == AgentRole.REASONER
def test_message_types(self) -> None:
msg = AgentMessage(
sender=AgentRole.CRITIC,
content="Needs improvement",
message_type="critique",
)
assert msg.message_type == "critique"
class TestStepRecord:
"""Test suite for StepRecord model."""
def test_create_step(self) -> None:
step = StepRecord(step_index=1, action_type="retrieve")
assert step.step_index == 1
assert step.intermediate_reward == 0.0
def test_step_with_trace(self) -> None:
step = StepRecord(
step_index=2,
action_type="reason",
reasoning_trace="Based on the evidence...",
)
assert "evidence" in step.reasoning_trace
class TestTrajectory:
"""Test suite for Trajectory model."""
def test_create_empty_trajectory(self) -> None:
traj = Trajectory(task_id="test_task")
assert len(traj.steps) == 0
assert traj.completed is False
def test_trajectory_with_steps(self) -> None:
steps = [
StepRecord(step_index=1, action_type="retrieve", intermediate_reward=0.3),
StepRecord(step_index=2, action_type="reason", intermediate_reward=0.5),
]
traj = Trajectory(task_id="test", steps=steps, total_reward=0.8)
assert len(traj.steps) == 2
assert traj.total_reward == 0.8
class TestTaskDefinition:
"""Test suite for TaskDefinition model."""
def test_create_task(self) -> None:
task = TaskDefinition(
task_id="test_easy",
name="Test Task",
description="A test task",
difficulty=DifficultyLevel.EASY,
)
assert task.difficulty == DifficultyLevel.EASY
assert task.max_steps == 20
def test_task_with_rubric(self) -> None:
task = TaskDefinition(
task_id="test",
name="Test",
description="Test",
difficulty=DifficultyLevel.HARD,
grading_rubric={"accuracy": 0.5, "completeness": 0.5},
)
assert sum(task.grading_rubric.values()) == 1.0
class TestEpisodeState:
"""Test suite for EpisodeState model."""
def test_create_state(self) -> None:
task = TaskDefinition(
task_id="t1", name="T1", description="D1", difficulty=DifficultyLevel.MEDIUM
)
state = EpisodeState(task=task)
assert state.current_step == 0
assert state.done is False
def test_state_history_tracking(self) -> None:
task = TaskDefinition(
task_id="t1", name="T1", description="D1", difficulty=DifficultyLevel.EASY
)
state = EpisodeState(task=task)
state.query_history.append("ion propulsion")
state.intermediate_rewards.append(0.4)
assert len(state.query_history) == 1
assert len(state.intermediate_rewards) == 1
class TestDifficultyLevel:
"""Test suite for DifficultyLevel enum."""
def test_all_levels(self) -> None:
assert DifficultyLevel.EASY == "easy"
assert DifficultyLevel.MEDIUM == "medium"
assert DifficultyLevel.HARD == "hard"
assert DifficultyLevel.EXPERT == "expert"
class TestAgentRole:
"""Test suite for AgentRole enum."""
def test_all_roles(self) -> None:
assert AgentRole.RETRIEVER == "retriever"
assert AgentRole.REASONER == "reasoner"
assert AgentRole.CRITIC == "critic"
assert AgentRole.PLANNER == "planner"
assert AgentRole.VERIFIER == "verifier"