"""Tests for core domain models.""" from __future__ import annotations import pytest from rag_master.models import ( AgentMessage, AgentRole, DifficultyLevel, Document, EpisodeState, RetrievalResult, StepRecord, TaskDefinition, Trajectory, ) class TestDocument: """Test suite for Document model.""" def test_create_document(self) -> None: doc = Document(content="Test content", source="test") assert doc.content == "Test content" assert doc.source == "test" assert doc.doc_id is not None assert doc.created_at is not None def test_document_with_metadata(self) -> None: doc = Document(content="x", metadata={"topic": "propulsion"}) assert doc.metadata["topic"] == "propulsion" def test_document_default_embedding(self) -> None: doc = Document(content="x") assert doc.embedding is None class TestRetrievalResult: """Test suite for RetrievalResult model.""" def test_create_result(self) -> None: doc = Document(content="test") result = RetrievalResult(document=doc, score=0.85, rank=0) assert result.score == 0.85 assert result.rank == 0 def test_score_bounds(self) -> None: doc = Document(content="test") with pytest.raises(Exception): RetrievalResult(document=doc, score=1.5, rank=0) with pytest.raises(Exception): RetrievalResult(document=doc, score=-0.1, rank=0) class TestAgentMessage: """Test suite for AgentMessage model.""" def test_create_message(self) -> None: msg = AgentMessage( sender=AgentRole.RETRIEVER, receiver=AgentRole.REASONER, content="Found 3 documents", ) assert msg.sender == AgentRole.RETRIEVER assert msg.receiver == AgentRole.REASONER def test_message_types(self) -> None: msg = AgentMessage( sender=AgentRole.CRITIC, content="Needs improvement", message_type="critique", ) assert msg.message_type == "critique" class TestStepRecord: """Test suite for StepRecord model.""" def test_create_step(self) -> None: step = StepRecord(step_index=1, action_type="retrieve") assert step.step_index == 1 assert step.intermediate_reward == 0.0 def test_step_with_trace(self) -> None: step = StepRecord( step_index=2, action_type="reason", reasoning_trace="Based on the evidence...", ) assert "evidence" in step.reasoning_trace class TestTrajectory: """Test suite for Trajectory model.""" def test_create_empty_trajectory(self) -> None: traj = Trajectory(task_id="test_task") assert len(traj.steps) == 0 assert traj.completed is False def test_trajectory_with_steps(self) -> None: steps = [ StepRecord(step_index=1, action_type="retrieve", intermediate_reward=0.3), StepRecord(step_index=2, action_type="reason", intermediate_reward=0.5), ] traj = Trajectory(task_id="test", steps=steps, total_reward=0.8) assert len(traj.steps) == 2 assert traj.total_reward == 0.8 class TestTaskDefinition: """Test suite for TaskDefinition model.""" def test_create_task(self) -> None: task = TaskDefinition( task_id="test_easy", name="Test Task", description="A test task", difficulty=DifficultyLevel.EASY, ) assert task.difficulty == DifficultyLevel.EASY assert task.max_steps == 20 def test_task_with_rubric(self) -> None: task = TaskDefinition( task_id="test", name="Test", description="Test", difficulty=DifficultyLevel.HARD, grading_rubric={"accuracy": 0.5, "completeness": 0.5}, ) assert sum(task.grading_rubric.values()) == 1.0 class TestEpisodeState: """Test suite for EpisodeState model.""" def test_create_state(self) -> None: task = TaskDefinition( task_id="t1", name="T1", description="D1", difficulty=DifficultyLevel.MEDIUM ) state = EpisodeState(task=task) assert state.current_step == 0 assert state.done is False def test_state_history_tracking(self) -> None: task = TaskDefinition( task_id="t1", name="T1", description="D1", difficulty=DifficultyLevel.EASY ) state = EpisodeState(task=task) state.query_history.append("ion propulsion") state.intermediate_rewards.append(0.4) assert len(state.query_history) == 1 assert len(state.intermediate_rewards) == 1 class TestDifficultyLevel: """Test suite for DifficultyLevel enum.""" def test_all_levels(self) -> None: assert DifficultyLevel.EASY == "easy" assert DifficultyLevel.MEDIUM == "medium" assert DifficultyLevel.HARD == "hard" assert DifficultyLevel.EXPERT == "expert" class TestAgentRole: """Test suite for AgentRole enum.""" def test_all_roles(self) -> None: assert AgentRole.RETRIEVER == "retriever" assert AgentRole.REASONER == "reasoner" assert AgentRole.CRITIC == "critic" assert AgentRole.PLANNER == "planner" assert AgentRole.VERIFIER == "verifier"