File size: 4,142 Bytes
3193174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
"""Pytest fixtures and configuration."""

import sys
from pathlib import Path

import pytest
import torch

sys.path.insert(0, str(Path(__file__).parent.parent / "src"))

import rustworkx as rx


@pytest.fixture
def empty_digraph():
    """Empty directed graph."""
    return rx.PyDiGraph()


@pytest.fixture
def simple_digraph():
    """Simple directed graph with 3 nodes."""
    g = rx.PyDiGraph()
    g.add_node({"id": "a", "role": "agent"})
    g.add_node({"id": "b", "role": "agent"})
    g.add_node({"id": "c", "role": "agent"})
    g.add_edge(0, 1, {"weight": 0.8})
    g.add_edge(1, 2, {"weight": 0.9})
    return g


@pytest.fixture
def cyclic_digraph():
    """Graph with a cycle."""
    g = rx.PyDiGraph()
    g.add_node({"id": "a"})
    g.add_node({"id": "b"})
    g.add_node({"id": "c"})
    g.add_edge(0, 1, {"weight": 1.0})
    g.add_edge(1, 2, {"weight": 1.0})
    g.add_edge(2, 0, {"weight": 1.0})  # Cycle
    return g


@pytest.fixture
def parallel_digraph():
    """Graph with parallel branches."""
    g = rx.PyDiGraph()
    # Task -> A, B (parallel) -> C
    g.add_node({"id": "task", "role": "task"})
    g.add_node({"id": "a", "role": "agent"})
    g.add_node({"id": "b", "role": "agent"})
    g.add_node({"id": "c", "role": "agent"})
    g.add_edge(0, 1, {"weight": 0.5})
    g.add_edge(0, 2, {"weight": 0.5})
    g.add_edge(1, 3, {"weight": 0.8})
    g.add_edge(2, 3, {"weight": 0.8})
    return g


@pytest.fixture
def sample_agent_configs():
    """Sample agent configurations."""
    return [
        {
            "agent_id": "coordinator",
            "role_name": "Coordinator",
            "persona": "Manages workflow",
            "state": {},
        },
        {
            "agent_id": "researcher",
            "role_name": "Researcher",
            "persona": "Finds information",
            "state": {},
        },
        {
            "agent_id": "writer",
            "role_name": "Writer",
            "persona": "Creates content",
            "state": {},
        },
    ]


@pytest.fixture
def sample_connections():
    """Sample agent connections."""
    return {
        "coordinator": ["researcher", "writer"],
        "researcher": ["writer"],
        "writer": [],
    }


@pytest.fixture
def mock_agent_callable():
    """Mock callable for an agent."""

    async def agent_fn(query: str, context: dict) -> str:
        del context  # Unused in mock
        return f"Response to: {query}"

    return agent_fn


@pytest.fixture
def mock_failing_agent():
    """Mock agent that always fails."""

    async def failing_agent(query: str, context: dict) -> str:
        del query, context  # Unused in mock
        msg = "Agent failure"
        raise RuntimeError(msg)

    return failing_agent


@pytest.fixture
def mock_slow_agent():
    """Mock slow agent."""
    import asyncio

    async def slow_agent(query: str, context: dict) -> str:
        del query, context  # Unused in mock
        await asyncio.sleep(5.0)  # Intentionally slow
        return "Slow response"

    return slow_agent


def create_role_graph(node_ids, connections):
    """Helper for creating a RoleGraph."""
    from core.agent import AgentProfile
    from core.graph import RoleGraph

    g = rx.PyDiGraph()

    id_to_idx = {}
    agents = []
    for nid in node_ids:
        idx = g.add_node({"id": nid})
        id_to_idx[nid] = idx
        agent = AgentProfile(agent_id=nid, display_name=f"Agent {nid.upper()}")
        agents.append(agent)

    for src, targets in connections.items():
        for tgt in targets:
            if src in id_to_idx and tgt in id_to_idx:
                g.add_edge(id_to_idx[src], id_to_idx[tgt], {"weight": 1.0})

    n = len(node_ids)
    a_com = torch.zeros((n, n), dtype=torch.float32)
    for src, targets in connections.items():
        i = node_ids.index(src)
        for tgt in targets:
            if tgt in node_ids:
                j = node_ids.index(tgt)
                a_com[i, j] = 1.0

    graph = RoleGraph(
        node_ids=node_ids,
        role_connections=connections,
        graph=g,
        A_com=a_com,
    )
    graph.agents = agents
    return graph