File size: 3,869 Bytes
d36ce3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""Tests for Magentic Orchestrator termination guarantee."""

from unittest.mock import MagicMock, patch

import pytest
from agent_framework import MagenticAgentMessageEvent

from src.orchestrator_magentic import MagenticOrchestrator
from src.utils.models import AgentEvent

# Skip tests if agent_framework is not installed
pytest.importorskip("agent_framework")


class MockChatMessage:
    def __init__(self, content):
        self.content = content
        self.role = "assistant"

    @property
    def text(self):
        return self.content


@pytest.fixture
def mock_magentic_requirements():
    """Mock requirements check."""
    with patch("src.orchestrator_magentic.check_magentic_requirements"):
        yield


@pytest.mark.asyncio
async def test_termination_event_emitted_on_stream_end(mock_magentic_requirements):
    """
    Verify that a termination event is emitted when the workflow stream ends
    without a MagenticFinalResultEvent (e.g. max rounds reached).
    """
    orchestrator = MagenticOrchestrator(max_rounds=2)

    # Use real event class
    mock_message = MockChatMessage("Thinking...")
    mock_agent_event = MagenticAgentMessageEvent(agent_id="SearchAgent", message=mock_message)

    # Mock the workflow and its run_stream method
    mock_workflow = MagicMock()

    # Create an async generator for run_stream
    async def mock_stream(task):
        # Yield the real message event
        yield mock_agent_event
        # STOP HERE - No FinalResultEvent

    mock_workflow.run_stream = mock_stream

    # Mock _build_workflow to return our mock workflow
    with patch.object(orchestrator, "_build_workflow", return_value=mock_workflow):
        events = []
        async for event in orchestrator.run("Research query"):
            events.append(event)

        for i, e in enumerate(events):
            print(f"Event {i}: {e.type} - {e.message}")

        assert len(events) >= 2
        assert events[0].type == "started"

        # Verify the message event was processed
        # Depending on _process_event logic, MagenticAgentMessageEvent might map to different types
        # We assume it maps to something valid or we just check presence.
        assert any("Thinking..." in e.message for e in events)

        # THE CRITICAL CHECK: Did we get the fallback termination event?
        last_event = events[-1]
        assert last_event.type == "complete"
        assert "Max iterations reached" in last_event.message
        assert last_event.data.get("reason") == "max_rounds_reached"


@pytest.mark.asyncio
async def test_no_double_termination_event(mock_magentic_requirements):
    """
    Verify that we DO NOT emit a fallback event if the workflow finished normally.
    """
    orchestrator = MagenticOrchestrator()

    mock_workflow = MagicMock()

    with patch.object(orchestrator, "_build_workflow", return_value=mock_workflow):
        # Mock _process_event to simulate a natural completion event
        with patch.object(orchestrator, "_process_event") as mock_process:
            mock_process.side_effect = [
                AgentEvent(type="thinking", message="Working...", iteration=1),
                AgentEvent(type="complete", message="Done!", iteration=2),
            ]

            async def mock_stream_with_yields(task):
                yield "raw_event_1"
                yield "raw_event_2"

            mock_workflow.run_stream = mock_stream_with_yields

            events = []
            async for event in orchestrator.run("Research query"):
                events.append(event)

            assert events[-1].message == "Done!"
            assert events[-1].type == "complete"

            # Verify we didn't get a SECOND "Max iterations reached" event
            fallback_events = [e for e in events if "Max iterations reached" in e.message]
            assert len(fallback_events) == 0