| """Integration tests for the research graph.""" |
|
|
| import pytest |
| from pydantic_ai.models.test import TestModel |
|
|
| from src.agents.graph.workflow import create_research_graph |
|
|
|
|
| @pytest.mark.integration |
| @pytest.mark.asyncio |
| async def test_graph_execution_flow(mocker): |
| """Test the graph runs from start to finish (simulated).""" |
| |
| |
| mocker.patch("src.agents.graph.nodes.get_model", return_value=TestModel()) |
|
|
| |
| mock_run = mocker.patch("pydantic_ai.Agent.run") |
| |
| mock_result = mocker.Mock() |
| mock_result.output = mocker.Mock() |
| |
| mock_result.output.hypotheses = [] |
| |
| |
| |
| |
|
|
| |
| from src.utils.models import ReportSection, ResearchReport |
|
|
| dummy_section = ReportSection(title="Dummy", content="Content") |
|
|
| mock_report = ResearchReport( |
| title="Test Report", |
| executive_summary="Summary " * 20, |
| research_question="Question", |
| methodology=dummy_section, |
| hypotheses_tested=[], |
| mechanistic_findings=dummy_section, |
| clinical_findings=dummy_section, |
| drug_candidates=[], |
| limitations=["None"], |
| conclusion="Conclusion", |
| references=[], |
| confidence_score=0.5, |
| ) |
|
|
| |
| |
| |
| mock_result.output = mock_report |
| mock_run.return_value = mock_result |
|
|
| |
| graph = create_research_graph(llm=None) |
|
|
| |
| initial_state = { |
| "query": "test query", |
| "hypotheses": [], |
| "conflicts": [], |
| "evidence_ids": [], |
| "messages": [], |
| "next_step": "search", |
| "iteration_count": 0, |
| "max_iterations": 2, |
| } |
|
|
| |
| events = [] |
| async for event in graph.astream(initial_state): |
| events.append(event) |
|
|
| |
| |
| assert len(events) >= 3, f"Expected at least 3 events, got {len(events)}" |
|
|
| |
| node_names = [next(iter(e.keys())) for e in events] |
| assert "supervisor" in node_names, "Supervisor node should have executed" |
| assert "search" in node_names, "Search node should have executed" |
| assert "synthesize" in node_names, "Synthesize node should have executed" |
|
|
| |
| final_event = events[-1] |
| assert "synthesize" in final_event, ( |
| f"Final event should be synthesis, got: {list(final_event.keys())}" |
| ) |
|
|
| |
| synth_output = final_event.get("synthesize", {}) |
| assert "messages" in synth_output, "Synthesis should produce messages" |
|
|