Spaces:
Running
Running
| """Pytest fixtures and configuration.""" | |
| import sys | |
| from pathlib import Path | |
| import pytest | |
| import torch | |
| sys.path.insert(0, str(Path(__file__).parent.parent / "src")) | |
| import rustworkx as rx | |
| def empty_digraph(): | |
| """Empty directed graph.""" | |
| return rx.PyDiGraph() | |
| def simple_digraph(): | |
| """Simple directed graph with 3 nodes.""" | |
| g = rx.PyDiGraph() | |
| g.add_node({"id": "a", "role": "agent"}) | |
| g.add_node({"id": "b", "role": "agent"}) | |
| g.add_node({"id": "c", "role": "agent"}) | |
| g.add_edge(0, 1, {"weight": 0.8}) | |
| g.add_edge(1, 2, {"weight": 0.9}) | |
| return g | |
| def cyclic_digraph(): | |
| """Graph with a cycle.""" | |
| g = rx.PyDiGraph() | |
| g.add_node({"id": "a"}) | |
| g.add_node({"id": "b"}) | |
| g.add_node({"id": "c"}) | |
| g.add_edge(0, 1, {"weight": 1.0}) | |
| g.add_edge(1, 2, {"weight": 1.0}) | |
| g.add_edge(2, 0, {"weight": 1.0}) # Cycle | |
| return g | |
| def parallel_digraph(): | |
| """Graph with parallel branches.""" | |
| g = rx.PyDiGraph() | |
| # Task -> A, B (parallel) -> C | |
| g.add_node({"id": "task", "role": "task"}) | |
| g.add_node({"id": "a", "role": "agent"}) | |
| g.add_node({"id": "b", "role": "agent"}) | |
| g.add_node({"id": "c", "role": "agent"}) | |
| g.add_edge(0, 1, {"weight": 0.5}) | |
| g.add_edge(0, 2, {"weight": 0.5}) | |
| g.add_edge(1, 3, {"weight": 0.8}) | |
| g.add_edge(2, 3, {"weight": 0.8}) | |
| return g | |
| def sample_agent_configs(): | |
| """Sample agent configurations.""" | |
| return [ | |
| { | |
| "agent_id": "coordinator", | |
| "role_name": "Coordinator", | |
| "persona": "Manages workflow", | |
| "state": {}, | |
| }, | |
| { | |
| "agent_id": "researcher", | |
| "role_name": "Researcher", | |
| "persona": "Finds information", | |
| "state": {}, | |
| }, | |
| { | |
| "agent_id": "writer", | |
| "role_name": "Writer", | |
| "persona": "Creates content", | |
| "state": {}, | |
| }, | |
| ] | |
| def sample_connections(): | |
| """Sample agent connections.""" | |
| return { | |
| "coordinator": ["researcher", "writer"], | |
| "researcher": ["writer"], | |
| "writer": [], | |
| } | |
| def mock_agent_callable(): | |
| """Mock callable for an agent.""" | |
| async def agent_fn(query: str, context: dict) -> str: | |
| del context # Unused in mock | |
| return f"Response to: {query}" | |
| return agent_fn | |
| def mock_failing_agent(): | |
| """Mock agent that always fails.""" | |
| async def failing_agent(query: str, context: dict) -> str: | |
| del query, context # Unused in mock | |
| msg = "Agent failure" | |
| raise RuntimeError(msg) | |
| return failing_agent | |
| def mock_slow_agent(): | |
| """Mock slow agent.""" | |
| import asyncio | |
| async def slow_agent(query: str, context: dict) -> str: | |
| del query, context # Unused in mock | |
| await asyncio.sleep(5.0) # Intentionally slow | |
| return "Slow response" | |
| return slow_agent | |
| def create_role_graph(node_ids, connections): | |
| """Helper for creating a RoleGraph.""" | |
| from core.agent import AgentProfile | |
| from core.graph import RoleGraph | |
| g = rx.PyDiGraph() | |
| id_to_idx = {} | |
| agents = [] | |
| for nid in node_ids: | |
| idx = g.add_node({"id": nid}) | |
| id_to_idx[nid] = idx | |
| agent = AgentProfile(agent_id=nid, display_name=f"Agent {nid.upper()}") | |
| agents.append(agent) | |
| for src, targets in connections.items(): | |
| for tgt in targets: | |
| if src in id_to_idx and tgt in id_to_idx: | |
| g.add_edge(id_to_idx[src], id_to_idx[tgt], {"weight": 1.0}) | |
| n = len(node_ids) | |
| a_com = torch.zeros((n, n), dtype=torch.float32) | |
| for src, targets in connections.items(): | |
| i = node_ids.index(src) | |
| for tgt in targets: | |
| if tgt in node_ids: | |
| j = node_ids.index(tgt) | |
| a_com[i, j] = 1.0 | |
| graph = RoleGraph( | |
| node_ids=node_ids, | |
| role_connections=connections, | |
| graph=g, | |
| A_com=a_com, | |
| ) | |
| graph.agents = agents | |
| return graph | |