"""Tests for the FastAPI server endpoints.""" from __future__ import annotations from typing import Any, Dict, List, Tuple import pytest from httpx import ASGITransport, AsyncClient from rag_master.adapters import BaseDomainConfig, BaseGrader, BaseLLMClient, BaseRewardFunction, BaseRetriever from rag_master.models import ( DifficultyLevel, Document, EpisodeState, RetrievalResult, StepRecord, TaskDefinition, Trajectory, ) from rag_master.orchestrator import Orchestrator from rag_master.rewards import _SCORE_MAX, _SCORE_MIN, clamp_score import server.app as server_app_module from server.app import app class _MockRetriever(BaseRetriever): async def index_documents(self, documents: List[Document]) -> int: return len(documents) async def retrieve(self, query: str, top_k: int = 5) -> List[RetrievalResult]: return [ RetrievalResult( document=Document(content="Test document about propulsion systems", source="test"), score=0.75, rank=0, ) ] async def clear_index(self) -> None: pass class _MockLLM(BaseLLMClient): async def generate(self, messages: List[Dict[str, str]], **kwargs: Any) -> str: return "Based on analysis, ion propulsion offers higher specific impulse." async def generate_with_metadata(self, messages: List[Dict[str, str]], **kwargs: Any) -> Dict[str, Any]: return {"content": "Test response", "total_tokens": 50, "model": "test"} class _MockRewardFn(BaseRewardFunction): async def compute_step_reward(self, state: EpisodeState, step: StepRecord) -> float: return clamp_score(0.5) async def compute_episode_reward(self, trajectory: Trajectory, state: EpisodeState) -> float: return clamp_score(0.7) def get_reward_bounds(self) -> Tuple[float, float]: return (0.01, 0.99) class _MockGrader(BaseGrader): async def grade(self, state: EpisodeState, trajectory: Trajectory) -> float: return clamp_score(0.6) class _MockDomain(BaseDomainConfig): def get_tasks(self) -> List[TaskDefinition]: return [ TaskDefinition(task_id="test_easy", name="Test Easy", description="Easy test", difficulty=DifficultyLevel.EASY, max_steps=10), TaskDefinition(task_id="test_medium", name="Test Medium", description="Medium test", difficulty=DifficultyLevel.MEDIUM, max_steps=15), TaskDefinition(task_id="test_hard", name="Test Hard", description="Hard test", difficulty=DifficultyLevel.HARD, max_steps=20), TaskDefinition(task_id="test_expert1", name="Test Expert1", description="Expert test 1", difficulty=DifficultyLevel.EXPERT, max_steps=20), TaskDefinition(task_id="test_expert2", name="Test Expert2", description="Expert test 2", difficulty=DifficultyLevel.EXPERT, max_steps=20), ] def get_documents(self) -> List[Document]: return [Document(content="Test doc", source="test")] def get_grader(self, task_id: str) -> BaseGrader: return _MockGrader() def get_reward_function(self) -> BaseRewardFunction: return _MockRewardFn() def get_system_prompt(self) -> str: return "You are a test assistant." @pytest.fixture(autouse=True) def _inject_mock_orchestrator(): """Inject a mock orchestrator into the server module for all tests.""" orch = Orchestrator( domain_config=_MockDomain(), retriever=_MockRetriever(), llm_client=_MockLLM(), reward_function=_MockRewardFn(), ) server_app_module._orchestrators = {"aerospace": orch} server_app_module._active_domain = "aerospace" yield server_app_module._orchestrators = {} server_app_module._active_domain = "aerospace" @pytest.fixture async def client(): """Create async test client.""" transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as ac: yield ac class TestHealthEndpoint: """Tests for the health check endpoint.""" @pytest.mark.asyncio async def test_health(self, client: AsyncClient) -> None: resp = await client.get("/health") assert resp.status_code == 200 data = resp.json() assert data["status"] == "healthy" assert data["version"] == "1.0.0" class TestTasksEndpoint: """Tests for the tasks listing endpoint.""" @pytest.mark.asyncio async def test_list_tasks(self, client: AsyncClient) -> None: resp = await client.get("/tasks") assert resp.status_code == 200 data = resp.json() assert "tasks" in data assert len(data["tasks"]) >= 5 class TestResetEndpoint: """Tests for the reset endpoint.""" @pytest.mark.asyncio async def test_reset_default(self, client: AsyncClient) -> None: resp = await client.post("/reset", json={}) assert resp.status_code == 200 data = resp.json() assert "observation" in data @pytest.mark.asyncio async def test_reset_specific_task(self, client: AsyncClient) -> None: resp = await client.post("/reset", json={"task_id": "test_easy"}) assert resp.status_code == 200 data = resp.json() obs = data["observation"] assert obs["task"]["id"] == "test_easy" @pytest.mark.asyncio async def test_reset_invalid_task(self, client: AsyncClient) -> None: resp = await client.post("/reset", json={"task_id": "nonexistent"}) assert resp.status_code == 400 class TestStepEndpoint: """Tests for the step endpoint.""" @pytest.mark.asyncio async def test_step_retrieve(self, client: AsyncClient) -> None: await client.post("/reset", json={}) resp = await client.post("/step", json={"type": "retrieve", "query": "propulsion"}) assert resp.status_code == 200 data = resp.json() assert "observation" in data assert "reward" in data assert _SCORE_MIN <= data["reward"] <= _SCORE_MAX @pytest.mark.asyncio async def test_step_after_reset(self, client: AsyncClient) -> None: resp = await client.post("/reset", json={}) assert resp.status_code == 200 class TestStateEndpoint: """Tests for the state endpoint.""" @pytest.mark.asyncio async def test_state(self, client: AsyncClient) -> None: await client.post("/reset", json={}) resp = await client.get("/state") assert resp.status_code == 200 data = resp.json() assert "task_id" in data class TestGradeEndpoint: """Tests for the grade endpoint.""" @pytest.mark.asyncio async def test_grade(self, client: AsyncClient) -> None: await client.post("/reset", json={}) await client.post("/step", json={"type": "retrieve", "query": "test"}) resp = await client.post("/grade", json={}) assert resp.status_code == 200 data = resp.json() assert "score" in data assert _SCORE_MIN <= data["score"] <= _SCORE_MAX class TestServerModels: """Tests for server Pydantic models.""" def test_action_model(self) -> None: from server.models import Action, ActionType action = Action(type=ActionType.RETRIEVE, query="test") assert action.type == ActionType.RETRIEVE def test_observation_model(self) -> None: from server.models import Observation, TaskObservation task = TaskObservation(id="t1", name="T1", description="D1", difficulty="easy", max_steps=10) obs = Observation(task=task, step=0) assert obs.done is False def test_health_status(self) -> None: from server.models import HealthStatus h = HealthStatus() assert h.status == "healthy"