|
|
"""Integration tests for the research graph.""" |
|
|
|
|
|
import pytest |
|
|
from pydantic_ai.models.test import TestModel |
|
|
|
|
|
from src.agents.graph.workflow import create_research_graph |
|
|
|
|
|
|
|
|
@pytest.mark.integration |
|
|
@pytest.mark.asyncio |
|
|
async def test_graph_execution_flow(mocker): |
|
|
"""Test the graph runs from start to finish (simulated).""" |
|
|
|
|
|
|
|
|
mocker.patch("src.agents.graph.nodes.get_model", return_value=TestModel()) |
|
|
|
|
|
|
|
|
mock_run = mocker.patch("pydantic_ai.Agent.run") |
|
|
|
|
|
mock_result = mocker.Mock() |
|
|
mock_result.output = mocker.Mock() |
|
|
|
|
|
mock_result.output.hypotheses = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from src.utils.models import ReportSection, ResearchReport |
|
|
|
|
|
dummy_section = ReportSection(title="Dummy", content="Content") |
|
|
|
|
|
mock_report = ResearchReport( |
|
|
title="Test Report", |
|
|
executive_summary="Summary " * 20, |
|
|
research_question="Question", |
|
|
methodology=dummy_section, |
|
|
hypotheses_tested=[], |
|
|
mechanistic_findings=dummy_section, |
|
|
clinical_findings=dummy_section, |
|
|
drug_candidates=[], |
|
|
limitations=["None"], |
|
|
conclusion="Conclusion", |
|
|
references=[], |
|
|
confidence_score=0.5, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mock_result.output = mock_report |
|
|
mock_run.return_value = mock_result |
|
|
|
|
|
|
|
|
graph = create_research_graph(llm=None) |
|
|
|
|
|
|
|
|
initial_state = { |
|
|
"query": "test query", |
|
|
"hypotheses": [], |
|
|
"conflicts": [], |
|
|
"evidence_ids": [], |
|
|
"messages": [], |
|
|
"next_step": "search", |
|
|
"iteration_count": 0, |
|
|
"max_iterations": 2, |
|
|
} |
|
|
|
|
|
|
|
|
events = [] |
|
|
async for event in graph.astream(initial_state): |
|
|
events.append(event) |
|
|
|
|
|
|
|
|
|
|
|
assert len(events) >= 3, f"Expected at least 3 events, got {len(events)}" |
|
|
|
|
|
|
|
|
node_names = [next(iter(e.keys())) for e in events] |
|
|
assert "supervisor" in node_names, "Supervisor node should have executed" |
|
|
assert "search" in node_names, "Search node should have executed" |
|
|
assert "synthesize" in node_names, "Synthesize node should have executed" |
|
|
|
|
|
|
|
|
final_event = events[-1] |
|
|
assert "synthesize" in final_event, ( |
|
|
f"Final event should be synthesis, got: {list(final_event.keys())}" |
|
|
) |
|
|
|
|
|
|
|
|
synth_output = final_event.get("synthesize", {}) |
|
|
assert "messages" in synth_output, "Synthesis should produce messages" |
|
|
|