Rabbook / tests /test_phase2_rag_graph.py
Matcry's picture
Deploy snapshot
c76423f
Raw
History Blame Contribute Delete
11.1 kB
import unittest
from unittest.mock import patch
from agents.rag_graph import (
build_initial_graph_state,
check_grounding_node,
decide_next_action_node,
expand_context_node,
fallback_answer_node,
generate_answer_node,
prepare_input_node,
route_after_grounding,
run_rag_graph_answer,
retrieve_node,
)
class Phase2RagGraphTests(unittest.TestCase):
def test_build_initial_graph_state_sets_expected_defaults(self):
state = build_initial_graph_state(
"What is this roadmap about?",
selected_file="roadmap.pdf",
debug_mode=True,
)
self.assertEqual(state["query"], "What is this roadmap about?")
self.assertEqual(state["selected_file"], "roadmap.pdf")
self.assertTrue(state["debug_mode"])
self.assertEqual(state["retrieved_documents"], [])
self.assertIsNone(state["metadata_filter"])
self.assertIsNone(state["next_action"])
def test_prepare_input_node_builds_metadata_filter(self):
state = build_initial_graph_state(
"What is this roadmap about?",
selected_file="roadmap.pdf",
selected_file_type="pdf",
page_start="4",
page_end="2",
)
updated_state = prepare_input_node(state)
self.assertEqual(
updated_state["metadata_filter"],
{
"file_name": "roadmap.pdf",
"file_type": "pdf",
"page_range": {"start": 2, "end": 4},
},
)
@patch("agents.rag_graph.retrieve_documents_with_query_transform")
def test_retrieve_node_stores_retrieved_documents_and_debug_data(self, mock_retrieve):
state = build_initial_graph_state(
"What is this roadmap about?",
selected_file="roadmap.pdf",
debug_mode=True,
)
state["metadata_filter"] = {"file_name": "roadmap.pdf"}
mock_retrieve.return_value = (["doc-a", "doc-b"], {"stage_counts": {"search_queries": 1}})
updated_state = retrieve_node(
state,
vectorstore=object(),
reranker=object(),
bm25_index=object(),
llm=object(),
retrieval_k=4,
rerank_candidate_k=8,
bm25_candidate_k=8,
enable_query_transform=True,
)
self.assertEqual(updated_state["retrieved_documents"], ["doc-a", "doc-b"])
self.assertEqual(updated_state["debug_data"]["metadata_filter"], {"file_name": "roadmap.pdf"})
self.assertEqual(updated_state["debug_data"]["grounding"]["reason"], "not_checked")
@patch("agents.rag_graph.expand_with_context_window")
def test_expand_context_node_stores_expanded_documents(self, mock_expand):
state = build_initial_graph_state("What is this roadmap about?", debug_mode=True)
state["retrieved_documents"] = ["doc-a"]
state["debug_data"] = {"stage_counts": {}, "grounding": {"reason": "not_checked"}}
mock_expand.return_value = ["doc-a", "doc-a-next"]
updated_state = expand_context_node(
state,
chunk_registry={"by_chunk_id": {}},
context_window=1,
max_expanded_chunks=12,
)
self.assertEqual(updated_state["expanded_documents"], ["doc-a", "doc-a-next"])
self.assertEqual(updated_state["debug_data"]["expanded_hits"], ["doc-a", "doc-a-next"])
self.assertEqual(updated_state["debug_data"]["stage_counts"]["expanded_context"], 2)
@patch("agents.rag_graph.check_grounding_evidence")
def test_check_grounding_node_stores_grounding_result(self, mock_check_grounding):
state = build_initial_graph_state("What is this roadmap about?", debug_mode=True)
state["retrieved_documents"] = ["doc-a"]
state["expanded_documents"] = ["doc-a", "doc-a-next"]
state["debug_data"] = {
"stage_counts": {},
"grounding": {"stage": "retrieval", "passed": None, "reason": "not_checked"},
}
mock_check_grounding.return_value = {
"passed": True,
"reason": "enough_evidence",
"top_rerank_score": 1.2,
"retrieved_count": 1,
"expanded_count": 2,
}
updated_state = check_grounding_node(
state,
min_grounded_rerank_score=1.0,
min_grounded_chunks=1,
)
self.assertTrue(updated_state["grounding"]["passed"])
self.assertEqual(updated_state["debug_data"]["grounding"]["reason"], "enough_evidence")
self.assertEqual(updated_state["debug_data"]["grounding"]["stage"], "retrieval")
def test_route_after_grounding_chooses_refine_query_when_retry_retrieval(self):
state = build_initial_graph_state("What is this roadmap about?")
state["next_action"] = "retry_retrieval"
next_node = route_after_grounding(state)
self.assertEqual(next_node, "refine_query")
def test_route_after_grounding_chooses_fallback_when_no_evidence(self):
state = build_initial_graph_state("What is this roadmap about?")
state["next_action"] = "fallback"
next_node = route_after_grounding(state)
self.assertEqual(next_node, "fallback_answer")
def test_decide_next_action_node_marks_retry_when_local_evidence_is_partial(self):
state = build_initial_graph_state("What is this roadmap about?", debug_mode=True)
state["retrieved_documents"] = ["doc-a"]
state["expanded_documents"] = ["doc-a", "doc-a-next"]
state["grounding"] = {"passed": False}
state["debug_data"] = {"grounding": {}, "stage_counts": {}}
updated_state = decide_next_action_node(state)
self.assertEqual(updated_state["next_action"], "retry_retrieval")
self.assertEqual(updated_state["decision_reason"], "partial_local_evidence_retry_1")
self.assertEqual(updated_state["debug_data"]["next_action"], "retry_retrieval")
def test_decide_next_action_node_marks_fallback_when_no_local_evidence_exists(self):
state = build_initial_graph_state("What is this roadmap about?", debug_mode=True)
state["grounding"] = {"passed": False}
state["debug_data"] = {"grounding": {}, "stage_counts": {}}
updated_state = decide_next_action_node(state)
self.assertEqual(updated_state["next_action"], "fallback")
self.assertEqual(updated_state["decision_reason"], "no_evidence_found_research_disabled")
def test_decide_next_action_node_marks_answer_when_grounding_passes(self):
state = build_initial_graph_state("What is this roadmap about?", debug_mode=True)
state["grounding"] = {"passed": True}
state["debug_data"] = {"grounding": {}, "stage_counts": {}}
updated_state = decide_next_action_node(state)
self.assertEqual(updated_state["next_action"], "answer")
self.assertEqual(updated_state["decision_reason"], "grounding_passed")
@patch("agents.rag_graph.build_sources")
def test_fallback_answer_node_stores_fallback_answer_and_sources(self, mock_build_sources):
state = build_initial_graph_state("What is this roadmap about?", debug_mode=True)
state["retrieved_documents"] = ["doc-a"]
state["debug_data"] = {"grounding": {}, "stage_counts": {}}
mock_build_sources.return_value = [{"chunk_id": "c1"}]
updated_state = fallback_answer_node(
state,
grounded_fallback_message="fallback",
)
self.assertEqual(updated_state["answer"], "fallback")
self.assertEqual(updated_state["sources"], [{"chunk_id": "c1"}])
self.assertEqual(updated_state["citations"], [])
self.assertEqual(updated_state["debug_data"]["graph_path"], "fallback_answer")
@patch("agents.rag_graph.build_citations")
@patch("agents.rag_graph.answer_is_grounded")
@patch("agents.rag_graph.generate_answer")
@patch("agents.rag_graph.format_context")
@patch("agents.rag_graph.build_sources")
def test_generate_answer_node_stores_answer_sources_and_citations(
self,
mock_build_sources,
mock_format_context,
mock_generate_answer,
mock_answer_is_grounded,
mock_build_citations,
):
state = build_initial_graph_state("What is this roadmap about?", debug_mode=True)
state["retrieved_documents"] = ["doc-a"]
state["expanded_documents"] = ["doc-a", "doc-a-next"]
state["grounding"] = {
"passed": True,
"top_rerank_score": 1.2,
"retrieved_count": 1,
"expanded_count": 2,
}
state["debug_data"] = {"grounding": {}, "stage_counts": {}}
mock_build_sources.return_value = [{"chunk_id": "c1"}]
mock_format_context.return_value = "[1] expanded context"
mock_generate_answer.return_value = "Grounded answer [1]"
mock_answer_is_grounded.return_value = True
mock_build_citations.return_value = [{"number": 1}]
updated_state = generate_answer_node(
state,
llm=object(),
grounded_fallback_message="fallback",
)
self.assertEqual(updated_state["answer"], "Grounded answer [1]")
self.assertEqual(updated_state["sources"], [{"chunk_id": "c1"}])
self.assertEqual(updated_state["citations"], [{"number": 1}])
self.assertEqual(updated_state["debug_data"]["graph_path"], "generate_answer")
self.assertEqual(updated_state["debug_data"]["grounding"]["reason"], "answer_is_grounded")
@patch("agents.rag_graph.build_rag_graph")
def test_run_rag_graph_answer_returns_answer_result_shape(self, mock_build_graph):
fake_graph = mock_build_graph.return_value
fake_graph.invoke.return_value = {
"answer": "Graph answer [1]",
"sources": [{"chunk_id": "c1"}],
"citations": [{"number": 1}],
"debug_data": {"grounding": {"reason": "answer_is_grounded"}},
}
result = run_rag_graph_answer(
"What is this roadmap about?",
vectorstore=object(),
chunk_registry={"by_chunk_id": {}},
reranker=object(),
bm25_index=object(),
llm=object(),
retrieval_k=4,
rerank_candidate_k=8,
bm25_candidate_k=8,
context_window=1,
max_expanded_chunks=12,
min_grounded_rerank_score=1.0,
min_grounded_chunks=1,
grounded_fallback_message="fallback",
enable_query_transform=True,
selected_file="roadmap.pdf",
debug_mode=True,
)
self.assertEqual(result.answer, "Graph answer [1]")
self.assertEqual(result.sources, [{"chunk_id": "c1"}])
self.assertEqual(result.citations, [{"number": 1}])
self.assertEqual(result.debug_data["pipeline_mode"], "langgraph_rag")
self.assertEqual(result.debug_data["grounding"]["reason"], "answer_is_grounded")
if __name__ == "__main__":
unittest.main()