|
|
"""Tests for Magentic Orchestrator termination guarantee.""" |
|
|
|
|
|
from unittest.mock import MagicMock, patch |
|
|
|
|
|
import pytest |
|
|
|
|
|
|
|
|
|
|
|
pytest.importorskip("agent_framework") |
|
|
|
|
|
from agent_framework import MagenticAgentMessageEvent |
|
|
|
|
|
from src.orchestrators.advanced import AdvancedOrchestrator as MagenticOrchestrator |
|
|
from src.utils.models import AgentEvent |
|
|
|
|
|
|
|
|
class MockChatMessage: |
|
|
def __init__(self, content): |
|
|
self.content = content |
|
|
self.role = "assistant" |
|
|
|
|
|
@property |
|
|
def text(self): |
|
|
return self.content |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def mock_magentic_requirements(): |
|
|
"""Mock requirements check.""" |
|
|
with patch("src.orchestrators.advanced.check_magentic_requirements"): |
|
|
yield |
|
|
|
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_termination_event_emitted_on_stream_end(mock_magentic_requirements): |
|
|
""" |
|
|
Verify that a termination event is emitted when the workflow stream ends |
|
|
without a MagenticFinalResultEvent (e.g. max rounds reached). |
|
|
""" |
|
|
orchestrator = MagenticOrchestrator(max_rounds=2) |
|
|
|
|
|
|
|
|
mock_message = MockChatMessage("Thinking...") |
|
|
mock_agent_event = MagenticAgentMessageEvent(agent_id="SearchAgent", message=mock_message) |
|
|
|
|
|
|
|
|
mock_workflow = MagicMock() |
|
|
|
|
|
|
|
|
async def mock_stream(task): |
|
|
|
|
|
yield mock_agent_event |
|
|
|
|
|
|
|
|
mock_workflow.run_stream = mock_stream |
|
|
|
|
|
|
|
|
with patch.object(orchestrator, "_build_workflow", return_value=mock_workflow): |
|
|
events = [] |
|
|
async for event in orchestrator.run("Research query"): |
|
|
events.append(event) |
|
|
|
|
|
for i, e in enumerate(events): |
|
|
print(f"Event {i}: {e.type} - {e.message}") |
|
|
|
|
|
assert len(events) >= 2 |
|
|
assert events[0].type == "started" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert any("Thinking..." in e.message for e in events) |
|
|
|
|
|
|
|
|
last_event = events[-1] |
|
|
assert last_event.type == "complete" |
|
|
assert "Max iterations reached" in last_event.message |
|
|
assert last_event.data.get("reason") == "max_rounds_reached" |
|
|
|
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_no_double_termination_event(mock_magentic_requirements): |
|
|
""" |
|
|
Verify that we DO NOT emit a fallback event if the workflow finished normally. |
|
|
""" |
|
|
orchestrator = MagenticOrchestrator() |
|
|
|
|
|
mock_workflow = MagicMock() |
|
|
|
|
|
with patch.object(orchestrator, "_build_workflow", return_value=mock_workflow): |
|
|
|
|
|
with patch.object(orchestrator, "_process_event") as mock_process: |
|
|
mock_process.side_effect = [ |
|
|
AgentEvent(type="thinking", message="Working...", iteration=1), |
|
|
AgentEvent(type="complete", message="Done!", iteration=2), |
|
|
] |
|
|
|
|
|
async def mock_stream_with_yields(task): |
|
|
yield "raw_event_1" |
|
|
yield "raw_event_2" |
|
|
|
|
|
mock_workflow.run_stream = mock_stream_with_yields |
|
|
|
|
|
events = [] |
|
|
async for event in orchestrator.run("Research query"): |
|
|
events.append(event) |
|
|
|
|
|
assert events[-1].message == "Done!" |
|
|
assert events[-1].type == "complete" |
|
|
|
|
|
|
|
|
fallback_events = [e for e in events if "Max iterations reached" in e.message] |
|
|
assert len(fallback_events) == 0 |
|
|
|
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_termination_on_timeout(mock_magentic_requirements): |
|
|
""" |
|
|
Verify that a termination event is emitted when the workflow times out. |
|
|
""" |
|
|
orchestrator = MagenticOrchestrator() |
|
|
|
|
|
mock_workflow = MagicMock() |
|
|
|
|
|
|
|
|
async def mock_stream_raises(task): |
|
|
|
|
|
yield MagenticAgentMessageEvent( |
|
|
agent_id="SearchAgent", message=MockChatMessage("Working...") |
|
|
) |
|
|
raise TimeoutError() |
|
|
|
|
|
mock_workflow.run_stream = mock_stream_raises |
|
|
|
|
|
with patch.object(orchestrator, "_build_workflow", return_value=mock_workflow): |
|
|
events = [] |
|
|
async for event in orchestrator.run("Research query"): |
|
|
events.append(event) |
|
|
|
|
|
|
|
|
assert any("Working..." in e.message for e in events) |
|
|
|
|
|
|
|
|
completion_events = [e for e in events if e.type == "complete"] |
|
|
assert len(completion_events) > 0 |
|
|
last_event = completion_events[-1] |
|
|
assert "timed out" in last_event.message |
|
|
assert last_event.data.get("reason") == "timeout" |
|
|
|