File size: 3,992 Bytes
d8d14f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import os
from dotenv import load_dotenv
from swarm_models import OpenAIChat
from swarms.structs.agent import Agent
from swarms.structs.groupchat import GroupChat, expertise_based


def setup_test_agents():
    model = OpenAIChat(
        openai_api_key=os.getenv("OPENAI_API_KEY"),
        model_name="gpt-4",
        temperature=0.1,
    )

    return [
        Agent(
            agent_name="Agent1",
            system_prompt="You only respond with 'A'",
            llm=model,
        ),
        Agent(
            agent_name="Agent2",
            system_prompt="You only respond with 'B'",
            llm=model,
        ),
        Agent(
            agent_name="Agent3",
            system_prompt="You only respond with 'C'",
            llm=model,
        ),
    ]


def test_round_robin_speaking():
    chat = GroupChat(agents=setup_test_agents())
    history = chat.run("Say your letter")

    # Verify agents speak in order
    responses = [
        r.message for t in history.turns for r in t.responses
    ]
    assert responses == ["A", "B", "C"] * (len(history.turns))


def test_concurrent_processing():
    chat = GroupChat(agents=setup_test_agents())
    tasks = ["Task1", "Task2", "Task3"]
    histories = chat.concurrent_run(tasks)

    assert len(histories) == len(tasks)
    for history in histories:
        assert history.total_messages > 0


def test_expertise_based_speaking():
    agents = setup_test_agents()
    chat = GroupChat(agents=agents, speaker_fn=expertise_based)

    # Test each agent's expertise trigger
    for agent in agents:
        history = chat.run(f"Trigger {agent.system_prompt}")
        first_response = history.turns[0].responses[0]
        assert first_response.agent_name == agent.agent_name


def test_max_turns_limit():
    max_turns = 3
    chat = GroupChat(agents=setup_test_agents(), max_turns=max_turns)
    history = chat.run("Test message")

    assert len(history.turns) == max_turns


def test_error_handling():
    broken_agent = Agent(
        agent_name="BrokenAgent",
        system_prompt="You raise errors",
        llm=None,
    )

    chat = GroupChat(agents=[broken_agent])
    history = chat.run("Trigger error")

    assert "Error" in history.turns[0].responses[0].message


def test_conversation_context():
    agents = setup_test_agents()
    complex_prompt = "Previous message refers to A. Now trigger B. Finally discuss C."

    chat = GroupChat(agents=agents, speaker_fn=expertise_based)
    history = chat.run(complex_prompt)

    responses = [
        r.agent_name for t in history.turns for r in t.responses
    ]
    assert all(agent.agent_name in responses for agent in agents)


def test_large_agent_group():
    large_group = setup_test_agents() * 5  # 15 agents
    chat = GroupChat(agents=large_group)
    history = chat.run("Test scaling")

    assert history.total_messages > len(large_group)


def test_long_conversations():
    chat = GroupChat(agents=setup_test_agents(), max_turns=50)
    history = chat.run("Long conversation test")

    assert len(history.turns) == 50
    assert history.total_messages > 100


def test_stress_batched_runs():
    chat = GroupChat(agents=setup_test_agents())
    tasks = ["Task"] * 100
    histories = chat.batched_run(tasks)

    assert len(histories) == len(tasks)
    total_messages = sum(h.total_messages for h in histories)
    assert total_messages > len(tasks) * 3


if __name__ == "__main__":
    load_dotenv()

    functions = [
        test_round_robin_speaking,
        test_concurrent_processing,
        test_expertise_based_speaking,
        test_max_turns_limit,
        test_error_handling,
        test_conversation_context,
        test_large_agent_group,
        test_long_conversations,
        test_stress_batched_runs,
    ]

    for func in functions:
        try:
            print(f"Running {func.__name__}...")
            func()
            print("✓ Passed")
        except Exception as e:
            print(f"✗ Failed: {str(e)}")