Commit
·
b1d094d
1
Parent(s):
0f049b6
feat: Implement SPEC_01 (Termination) and SPEC_02 (E2E Tests)
Browse files- SPEC_01: Added timeout (300s) and progress events to MagenticOrchestrator
- SPEC_02: Created tests/e2e/ with mocked tests for Simple and Advanced modes
- Docs: Updated specs to match codebase state
- docs/specs/SPEC_01_DEMO_TERMINATION.md +1 -1
- pyproject.toml +1 -0
- src/orchestrator_magentic.py +28 -10
- src/utils/models.py +2 -0
- tests/e2e/conftest.py +60 -0
- tests/e2e/test_advanced_mode.py +70 -0
- tests/e2e/test_simple_mode.py +65 -0
- tests/unit/test_magentic_termination.py +35 -0
docs/specs/SPEC_01_DEMO_TERMINATION.md
CHANGED
|
@@ -16,7 +16,7 @@ Advanced (Magentic) mode runs indefinitely from user perspective. The demo was m
|
|
| 16 |
### Question 1: Does max_round_count actually work?
|
| 17 |
|
| 18 |
```python
|
| 19 |
-
# Current code (src/orchestrator_magentic.py:
|
| 20 |
.with_standard_manager(
|
| 21 |
chat_client=manager_client,
|
| 22 |
max_round_count=self._max_rounds, # Default: 10
|
|
|
|
| 16 |
### Question 1: Does max_round_count actually work?
|
| 17 |
|
| 18 |
```python
|
| 19 |
+
# Current code (src/orchestrator_magentic.py:94)
|
| 20 |
.with_standard_manager(
|
| 21 |
chat_client=manager_client,
|
| 22 |
max_round_count=self._max_rounds, # Default: 10
|
pyproject.toml
CHANGED
|
@@ -129,6 +129,7 @@ markers = [
|
|
| 129 |
"unit: Unit tests (mocked)",
|
| 130 |
"integration: Integration tests (real APIs)",
|
| 131 |
"slow: Slow tests",
|
|
|
|
| 132 |
]
|
| 133 |
# Filter warnings from unittest.mock introspecting Pydantic models.
|
| 134 |
# This is a known upstream issue: https://github.com/pydantic/pydantic/issues/9927
|
|
|
|
| 129 |
"unit: Unit tests (mocked)",
|
| 130 |
"integration: Integration tests (real APIs)",
|
| 131 |
"slow: Slow tests",
|
| 132 |
+
"e2e: End-to-End tests (full pipeline)",
|
| 133 |
]
|
| 134 |
# Filter warnings from unittest.mock introspecting Pydantic models.
|
| 135 |
# This is a known upstream issue: https://github.com/pydantic/pydantic/issues/9927
|
src/orchestrator_magentic.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
"""Magentic-based orchestrator using ChatAgent pattern."""
|
| 2 |
|
|
|
|
| 3 |
from collections.abc import AsyncGenerator
|
| 4 |
from typing import TYPE_CHECKING, Any
|
| 5 |
|
|
@@ -169,18 +170,26 @@ The final output should be a structured research report."""
|
|
| 169 |
|
| 170 |
iteration = 0
|
| 171 |
final_event_received = False
|
|
|
|
| 172 |
|
| 173 |
try:
|
| 174 |
-
async
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
if
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
|
| 185 |
# GUARANTEE: Always emit termination event if stream ends without one
|
| 186 |
# (e.g., max rounds reached)
|
|
@@ -200,6 +209,15 @@ The final output should be a structured research report."""
|
|
| 200 |
iteration=iteration,
|
| 201 |
)
|
| 202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
except Exception as e:
|
| 204 |
logger.error("Magentic workflow failed", error=str(e))
|
| 205 |
yield AgentEvent(
|
|
|
|
| 1 |
"""Magentic-based orchestrator using ChatAgent pattern."""
|
| 2 |
|
| 3 |
+
import asyncio
|
| 4 |
from collections.abc import AsyncGenerator
|
| 5 |
from typing import TYPE_CHECKING, Any
|
| 6 |
|
|
|
|
| 170 |
|
| 171 |
iteration = 0
|
| 172 |
final_event_received = False
|
| 173 |
+
demo_timeout_seconds = 300 # 5 minutes max
|
| 174 |
|
| 175 |
try:
|
| 176 |
+
async with asyncio.timeout(demo_timeout_seconds):
|
| 177 |
+
async for event in workflow.run_stream(task):
|
| 178 |
+
agent_event = self._process_event(event, iteration)
|
| 179 |
+
if agent_event:
|
| 180 |
+
if isinstance(event, MagenticAgentMessageEvent):
|
| 181 |
+
iteration += 1
|
| 182 |
+
# Yield progress update before the agent action
|
| 183 |
+
yield AgentEvent(
|
| 184 |
+
type="progress",
|
| 185 |
+
message=f"Round {iteration}/{self._max_rounds}...",
|
| 186 |
+
iteration=iteration,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
if agent_event.type == "complete":
|
| 190 |
+
final_event_received = True
|
| 191 |
+
|
| 192 |
+
yield agent_event
|
| 193 |
|
| 194 |
# GUARANTEE: Always emit termination event if stream ends without one
|
| 195 |
# (e.g., max rounds reached)
|
|
|
|
| 209 |
iteration=iteration,
|
| 210 |
)
|
| 211 |
|
| 212 |
+
except TimeoutError:
|
| 213 |
+
logger.warning("Workflow timed out", iterations=iteration)
|
| 214 |
+
yield AgentEvent(
|
| 215 |
+
type="complete",
|
| 216 |
+
message="Research timed out. Synthesizing available evidence...",
|
| 217 |
+
data={"reason": "timeout", "iterations": iteration},
|
| 218 |
+
iteration=iteration,
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
except Exception as e:
|
| 222 |
logger.error("Magentic workflow failed", error=str(e))
|
| 223 |
yield AgentEvent(
|
src/utils/models.py
CHANGED
|
@@ -119,6 +119,7 @@ class AgentEvent(BaseModel):
|
|
| 119 |
"hypothesizing",
|
| 120 |
"analyzing", # NEW for Phase 13
|
| 121 |
"analysis_complete", # NEW for Phase 13
|
|
|
|
| 122 |
]
|
| 123 |
message: str
|
| 124 |
data: Any = None
|
|
@@ -142,6 +143,7 @@ class AgentEvent(BaseModel):
|
|
| 142 |
"hypothesizing": "🔬", # NEW
|
| 143 |
"analyzing": "📊", # NEW
|
| 144 |
"analysis_complete": "📈", # NEW
|
|
|
|
| 145 |
}
|
| 146 |
icon = icons.get(self.type, "•")
|
| 147 |
return f"{icon} **{self.type.upper()}**: {self.message}"
|
|
|
|
| 119 |
"hypothesizing",
|
| 120 |
"analyzing", # NEW for Phase 13
|
| 121 |
"analysis_complete", # NEW for Phase 13
|
| 122 |
+
"progress", # NEW for SPEC_01
|
| 123 |
]
|
| 124 |
message: str
|
| 125 |
data: Any = None
|
|
|
|
| 143 |
"hypothesizing": "🔬", # NEW
|
| 144 |
"analyzing": "📊", # NEW
|
| 145 |
"analysis_complete": "📈", # NEW
|
| 146 |
+
"progress": "⏱️", # NEW
|
| 147 |
}
|
| 148 |
icon = icons.get(self.type, "•")
|
| 149 |
return f"{icon} **{self.type.upper()}**: {self.message}"
|
tests/e2e/conftest.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from unittest.mock import MagicMock
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
|
| 5 |
+
from src.utils.models import AssessmentDetails, Citation, Evidence, JudgeAssessment, SearchResult
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@pytest.fixture
|
| 9 |
+
def mock_search_handler():
|
| 10 |
+
"""Return a mock search handler that returns fake evidence."""
|
| 11 |
+
mock = MagicMock()
|
| 12 |
+
|
| 13 |
+
async def mock_execute(query, max_results=10):
|
| 14 |
+
return SearchResult(
|
| 15 |
+
query=query,
|
| 16 |
+
evidence=[
|
| 17 |
+
Evidence(
|
| 18 |
+
content=f"Evidence content for {query}",
|
| 19 |
+
citation=Citation(
|
| 20 |
+
source="pubmed",
|
| 21 |
+
title=f"Study on {query}",
|
| 22 |
+
url="https://pubmed.example.com/123",
|
| 23 |
+
date="2025-01-01",
|
| 24 |
+
authors=["Doe J"],
|
| 25 |
+
),
|
| 26 |
+
)
|
| 27 |
+
],
|
| 28 |
+
sources_searched=["pubmed"],
|
| 29 |
+
total_found=1,
|
| 30 |
+
errors=[],
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
mock.execute = mock_execute
|
| 34 |
+
return mock
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@pytest.fixture
|
| 38 |
+
def mock_judge_handler():
|
| 39 |
+
"""Return a mock judge that always says 'synthesize'."""
|
| 40 |
+
mock = MagicMock()
|
| 41 |
+
|
| 42 |
+
async def mock_assess(question, evidence):
|
| 43 |
+
return JudgeAssessment(
|
| 44 |
+
sufficient=True,
|
| 45 |
+
confidence=0.9,
|
| 46 |
+
recommendation="synthesize",
|
| 47 |
+
details=AssessmentDetails(
|
| 48 |
+
mechanism_score=8,
|
| 49 |
+
mechanism_reasoning="Strong mechanism found in mock data",
|
| 50 |
+
clinical_evidence_score=7,
|
| 51 |
+
clinical_reasoning="Good clinical evidence in mock data",
|
| 52 |
+
drug_candidates=["MockDrug A"],
|
| 53 |
+
key_findings=["Finding 1", "Finding 2"],
|
| 54 |
+
),
|
| 55 |
+
reasoning="Evidence is sufficient for synthesis.",
|
| 56 |
+
next_search_queries=[],
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
mock.assess = mock_assess
|
| 60 |
+
return mock
|
tests/e2e/test_advanced_mode.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from unittest.mock import MagicMock, patch
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
|
| 5 |
+
# Skip entire module if agent_framework is not installed
|
| 6 |
+
agent_framework = pytest.importorskip("agent_framework")
|
| 7 |
+
from agent_framework import MagenticAgentMessageEvent, MagenticFinalResultEvent
|
| 8 |
+
|
| 9 |
+
from src.orchestrator_magentic import MagenticOrchestrator
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class MockChatMessage:
|
| 13 |
+
def __init__(self, content):
|
| 14 |
+
self.content = content
|
| 15 |
+
|
| 16 |
+
@property
|
| 17 |
+
def text(self):
|
| 18 |
+
return self.content
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@pytest.mark.asyncio
|
| 22 |
+
@pytest.mark.e2e
|
| 23 |
+
async def test_advanced_mode_completes_mocked():
|
| 24 |
+
"""Verify Advanced mode runs without crashing (mocked workflow)."""
|
| 25 |
+
|
| 26 |
+
# Initialize orchestrator (mocking requirements check)
|
| 27 |
+
with patch("src.orchestrator_magentic.check_magentic_requirements"):
|
| 28 |
+
orchestrator = MagenticOrchestrator(max_rounds=5)
|
| 29 |
+
|
| 30 |
+
# Mock the workflow
|
| 31 |
+
mock_workflow = MagicMock()
|
| 32 |
+
|
| 33 |
+
# Create fake events
|
| 34 |
+
# 1. Search Agent runs
|
| 35 |
+
mock_msg_1 = MockChatMessage("Found 5 papers on PubMed")
|
| 36 |
+
event1 = MagenticAgentMessageEvent(agent_id="SearchAgent", message=mock_msg_1)
|
| 37 |
+
|
| 38 |
+
# 2. Report Agent finishes
|
| 39 |
+
mock_result_msg = MockChatMessage("# Final Report\n\nFindings...")
|
| 40 |
+
event2 = MagenticFinalResultEvent(message=mock_result_msg)
|
| 41 |
+
|
| 42 |
+
async def mock_stream(task):
|
| 43 |
+
yield event1
|
| 44 |
+
yield event2
|
| 45 |
+
|
| 46 |
+
mock_workflow.run_stream = mock_stream
|
| 47 |
+
|
| 48 |
+
# Patch dependencies:
|
| 49 |
+
# _build_workflow: Returns our mock
|
| 50 |
+
# init_magentic_state: Avoids DB calls
|
| 51 |
+
# _init_embedding_service: Avoids loading embeddings
|
| 52 |
+
with (
|
| 53 |
+
patch.object(orchestrator, "_build_workflow", return_value=mock_workflow),
|
| 54 |
+
patch("src.orchestrator_magentic.init_magentic_state"),
|
| 55 |
+
patch.object(orchestrator, "_init_embedding_service", return_value=None),
|
| 56 |
+
):
|
| 57 |
+
events = []
|
| 58 |
+
async for event in orchestrator.run("test query"):
|
| 59 |
+
events.append(event)
|
| 60 |
+
|
| 61 |
+
# Check events
|
| 62 |
+
types = [e.type for e in events]
|
| 63 |
+
assert "started" in types
|
| 64 |
+
assert "thinking" in types
|
| 65 |
+
assert "search_complete" in types # Mapped from SearchAgent
|
| 66 |
+
assert "progress" in types # Added in SPEC_01
|
| 67 |
+
assert "complete" in types
|
| 68 |
+
|
| 69 |
+
complete_event = next(e for e in events if e.type == "complete")
|
| 70 |
+
assert "Final Report" in complete_event.message
|
tests/e2e/test_simple_mode.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
|
| 3 |
+
from src.orchestrator import Orchestrator
|
| 4 |
+
from src.utils.models import OrchestratorConfig
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@pytest.mark.asyncio
|
| 8 |
+
@pytest.mark.e2e
|
| 9 |
+
async def test_simple_mode_completes(mock_search_handler, mock_judge_handler):
|
| 10 |
+
"""Verify Simple mode runs without crashing using mocks."""
|
| 11 |
+
|
| 12 |
+
config = OrchestratorConfig(max_iterations=2)
|
| 13 |
+
|
| 14 |
+
orchestrator = Orchestrator(
|
| 15 |
+
search_handler=mock_search_handler,
|
| 16 |
+
judge_handler=mock_judge_handler,
|
| 17 |
+
config=config,
|
| 18 |
+
enable_analysis=False,
|
| 19 |
+
enable_embeddings=False,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
events = []
|
| 23 |
+
async for event in orchestrator.run("test query"):
|
| 24 |
+
events.append(event)
|
| 25 |
+
|
| 26 |
+
# Must complete
|
| 27 |
+
assert any(e.type == "complete" for e in events), "Did not receive complete event"
|
| 28 |
+
# Must not error
|
| 29 |
+
assert not any(e.type == "error" for e in events), "Received error event"
|
| 30 |
+
|
| 31 |
+
# Check structure of complete event
|
| 32 |
+
complete_event = next(e for e in events if e.type == "complete")
|
| 33 |
+
# The mock judge returns "MockDrug A" and "Finding 1", ensuring synthesis happens
|
| 34 |
+
assert "MockDrug A" in complete_event.message
|
| 35 |
+
assert "Finding 1" in complete_event.message
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@pytest.mark.asyncio
|
| 39 |
+
@pytest.mark.e2e
|
| 40 |
+
async def test_simple_mode_structure_validation(mock_search_handler, mock_judge_handler):
|
| 41 |
+
"""Verify output contains expected structure (citations, headings)."""
|
| 42 |
+
config = OrchestratorConfig(max_iterations=2)
|
| 43 |
+
orchestrator = Orchestrator(
|
| 44 |
+
search_handler=mock_search_handler,
|
| 45 |
+
judge_handler=mock_judge_handler,
|
| 46 |
+
config=config,
|
| 47 |
+
enable_analysis=False,
|
| 48 |
+
enable_embeddings=False,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
events = []
|
| 52 |
+
async for event in orchestrator.run("test query"):
|
| 53 |
+
events.append(event)
|
| 54 |
+
|
| 55 |
+
complete_event = next(e for e in events if e.type == "complete")
|
| 56 |
+
report = complete_event.message
|
| 57 |
+
|
| 58 |
+
# Check markdown structure
|
| 59 |
+
assert "## Drug Repurposing Analysis" in report
|
| 60 |
+
assert "### Citations" in report
|
| 61 |
+
assert "### Key Findings" in report
|
| 62 |
+
|
| 63 |
+
# Check for citations
|
| 64 |
+
assert "Study on test query" in report
|
| 65 |
+
assert "https://pubmed.example.com/123" in report
|
tests/unit/test_magentic_termination.py
CHANGED
|
@@ -109,3 +109,38 @@ async def test_no_double_termination_event(mock_magentic_requirements):
|
|
| 109 |
# Verify we didn't get a SECOND "Max iterations reached" event
|
| 110 |
fallback_events = [e for e in events if "Max iterations reached" in e.message]
|
| 111 |
assert len(fallback_events) == 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
# Verify we didn't get a SECOND "Max iterations reached" event
|
| 110 |
fallback_events = [e for e in events if "Max iterations reached" in e.message]
|
| 111 |
assert len(fallback_events) == 0
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
@pytest.mark.asyncio
|
| 115 |
+
async def test_termination_on_timeout(mock_magentic_requirements):
|
| 116 |
+
"""
|
| 117 |
+
Verify that a termination event is emitted when the workflow times out.
|
| 118 |
+
"""
|
| 119 |
+
orchestrator = MagenticOrchestrator()
|
| 120 |
+
|
| 121 |
+
mock_workflow = MagicMock()
|
| 122 |
+
|
| 123 |
+
# Simulate a stream that times out (raises TimeoutError)
|
| 124 |
+
async def mock_stream_raises(task):
|
| 125 |
+
# Yield one event before timing out
|
| 126 |
+
yield MagenticAgentMessageEvent(
|
| 127 |
+
agent_id="SearchAgent", message=MockChatMessage("Working...")
|
| 128 |
+
)
|
| 129 |
+
raise TimeoutError()
|
| 130 |
+
|
| 131 |
+
mock_workflow.run_stream = mock_stream_raises
|
| 132 |
+
|
| 133 |
+
with patch.object(orchestrator, "_build_workflow", return_value=mock_workflow):
|
| 134 |
+
events = []
|
| 135 |
+
async for event in orchestrator.run("Research query"):
|
| 136 |
+
events.append(event)
|
| 137 |
+
|
| 138 |
+
# Check for progress/normal events
|
| 139 |
+
assert any("Working..." in e.message for e in events)
|
| 140 |
+
|
| 141 |
+
# Check for timeout completion
|
| 142 |
+
completion_events = [e for e in events if e.type == "complete"]
|
| 143 |
+
assert len(completion_events) > 0
|
| 144 |
+
last_event = completion_events[-1]
|
| 145 |
+
assert "timed out" in last_event.message
|
| 146 |
+
assert last_event.data.get("reason") == "timeout"
|