File size: 5,220 Bytes
d36ce3c
 
 
 
 
 
7baf8ba
 
d36ce3c
 
7baf8ba
 
 
 
 
d36ce3c
 
 
 
 
 
 
 
 
 
 
 
 
 
e82d9c9
d36ce3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0257d2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""Tests for Magentic Orchestrator termination guarantee."""

from unittest.mock import MagicMock, patch

import pytest

# Skip all tests if agent_framework not installed (optional dep)
# MUST come before any agent_framework imports
pytest.importorskip("agent_framework")

from agent_framework import MagenticAgentMessageEvent  # noqa: E402

from src.orchestrators.advanced import AdvancedOrchestrator as MagenticOrchestrator  # noqa: E402
from src.utils.models import AgentEvent  # noqa: E402


class MockChatMessage:
    def __init__(self, content):
        self.content = content
        self.role = "assistant"

    @property
    def text(self):
        return self.content


@pytest.fixture
def mock_magentic_requirements():
    """Mock requirements check."""
    with patch("src.orchestrators.advanced.check_magentic_requirements"):
        yield


@pytest.mark.asyncio
async def test_termination_event_emitted_on_stream_end(mock_magentic_requirements):
    """
    Verify that a termination event is emitted when the workflow stream ends
    without a MagenticFinalResultEvent (e.g. max rounds reached).
    """
    orchestrator = MagenticOrchestrator(max_rounds=2)

    # Use real event class
    mock_message = MockChatMessage("Thinking...")
    mock_agent_event = MagenticAgentMessageEvent(agent_id="SearchAgent", message=mock_message)

    # Mock the workflow and its run_stream method
    mock_workflow = MagicMock()

    # Create an async generator for run_stream
    async def mock_stream(task):
        # Yield the real message event
        yield mock_agent_event
        # STOP HERE - No FinalResultEvent

    mock_workflow.run_stream = mock_stream

    # Mock _build_workflow to return our mock workflow
    with patch.object(orchestrator, "_build_workflow", return_value=mock_workflow):
        events = []
        async for event in orchestrator.run("Research query"):
            events.append(event)

        for i, e in enumerate(events):
            print(f"Event {i}: {e.type} - {e.message}")

        assert len(events) >= 2
        assert events[0].type == "started"

        # Verify the message event was processed
        # Depending on _process_event logic, MagenticAgentMessageEvent might map to different types
        # We assume it maps to something valid or we just check presence.
        assert any("Thinking..." in e.message for e in events)

        # THE CRITICAL CHECK: Did we get the fallback termination event?
        last_event = events[-1]
        assert last_event.type == "complete"
        assert "Max iterations reached" in last_event.message
        assert last_event.data.get("reason") == "max_rounds_reached"


@pytest.mark.asyncio
async def test_no_double_termination_event(mock_magentic_requirements):
    """
    Verify that we DO NOT emit a fallback event if the workflow finished normally.
    """
    orchestrator = MagenticOrchestrator()

    mock_workflow = MagicMock()

    with patch.object(orchestrator, "_build_workflow", return_value=mock_workflow):
        # Mock _process_event to simulate a natural completion event
        with patch.object(orchestrator, "_process_event") as mock_process:
            mock_process.side_effect = [
                AgentEvent(type="thinking", message="Working...", iteration=1),
                AgentEvent(type="complete", message="Done!", iteration=2),
            ]

            async def mock_stream_with_yields(task):
                yield "raw_event_1"
                yield "raw_event_2"

            mock_workflow.run_stream = mock_stream_with_yields

            events = []
            async for event in orchestrator.run("Research query"):
                events.append(event)

            assert events[-1].message == "Done!"
            assert events[-1].type == "complete"

            # Verify we didn't get a SECOND "Max iterations reached" event
            fallback_events = [e for e in events if "Max iterations reached" in e.message]
            assert len(fallback_events) == 0


@pytest.mark.asyncio
async def test_termination_on_timeout(mock_magentic_requirements):
    """
    Verify that a termination event is emitted when the workflow times out.
    """
    orchestrator = MagenticOrchestrator()

    mock_workflow = MagicMock()

    # Simulate a stream that times out (raises TimeoutError)
    async def mock_stream_raises(task):
        # Yield one event before timing out
        yield MagenticAgentMessageEvent(
            agent_id="SearchAgent", message=MockChatMessage("Working...")
        )
        raise TimeoutError()

    mock_workflow.run_stream = mock_stream_raises

    with patch.object(orchestrator, "_build_workflow", return_value=mock_workflow):
        events = []
        async for event in orchestrator.run("Research query"):
            events.append(event)

        # Check for progress/normal events
        assert any("Working..." in e.message for e in events)

        # Check for timeout completion
        completion_events = [e for e in events if e.type == "complete"]
        assert len(completion_events) > 0
        last_event = completion_events[-1]
        assert "timed out" in last_event.message
        assert last_event.data.get("reason") == "timeout"