Spaces:
Running
Running
File size: 7,487 Bytes
1e732dd 696f787 1e732dd 9659593 1e732dd 9659593 1e732dd 9659593 1e732dd 9659593 1e732dd 9659593 1e732dd 9659593 1e732dd 9659593 1e732dd 9659593 1e732dd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 | """
Tests for src/services/agents/ — agentic RAG pipeline.
"""
from dataclasses import dataclass
from typing import Any
from unittest.mock import MagicMock
# -----------------------------------------------------------------------
# Mock context and LLM
# -----------------------------------------------------------------------
class MockMessage:
def __init__(self, content: str):
self.content = content
class MockLLM:
"""Programmable mock LLM that returns canned responses."""
def __init__(self, responses: list[str] | None = None):
self._responses = responses or []
self._call_count = 0
def invoke(self, messages: list) -> MockMessage:
if self._call_count < len(self._responses):
resp = self._responses[self._call_count]
else:
resp = '{"score": 80}'
self._call_count += 1
return MockMessage(resp)
@dataclass
class MockContext:
llm: Any | None = None
embedding_service: Any | None = None
opensearch_client: Any | None = None
cache: Any | None = None
tracer: Any | None = None
# -----------------------------------------------------------------------
# Guardrail node
# -----------------------------------------------------------------------
class TestGuardrailNode:
def test_in_scope_query(self):
from src.services.agents.nodes.guardrail_node import guardrail_node
ctx = MockContext(llm=MockLLM(['{"score": 85}']))
state = {"query": "What does high HbA1c mean?"}
result = guardrail_node(state, context=ctx)
assert result["is_in_scope"] is True
assert result["guardrail_score"] == 85.0
def test_out_of_scope_query(self):
from src.services.agents.nodes.guardrail_node import guardrail_node
ctx = MockContext(llm=MockLLM(['{"score": 10}']))
state = {"query": "What is the weather today?"}
result = guardrail_node(state, context=ctx)
assert result["is_in_scope"] is False
assert result["routing_decision"] == "out_of_scope"
def test_biomarkers_bypass(self):
from src.services.agents.nodes.guardrail_node import guardrail_node
ctx = MockContext(llm=MockLLM())
state = {"query": "analyze", "biomarkers": {"Glucose": 185}}
result = guardrail_node(state, context=ctx)
assert result["is_in_scope"] is True
assert result["guardrail_score"] == 95.0
def test_llm_failure_defaults_in_scope(self):
from src.services.agents.nodes.guardrail_node import guardrail_node
broken_llm = MagicMock()
broken_llm.invoke.side_effect = Exception("LLM down")
ctx = MockContext(llm=broken_llm)
state = {"query": "What is HbA1c?"}
result = guardrail_node(state, context=ctx)
assert result["is_in_scope"] is True # benefit of the doubt
# -----------------------------------------------------------------------
# Out-of-scope node
# -----------------------------------------------------------------------
class TestOutOfScopeNode:
def test_returns_rejection(self):
from src.services.agents.nodes.out_of_scope_node import out_of_scope_node
from src.services.agents.prompts import OUT_OF_SCOPE_RESPONSE
ctx = MockContext()
result = out_of_scope_node({}, context=ctx)
assert result["final_answer"] == OUT_OF_SCOPE_RESPONSE
# -----------------------------------------------------------------------
# Grade documents node
# -----------------------------------------------------------------------
class TestGradeDocumentsNode:
def test_grades_relevant(self):
from src.services.agents.nodes.grade_documents_node import grade_documents_node
ctx = MockContext(llm=MockLLM(['{"relevant": true}', '{"relevant": false}']))
state = {
"query": "diabetes treatment",
"retrieved_documents": [
{"id": "1", "text": "Diabetes is treated with insulin."},
{"id": "2", "text": "The weather is sunny today."},
],
}
result = grade_documents_node(state, context=ctx)
assert len(result["relevant_documents"]) == 1
assert result["grading_results"][0]["relevant"] is True
assert result["grading_results"][1]["relevant"] is False
def test_empty_docs_needs_rewrite(self):
from src.services.agents.nodes.grade_documents_node import grade_documents_node
ctx = MockContext()
state = {"query": "test", "retrieved_documents": []}
result = grade_documents_node(state, context=ctx)
assert result["needs_rewrite"] is True
# -----------------------------------------------------------------------
# Rewrite query node
# -----------------------------------------------------------------------
class TestRewriteQueryNode:
def test_rewrites(self):
from src.services.agents.nodes.rewrite_query_node import rewrite_query_node
ctx = MockContext(llm=MockLLM(["diabetes HbA1c glucose management guidelines"]))
state = {"query": "sugar problems"}
result = rewrite_query_node(state, context=ctx)
assert "diabetes" in result["rewritten_query"].lower() or result["rewritten_query"]
def test_llm_failure_keeps_original(self):
from src.services.agents.nodes.rewrite_query_node import rewrite_query_node
broken_llm = MagicMock()
broken_llm.invoke.side_effect = Exception("timeout")
ctx = MockContext(llm=broken_llm)
state = {"query": "original query"}
result = rewrite_query_node(state, context=ctx)
assert result["rewritten_query"] == "original query"
# -----------------------------------------------------------------------
# Generate answer node
# -----------------------------------------------------------------------
class TestGenerateAnswerNode:
def test_generates_answer(self):
from src.services.agents.nodes.generate_answer_node import generate_answer_node
ctx = MockContext(llm=MockLLM(["Based on the evidence, HbA1c of 8.2% indicates poor glycemic control."]))
state = {
"query": "What does HbA1c 8.2 mean?",
"relevant_documents": [
{"title": "Diabetes Guide", "section": "Diagnosis", "text": "HbA1c above 6.5% indicates diabetes."}
],
}
result = generate_answer_node(state, context=ctx)
assert "final_answer" in result
assert len(result["final_answer"]) > 10
def test_llm_failure_returns_fallback(self):
from src.services.agents.nodes.generate_answer_node import generate_answer_node
broken_llm = MagicMock()
broken_llm.invoke.side_effect = Exception("dead")
ctx = MockContext(llm=broken_llm)
state = {"query": "test", "relevant_documents": []}
result = generate_answer_node(state, context=ctx)
assert "apologize" in result["final_answer"].lower()
assert len(result["errors"]) > 0
# -----------------------------------------------------------------------
# Agentic RAG state
# -----------------------------------------------------------------------
class TestAgenticRAGState:
def test_state_is_typed_dict(self):
from src.services.agents.state import AgenticRAGState
# Should be usable as a dict type hint
state: AgenticRAGState = {
"query": "test",
"errors": [],
}
assert state["query"] == "test"
|