Agentic-RagBot / tests /test_agentic_rag.py
T0X1N's picture
chore: codebase audit and fixes (ruff, mypy, pytest)
9659593
"""
Tests for src/services/agents/ — agentic RAG pipeline.
"""
from dataclasses import dataclass
from typing import Any
from unittest.mock import MagicMock
# -----------------------------------------------------------------------
# Mock context and LLM
# -----------------------------------------------------------------------
class MockMessage:
def __init__(self, content: str):
self.content = content
class MockLLM:
"""Programmable mock LLM that returns canned responses."""
def __init__(self, responses: list[str] | None = None):
self._responses = responses or []
self._call_count = 0
def invoke(self, messages: list) -> MockMessage:
if self._call_count < len(self._responses):
resp = self._responses[self._call_count]
else:
resp = '{"score": 80}'
self._call_count += 1
return MockMessage(resp)
@dataclass
class MockContext:
llm: Any | None = None
embedding_service: Any | None = None
opensearch_client: Any | None = None
cache: Any | None = None
tracer: Any | None = None
# -----------------------------------------------------------------------
# Guardrail node
# -----------------------------------------------------------------------
class TestGuardrailNode:
def test_in_scope_query(self):
from src.services.agents.nodes.guardrail_node import guardrail_node
ctx = MockContext(llm=MockLLM(['{"score": 85}']))
state = {"query": "What does high HbA1c mean?"}
result = guardrail_node(state, context=ctx)
assert result["is_in_scope"] is True
assert result["guardrail_score"] == 85.0
def test_out_of_scope_query(self):
from src.services.agents.nodes.guardrail_node import guardrail_node
ctx = MockContext(llm=MockLLM(['{"score": 10}']))
state = {"query": "What is the weather today?"}
result = guardrail_node(state, context=ctx)
assert result["is_in_scope"] is False
assert result["routing_decision"] == "out_of_scope"
def test_biomarkers_bypass(self):
from src.services.agents.nodes.guardrail_node import guardrail_node
ctx = MockContext(llm=MockLLM())
state = {"query": "analyze", "biomarkers": {"Glucose": 185}}
result = guardrail_node(state, context=ctx)
assert result["is_in_scope"] is True
assert result["guardrail_score"] == 95.0
def test_llm_failure_defaults_in_scope(self):
from src.services.agents.nodes.guardrail_node import guardrail_node
broken_llm = MagicMock()
broken_llm.invoke.side_effect = Exception("LLM down")
ctx = MockContext(llm=broken_llm)
state = {"query": "What is HbA1c?"}
result = guardrail_node(state, context=ctx)
assert result["is_in_scope"] is True # benefit of the doubt
# -----------------------------------------------------------------------
# Out-of-scope node
# -----------------------------------------------------------------------
class TestOutOfScopeNode:
def test_returns_rejection(self):
from src.services.agents.nodes.out_of_scope_node import out_of_scope_node
from src.services.agents.prompts import OUT_OF_SCOPE_RESPONSE
ctx = MockContext()
result = out_of_scope_node({}, context=ctx)
assert result["final_answer"] == OUT_OF_SCOPE_RESPONSE
# -----------------------------------------------------------------------
# Grade documents node
# -----------------------------------------------------------------------
class TestGradeDocumentsNode:
def test_grades_relevant(self):
from src.services.agents.nodes.grade_documents_node import grade_documents_node
ctx = MockContext(llm=MockLLM(['{"relevant": true}', '{"relevant": false}']))
state = {
"query": "diabetes treatment",
"retrieved_documents": [
{"id": "1", "text": "Diabetes is treated with insulin."},
{"id": "2", "text": "The weather is sunny today."},
],
}
result = grade_documents_node(state, context=ctx)
assert len(result["relevant_documents"]) == 1
assert result["grading_results"][0]["relevant"] is True
assert result["grading_results"][1]["relevant"] is False
def test_empty_docs_needs_rewrite(self):
from src.services.agents.nodes.grade_documents_node import grade_documents_node
ctx = MockContext()
state = {"query": "test", "retrieved_documents": []}
result = grade_documents_node(state, context=ctx)
assert result["needs_rewrite"] is True
# -----------------------------------------------------------------------
# Rewrite query node
# -----------------------------------------------------------------------
class TestRewriteQueryNode:
def test_rewrites(self):
from src.services.agents.nodes.rewrite_query_node import rewrite_query_node
ctx = MockContext(llm=MockLLM(["diabetes HbA1c glucose management guidelines"]))
state = {"query": "sugar problems"}
result = rewrite_query_node(state, context=ctx)
assert "diabetes" in result["rewritten_query"].lower() or result["rewritten_query"]
def test_llm_failure_keeps_original(self):
from src.services.agents.nodes.rewrite_query_node import rewrite_query_node
broken_llm = MagicMock()
broken_llm.invoke.side_effect = Exception("timeout")
ctx = MockContext(llm=broken_llm)
state = {"query": "original query"}
result = rewrite_query_node(state, context=ctx)
assert result["rewritten_query"] == "original query"
# -----------------------------------------------------------------------
# Generate answer node
# -----------------------------------------------------------------------
class TestGenerateAnswerNode:
def test_generates_answer(self):
from src.services.agents.nodes.generate_answer_node import generate_answer_node
ctx = MockContext(llm=MockLLM(["Based on the evidence, HbA1c of 8.2% indicates poor glycemic control."]))
state = {
"query": "What does HbA1c 8.2 mean?",
"relevant_documents": [
{"title": "Diabetes Guide", "section": "Diagnosis", "text": "HbA1c above 6.5% indicates diabetes."}
],
}
result = generate_answer_node(state, context=ctx)
assert "final_answer" in result
assert len(result["final_answer"]) > 10
def test_llm_failure_returns_fallback(self):
from src.services.agents.nodes.generate_answer_node import generate_answer_node
broken_llm = MagicMock()
broken_llm.invoke.side_effect = Exception("dead")
ctx = MockContext(llm=broken_llm)
state = {"query": "test", "relevant_documents": []}
result = generate_answer_node(state, context=ctx)
assert "apologize" in result["final_answer"].lower()
assert len(result["errors"]) > 0
# -----------------------------------------------------------------------
# Agentic RAG state
# -----------------------------------------------------------------------
class TestAgenticRAGState:
def test_state_is_typed_dict(self):
from src.services.agents.state import AgenticRAGState
# Should be usable as a dict type hint
state: AgenticRAGState = {
"query": "test",
"errors": [],
}
assert state["query"] == "test"