Spaces:
Running
Running
| """Tests for src/core/visualization.py — GraphVisualizer and helper functions.""" | |
| import tempfile | |
| from pathlib import Path | |
| import pytest | |
| import rustworkx as rx | |
| import torch | |
| from core.graph import RoleGraph | |
| from core.visualization import ( | |
| EdgeStyle, | |
| GraphVisualizer, | |
| ImageFormat, | |
| MermaidDirection, | |
| NodeShape, | |
| NodeStyle, | |
| VisualizationStyle, | |
| print_graph, | |
| to_ascii, | |
| to_dot, | |
| to_mermaid, | |
| ) | |
| # ============================================================================ | |
| # Helpers | |
| # ============================================================================ | |
| def _make_simple_graph(num_nodes: int = 2, add_edge: bool = True) -> RoleGraph: | |
| """Create a simple RoleGraph for testing.""" | |
| from core.agent import AgentProfile | |
| ids = [chr(ord("a") + i) for i in range(num_nodes)] | |
| agents = [AgentProfile(agent_id=aid, display_name=f"Agent {aid.upper()}") for aid in ids] | |
| g = rx.PyDiGraph() | |
| idx_map = {} | |
| for aid in ids: | |
| idx_map[aid] = g.add_node({"id": aid}) | |
| if add_edge and num_nodes >= 2: | |
| g.add_edge(idx_map[ids[0]], idx_map[ids[1]], {"weight": 0.8}) | |
| a_com = torch.zeros((num_nodes, num_nodes)) | |
| if add_edge and num_nodes >= 2: | |
| a_com[0, 1] = 0.8 | |
| role_connections = {aid: [] for aid in ids} | |
| if add_edge and num_nodes >= 2: | |
| role_connections[ids[0]] = [ids[1]] | |
| graph = RoleGraph( | |
| node_ids=ids, | |
| role_connections=role_connections, | |
| graph=g, | |
| A_com=a_com, | |
| ) | |
| graph.agents = agents | |
| return graph | |
| def _make_graph_with_tools() -> RoleGraph: | |
| """Create a graph with agents that have tools.""" | |
| from core.agent import AgentProfile | |
| agent_a = AgentProfile( | |
| agent_id="researcher", | |
| display_name="Researcher", | |
| tools=["web_search", "file_search", "code_exec"], | |
| description="A research agent", | |
| ) | |
| agent_b = AgentProfile( | |
| agent_id="writer", | |
| display_name="Writer", | |
| tools=["text_tool"], | |
| description="A writing agent", | |
| ) | |
| g = rx.PyDiGraph() | |
| g.add_node({"id": "researcher"}) | |
| g.add_node({"id": "writer"}) | |
| g.add_edge(0, 1, {"weight": 1.0}) | |
| graph = RoleGraph( | |
| node_ids=["researcher", "writer"], | |
| role_connections={"researcher": ["writer"], "writer": []}, | |
| graph=g, | |
| A_com=torch.tensor([[0.0, 1.0], [0.0, 0.0]]), | |
| ) | |
| graph.agents = [agent_a, agent_b] | |
| return graph | |
| # ============================================================================ | |
| # Tests for enums and models | |
| # ============================================================================ | |
| class TestEnumsAndModels: | |
| def test_mermaid_direction_values(self): | |
| assert MermaidDirection.TOP_BOTTOM.value == "TB" | |
| assert MermaidDirection.LEFT_RIGHT.value == "LR" | |
| assert MermaidDirection.BOTTOM_TOP.value == "BT" | |
| assert MermaidDirection.RIGHT_LEFT.value == "RL" | |
| def test_image_format_values(self): | |
| assert ImageFormat.PNG.value == "png" | |
| assert ImageFormat.SVG.value == "svg" | |
| assert ImageFormat.PDF.value == "pdf" | |
| def test_image_format_from_path_png(self): | |
| fmt = ImageFormat.from_path("graph.png") | |
| assert fmt == ImageFormat.PNG | |
| def test_image_format_from_path_svg(self): | |
| fmt = ImageFormat.from_path("diagram.svg") | |
| assert fmt == ImageFormat.SVG | |
| def test_image_format_from_path_jpeg(self): | |
| fmt = ImageFormat.from_path("image.jpeg") | |
| assert fmt == ImageFormat.JPEG | |
| def test_image_format_from_path_unknown_defaults_to_png(self): | |
| fmt = ImageFormat.from_path("graph.xyz") | |
| assert fmt == ImageFormat.PNG | |
| def test_image_format_from_path_no_extension(self): | |
| fmt = ImageFormat.from_path("output") | |
| assert fmt == ImageFormat.PNG | |
| def test_node_style_defaults(self): | |
| style = NodeStyle() | |
| assert style.shape == NodeShape.ROUND | |
| assert style.fill_color.startswith("#") | |
| assert style.stroke_color.startswith("#") | |
| def test_edge_style_defaults(self): | |
| style = EdgeStyle() | |
| assert style.line_style == "solid" | |
| assert style.arrow_head == "normal" | |
| def test_visualization_style_defaults(self): | |
| style = VisualizationStyle() | |
| assert style.direction == MermaidDirection.TOP_BOTTOM | |
| assert style.show_tools is True | |
| assert style.show_weights is False | |
| # ============================================================================ | |
| # Tests for GraphVisualizer | |
| # ============================================================================ | |
| class TestGraphVisualizerInit: | |
| def test_default_style(self): | |
| graph = _make_simple_graph() | |
| viz = GraphVisualizer(graph) | |
| assert viz.graph is graph | |
| assert isinstance(viz.style, VisualizationStyle) | |
| def test_custom_style(self): | |
| graph = _make_simple_graph() | |
| style = VisualizationStyle(direction=MermaidDirection.LEFT_RIGHT) | |
| viz = GraphVisualizer(graph, style=style) | |
| assert viz.style.direction == MermaidDirection.LEFT_RIGHT | |
| class TestToMermaid: | |
| def test_basic_output(self): | |
| graph = _make_simple_graph() | |
| viz = GraphVisualizer(graph) | |
| mermaid = viz.to_mermaid() | |
| assert "flowchart" in mermaid | |
| assert "TB" in mermaid | |
| def test_with_title(self): | |
| graph = _make_simple_graph() | |
| viz = GraphVisualizer(graph) | |
| mermaid = viz.to_mermaid(title="My Graph") | |
| assert "title: My Graph" in mermaid | |
| assert "---" in mermaid | |
| def test_different_direction(self): | |
| graph = _make_simple_graph() | |
| viz = GraphVisualizer(graph) | |
| mermaid = viz.to_mermaid(direction=MermaidDirection.LEFT_RIGHT) | |
| assert "LR" in mermaid | |
| def test_nodes_in_output(self): | |
| graph = _make_simple_graph() | |
| viz = GraphVisualizer(graph) | |
| mermaid = viz.to_mermaid() | |
| assert "Agent A" in mermaid or "a(" in mermaid | |
| def test_edges_in_output(self): | |
| graph = _make_simple_graph() | |
| viz = GraphVisualizer(graph) | |
| mermaid = viz.to_mermaid() | |
| assert "-->" in mermaid | |
| def test_with_weights(self): | |
| graph = _make_simple_graph() | |
| style = VisualizationStyle(show_weights=True) | |
| viz = GraphVisualizer(graph, style=style) | |
| mermaid = viz.to_mermaid() | |
| # Weight 0.8 should appear since it's != 1.0 | |
| assert "w=0.80" in mermaid | |
| def test_empty_graph(self): | |
| graph = RoleGraph() | |
| viz = GraphVisualizer(graph) | |
| mermaid = viz.to_mermaid() | |
| assert "flowchart" in mermaid | |
| def test_with_tools(self): | |
| graph = _make_graph_with_tools() | |
| style = VisualizationStyle(show_tools=True) | |
| viz = GraphVisualizer(graph, style=style) | |
| mermaid = viz.to_mermaid() | |
| assert "web_search" in mermaid | |
| def test_classdefs_included(self): | |
| graph = _make_simple_graph() | |
| viz = GraphVisualizer(graph) | |
| mermaid = viz.to_mermaid() | |
| assert "classDef" in mermaid | |
| class TestToAscii: | |
| def test_basic_output(self): | |
| graph = _make_simple_graph() | |
| viz = GraphVisualizer(graph) | |
| ascii_out = viz.to_ascii() | |
| assert "╔" in ascii_out | |
| assert "Graph" in ascii_out | |
| def test_without_edges(self): | |
| graph = _make_simple_graph() | |
| viz = GraphVisualizer(graph) | |
| ascii_no_edges = viz.to_ascii(show_edges=False) | |
| assert "Edges:" not in ascii_no_edges | |
| def test_with_edges(self): | |
| graph = _make_simple_graph() | |
| viz = GraphVisualizer(graph) | |
| ascii_out = viz.to_ascii(show_edges=True) | |
| assert "Edges:" in ascii_out | |
| def test_shows_tools(self): | |
| graph = _make_graph_with_tools() | |
| style = VisualizationStyle(show_tools=True) | |
| viz = GraphVisualizer(graph, style=style) | |
| ascii_out = viz.to_ascii() | |
| assert "🔧" in ascii_out | |
| def test_custom_box_width(self): | |
| graph = _make_simple_graph() | |
| viz = GraphVisualizer(graph) | |
| ascii_out = viz.to_ascii(box_width=30) | |
| assert "╔" in ascii_out | |
| def test_long_node_name_truncated(self): | |
| from core.agent import AgentProfile | |
| long_name_agent = AgentProfile( | |
| agent_id="short", | |
| display_name="A" * 50, # Very long name | |
| ) | |
| g = rx.PyDiGraph() | |
| g.add_node({"id": "short"}) | |
| graph = RoleGraph(node_ids=["short"], graph=g) | |
| graph.agents = [long_name_agent] | |
| viz = GraphVisualizer(graph) | |
| ascii_out = viz.to_ascii(box_width=20) | |
| assert "..." in ascii_out | |
| def test_many_edges_displayed(self): | |
| from core.agent import AgentProfile | |
| # Build a graph with many edges | |
| num_nodes = 15 | |
| ids = [f"node{i}" for i in range(num_nodes)] | |
| agents = [AgentProfile(agent_id=nid, display_name=nid) for nid in ids] | |
| g = rx.PyDiGraph() | |
| idx_map = {} | |
| for nid in ids: | |
| idx_map[nid] = g.add_node({"id": nid}) | |
| for i in range(num_nodes - 1): | |
| g.add_edge(idx_map[ids[i]], idx_map[ids[i + 1]], {"weight": 1.0}) | |
| A = torch.zeros((num_nodes, num_nodes)) | |
| for i in range(num_nodes - 1): | |
| A[i, i + 1] = 1.0 | |
| role_connections = {nid: [] for nid in ids} | |
| for i in range(num_nodes - 1): | |
| role_connections[ids[i]] = [ids[i + 1]] | |
| graph = RoleGraph(node_ids=ids, role_connections=role_connections, graph=g, A_com=A) | |
| graph.agents = agents | |
| viz = GraphVisualizer(graph) | |
| ascii_out = viz.to_ascii(show_edges=True) | |
| # ASCII output should contain edges section | |
| assert "Edges:" in ascii_out | |
| assert "→" in ascii_out | |
| class TestToDot: | |
| def test_basic_output(self): | |
| graph = _make_simple_graph() | |
| viz = GraphVisualizer(graph) | |
| dot = viz.to_dot() | |
| assert "digraph" in dot | |
| assert "AgentGraph" in dot | |
| assert "->" in dot | |
| def test_custom_graph_name(self): | |
| graph = _make_simple_graph() | |
| viz = GraphVisualizer(graph) | |
| dot = viz.to_dot(graph_name="MyCustomGraph") | |
| assert "MyCustomGraph" in dot | |
| def test_custom_rankdir(self): | |
| graph = _make_simple_graph() | |
| viz = GraphVisualizer(graph) | |
| dot = viz.to_dot(rankdir="LR") | |
| assert "rankdir=LR" in dot | |
| def test_with_dpi(self): | |
| graph = _make_simple_graph() | |
| viz = GraphVisualizer(graph) | |
| dot = viz.to_dot(dpi=150) | |
| assert "dpi=150" in dot | |
| def test_with_weights(self): | |
| graph = _make_simple_graph() | |
| style = VisualizationStyle(show_weights=True) | |
| viz = GraphVisualizer(graph, style=style) | |
| dot = viz.to_dot() | |
| assert "0.80" in dot | |
| def test_with_tools_in_label(self): | |
| graph = _make_graph_with_tools() | |
| style = VisualizationStyle(show_tools=True) | |
| viz = GraphVisualizer(graph, style=style) | |
| dot = viz.to_dot() | |
| assert "web_search" in dot | |
| def test_empty_graph(self): | |
| graph = RoleGraph() | |
| viz = GraphVisualizer(graph) | |
| dot = viz.to_dot() | |
| assert "digraph" in dot | |
| assert "}" in dot | |
| class TestToAdjacencyMatrix: | |
| def test_empty_graph(self): | |
| graph = RoleGraph() | |
| viz = GraphVisualizer(graph) | |
| result = viz.to_adjacency_matrix() | |
| # Empty graph returns a string (the "Empty adjacency matrix" check in source | |
| # uses a_com.size == 0 which compares a method reference, so falls through | |
| # to produce an empty header string) | |
| assert isinstance(result, str) | |
| def test_basic_matrix(self): | |
| from core.agent import AgentProfile | |
| agents = [AgentProfile(agent_id=aid, display_name=aid) for aid in ["a", "b"]] | |
| graph = RoleGraph( | |
| node_ids=["a", "b"], | |
| A_com=torch.tensor([[0.0, 0.8], [0.0, 0.0]]), | |
| ) | |
| graph.agents = agents | |
| viz = GraphVisualizer(graph) | |
| result = viz.to_adjacency_matrix() | |
| assert "0.80" in result | |
| def test_without_labels(self): | |
| from core.agent import AgentProfile | |
| agents = [AgentProfile(agent_id=aid, display_name=aid) for aid in ["a", "b"]] | |
| graph = RoleGraph( | |
| node_ids=["a", "b"], | |
| A_com=torch.tensor([[0.0, 0.5], [0.0, 0.0]]), | |
| ) | |
| graph.agents = agents | |
| viz = GraphVisualizer(graph) | |
| result = viz.to_adjacency_matrix(show_labels=False) | |
| # Should not include headers when show_labels=False | |
| assert "0.50" in result | |
| class TestSafeId: | |
| def test_simple_id(self): | |
| graph = RoleGraph() | |
| viz = GraphVisualizer(graph) | |
| assert viz._safe_id("agent") == "agent" | |
| def test_id_with_hyphens(self): | |
| graph = RoleGraph() | |
| viz = GraphVisualizer(graph) | |
| assert viz._safe_id("agent-name") == "agent_name" | |
| def test_id_with_spaces(self): | |
| graph = RoleGraph() | |
| viz = GraphVisualizer(graph) | |
| assert viz._safe_id("my agent") == "my_agent" | |
| def test_id_starting_with_digit(self): | |
| graph = RoleGraph() | |
| viz = GraphVisualizer(graph) | |
| result = viz._safe_id("123agent") | |
| assert result.startswith("n_") | |
| def test_id_with_dots(self): | |
| graph = RoleGraph() | |
| viz = GraphVisualizer(graph) | |
| result = viz._safe_id("agent.name") | |
| assert "." not in result | |
| def test_empty_id(self): | |
| graph = RoleGraph() | |
| viz = GraphVisualizer(graph) | |
| result = viz._safe_id("") | |
| assert result == "unknown" | |
| class TestFormatNodeLabel: | |
| def test_basic_label(self): | |
| from core.agent import AgentProfile | |
| agent = AgentProfile(agent_id="a", display_name="My Agent") | |
| graph = RoleGraph() | |
| viz = GraphVisualizer(graph) | |
| label = viz._format_node_label(agent, viz.style.agent_style) | |
| assert "My Agent" in label | |
| def test_long_name_truncated(self): | |
| from core.agent import AgentProfile | |
| agent = AgentProfile(agent_id="a", display_name="A" * 50) | |
| graph = RoleGraph() | |
| style = VisualizationStyle(max_label_length=20) | |
| viz = GraphVisualizer(graph, style=style) | |
| label = viz._format_node_label(agent, viz.style.agent_style) | |
| assert "..." in label | |
| def test_with_tools(self): | |
| from core.agent import AgentProfile | |
| agent = AgentProfile(agent_id="a", display_name="Agent", tools=["tool1", "tool2", "tool3"]) | |
| graph = RoleGraph() | |
| style = VisualizationStyle(show_tools=True) | |
| viz = GraphVisualizer(graph, style=style) | |
| label = viz._format_node_label(agent, viz.style.agent_style) | |
| assert "tool1" in label | |
| def test_with_many_tools_shows_ellipsis(self): | |
| from core.agent import AgentProfile | |
| agent = AgentProfile(agent_id="a", display_name="Agent", tools=["t1", "t2", "t3", "t4"]) | |
| graph = RoleGraph() | |
| style = VisualizationStyle(show_tools=True) | |
| viz = GraphVisualizer(graph, style=style) | |
| label = viz._format_node_label(agent, viz.style.agent_style) | |
| assert "..." in label | |
| class TestSaveMermaid: | |
| def test_save_mermaid_file(self, tmp_path): | |
| graph = _make_simple_graph() | |
| viz = GraphVisualizer(graph) | |
| filepath = tmp_path / "test.mmd" | |
| viz.save_mermaid(filepath) | |
| content = filepath.read_text() | |
| assert "flowchart" in content | |
| def test_save_mermaid_md_file(self, tmp_path): | |
| graph = _make_simple_graph() | |
| viz = GraphVisualizer(graph) | |
| filepath = tmp_path / "test.md" | |
| viz.save_mermaid(filepath, title="Test Graph") | |
| content = filepath.read_text() | |
| assert "```mermaid" in content | |
| assert "```" in content | |
| def test_save_dot_file(self, tmp_path): | |
| graph = _make_simple_graph() | |
| viz = GraphVisualizer(graph) | |
| filepath = tmp_path / "test.dot" | |
| viz.save_dot(filepath) | |
| content = filepath.read_text() | |
| assert "digraph" in content | |
| class TestRenderImage: | |
| def test_render_raises_import_error_without_graphviz(self, tmp_path): | |
| import sys | |
| import unittest.mock as mock | |
| graph = _make_simple_graph() | |
| viz = GraphVisualizer(graph) | |
| # Mock the import to simulate graphviz not being installed | |
| with mock.patch.dict(sys.modules, {"graphviz": None}): | |
| with pytest.raises((ImportError, Exception)): | |
| viz.render_image(tmp_path / "output.png") | |
| def test_show_interactive_raises_without_graphviz(self): | |
| import sys | |
| import unittest.mock as mock | |
| graph = _make_simple_graph() | |
| viz = GraphVisualizer(graph) | |
| with mock.patch.dict(sys.modules, {"graphviz": None}): | |
| with pytest.raises((ImportError, Exception)): | |
| viz.show_interactive() | |
| class TestPrintColored: | |
| def test_print_colored_no_error(self): | |
| """print_colored should not raise even if rich is unavailable.""" | |
| graph = _make_simple_graph() | |
| viz = GraphVisualizer(graph) | |
| # Should not raise; may silently fail if rich is not installed | |
| try: | |
| viz.print_colored() | |
| except Exception: | |
| pass # Acceptable if rich not installed | |
| # ============================================================================ | |
| # Tests for module-level convenience functions | |
| # ============================================================================ | |
| class TestConvenienceFunctions: | |
| def test_to_mermaid_function(self): | |
| graph = _make_simple_graph() | |
| result = to_mermaid(graph) | |
| assert "flowchart" in result | |
| def test_to_mermaid_with_direction(self): | |
| graph = _make_simple_graph() | |
| result = to_mermaid(graph, direction=MermaidDirection.LEFT_RIGHT) | |
| assert "LR" in result | |
| def test_to_mermaid_with_title(self): | |
| graph = _make_simple_graph() | |
| result = to_mermaid(graph, title="Test") | |
| assert "title: Test" in result | |
| def test_to_mermaid_with_custom_style(self): | |
| graph = _make_simple_graph() | |
| style = VisualizationStyle(show_weights=True) | |
| result = to_mermaid(graph, style=style) | |
| assert "flowchart" in result | |
| def test_to_ascii_function(self): | |
| graph = _make_simple_graph() | |
| result = to_ascii(graph) | |
| assert "╔" in result | |
| def test_to_ascii_no_edges(self): | |
| graph = _make_simple_graph() | |
| result = to_ascii(graph, show_edges=False) | |
| assert "Edges:" not in result | |
| def test_to_dot_function(self): | |
| graph = _make_simple_graph() | |
| result = to_dot(graph) | |
| assert "digraph" in result | |
| def test_to_dot_custom_name(self): | |
| graph = _make_simple_graph() | |
| result = to_dot(graph, graph_name="Custom") | |
| assert "Custom" in result | |
| def test_print_graph_auto_format(self): | |
| """print_graph with auto format should not raise.""" | |
| import sys | |
| from unittest.mock import MagicMock | |
| graph = _make_simple_graph() | |
| mock_console = MagicMock() | |
| mock_tree = MagicMock() | |
| mock_branch = MagicMock() | |
| mock_tree.add.return_value = mock_branch | |
| mock_branch.add.return_value = MagicMock() | |
| mock_table = MagicMock() | |
| mock_rich_console = MagicMock() | |
| mock_rich_console.Console.return_value = mock_console | |
| mock_rich_tree = MagicMock() | |
| mock_rich_tree.Tree.return_value = mock_tree | |
| mock_rich_table = MagicMock() | |
| mock_rich_table.Table.return_value = mock_table | |
| with pytest.MonkeyPatch.context() as mp: | |
| mp.setitem(sys.modules, "rich.console", mock_rich_console) | |
| mp.setitem(sys.modules, "rich.table", mock_rich_table) | |
| mp.setitem(sys.modules, "rich.tree", mock_rich_tree) | |
| mp.setitem(sys.modules, "rich", MagicMock()) | |
| print_graph(graph, output_format="auto") # Should call print_colored via rich | |
| def test_print_graph_colored_format(self): | |
| """print_graph with colored format should not raise.""" | |
| graph = _make_simple_graph() | |
| print_graph(graph, output_format="colored") | |
| def test_print_graph_ascii_format(self): | |
| """print_graph with ascii format should not raise.""" | |
| graph = _make_simple_graph() | |
| print_graph(graph, output_format="ascii") | |
| def test_print_graph_mermaid_format(self): | |
| """print_graph with mermaid format should not raise.""" | |
| graph = _make_simple_graph() | |
| print_graph(graph, output_format="mermaid") | |
| class TestEdgeShortNameTruncation: | |
| """Test ASCII representation edge source/target name truncation.""" | |
| def test_long_source_name_truncated_in_ascii(self): | |
| from core.agent import AgentProfile | |
| agent_a = AgentProfile( | |
| agent_id="very_long_source_name_here", | |
| display_name="Very Long Source", | |
| ) | |
| agent_b = AgentProfile( | |
| agent_id="b", | |
| display_name="B", | |
| ) | |
| g = rx.PyDiGraph() | |
| g.add_node({"id": "very_long_source_name_here"}) | |
| g.add_node({"id": "b"}) | |
| g.add_edge(0, 1, {"weight": 1.0}) | |
| graph = RoleGraph( | |
| node_ids=["very_long_source_name_here", "b"], | |
| role_connections={"very_long_source_name_here": ["b"], "b": []}, | |
| graph=g, | |
| A_com=torch.tensor([[0.0, 1.0], [0.0, 0.0]]), | |
| ) | |
| graph.agents = [agent_a, agent_b] | |
| viz = GraphVisualizer(graph) | |
| ascii_out = viz.to_ascii(show_edges=True) | |
| # Long source name should be truncated in edges section | |
| assert ".." in ascii_out | |
| class TestTaskNodeVisualization: | |
| """Test visualization of task nodes (diamond shape).""" | |
| def test_task_node_in_mermaid(self): | |
| from core.agent import AgentProfile | |
| class TaskAgent(AgentProfile): | |
| type: str = "task" | |
| task_agent = TaskAgent(agent_id="task1", display_name="My Task") | |
| regular_agent = AgentProfile(agent_id="agent1", display_name="My Agent") | |
| g = rx.PyDiGraph() | |
| g.add_node({"id": "task1"}) | |
| g.add_node({"id": "agent1"}) | |
| g.add_edge(0, 1, {"type": "task", "weight": 1.0}) | |
| graph = RoleGraph( | |
| node_ids=["task1", "agent1"], | |
| role_connections={"task1": ["agent1"], "agent1": []}, | |
| graph=g, | |
| A_com=torch.tensor([[0.0, 1.0], [0.0, 0.0]]), | |
| ) | |
| graph.agents = [task_agent, regular_agent] | |
| viz = GraphVisualizer(graph) | |
| mermaid = viz.to_mermaid() | |
| # Task node should use diamond/special shape {label} | |
| assert "task1" in mermaid | |
| def test_task_edge_in_dot(self): | |
| from core.agent import AgentProfile | |
| agent_a = AgentProfile(agent_id="a", display_name="A") | |
| agent_b = AgentProfile(agent_id="b", display_name="B") | |
| g = rx.PyDiGraph() | |
| g.add_node({"id": "a"}) | |
| g.add_node({"id": "b"}) | |
| g.add_edge(0, 1, {"type": "task_edge", "weight": 1.0}) | |
| graph = RoleGraph( | |
| node_ids=["a", "b"], | |
| role_connections={"a": ["b"], "b": []}, | |
| graph=g, | |
| A_com=torch.tensor([[0.0, 1.0], [0.0, 0.0]]), | |
| ) | |
| graph.agents = [agent_a, agent_b] | |
| viz = GraphVisualizer(graph) | |
| dot = viz.to_dot() | |
| assert "dashed" in dot | |
| class TestRenderImage: | |
| """Test render_image and show_interactive (with mocked graphviz).""" | |
| def test_render_image_no_graphviz_raises(self): | |
| import sys | |
| graph = _make_simple_graph() | |
| viz = GraphVisualizer(graph) | |
| # Temporarily remove graphviz from modules if it exists | |
| graphviz_backup = sys.modules.get("graphviz") | |
| sys.modules["graphviz"] = None # type: ignore | |
| try: | |
| with pytest.raises(ImportError, match="[Gg]raphviz"): | |
| viz.render_image("test.png") | |
| finally: | |
| if graphviz_backup is not None: | |
| sys.modules["graphviz"] = graphviz_backup | |
| else: | |
| del sys.modules["graphviz"] | |
| def test_render_image_with_mock_graphviz(self, tmp_path): | |
| from unittest.mock import MagicMock, patch | |
| graph = _make_simple_graph() | |
| viz = GraphVisualizer(graph) | |
| mock_source = MagicMock() | |
| mock_graphviz = MagicMock() | |
| mock_graphviz.Source.return_value = mock_source | |
| with patch.dict("sys.modules", {"graphviz": mock_graphviz}): | |
| viz.render_image(tmp_path / "test.png") | |
| mock_source.render.assert_called_once() | |
| def test_render_image_with_explicit_format(self, tmp_path): | |
| from unittest.mock import MagicMock, patch | |
| graph = _make_simple_graph() | |
| viz = GraphVisualizer(graph) | |
| mock_source = MagicMock() | |
| mock_graphviz = MagicMock() | |
| mock_graphviz.Source.return_value = mock_source | |
| with patch.dict("sys.modules", {"graphviz": mock_graphviz}): | |
| viz.render_image(tmp_path / "test", image_format=ImageFormat.SVG) | |
| mock_source.render.assert_called_once() | |
| def test_render_image_with_dpi(self, tmp_path): | |
| from unittest.mock import MagicMock, patch | |
| graph = _make_simple_graph() | |
| viz = GraphVisualizer(graph) | |
| mock_source = MagicMock() | |
| mock_graphviz = MagicMock() | |
| mock_graphviz.Source.return_value = mock_source | |
| with patch.dict("sys.modules", {"graphviz": mock_graphviz}): | |
| viz.render_image(tmp_path / "test.png", dpi=300) | |
| mock_source.render.assert_called_once() | |
| def test_render_image_svg_ignores_dpi(self, tmp_path): | |
| from unittest.mock import MagicMock, patch | |
| graph = _make_simple_graph() | |
| viz = GraphVisualizer(graph) | |
| mock_source = MagicMock() | |
| mock_graphviz = MagicMock() | |
| mock_graphviz.Source.return_value = mock_source | |
| with patch.dict("sys.modules", {"graphviz": mock_graphviz}): | |
| viz.render_image(tmp_path / "test.svg", dpi=300) | |
| # For SVG, dpi should be None (ignored) | |
| mock_source.render.assert_called_once() | |
| def test_render_image_raises_on_render_error(self, tmp_path): | |
| from unittest.mock import MagicMock, patch | |
| graph = _make_simple_graph() | |
| viz = GraphVisualizer(graph) | |
| mock_source = MagicMock() | |
| mock_source.render.side_effect = Exception("render failed") | |
| mock_graphviz = MagicMock() | |
| mock_graphviz.Source.return_value = mock_source | |
| with patch.dict("sys.modules", {"graphviz": mock_graphviz}): | |
| with pytest.raises(RuntimeError, match="render failed"): | |
| viz.render_image(tmp_path / "test.png") | |
| def test_show_interactive_no_graphviz_raises(self): | |
| import sys | |
| graph = _make_simple_graph() | |
| viz = GraphVisualizer(graph) | |
| graphviz_backup = sys.modules.get("graphviz") | |
| sys.modules["graphviz"] = None # type: ignore | |
| try: | |
| with pytest.raises(ImportError, match="[Gg]raphviz"): | |
| viz.show_interactive() | |
| finally: | |
| if graphviz_backup is not None: | |
| sys.modules["graphviz"] = graphviz_backup | |
| else: | |
| del sys.modules["graphviz"] | |
| def test_show_interactive_with_mock_graphviz(self): | |
| from unittest.mock import MagicMock, patch | |
| graph = _make_simple_graph() | |
| viz = GraphVisualizer(graph) | |
| mock_source = MagicMock() | |
| mock_graphviz = MagicMock() | |
| mock_graphviz.Source.return_value = mock_source | |
| with patch.dict("sys.modules", {"graphviz": mock_graphviz}): | |
| viz.show_interactive() # Should not raise | |
| class TestPrintColored: | |
| """Test print_colored method.""" | |
| def test_print_colored_with_mocked_rich(self): | |
| """print_colored with mocked rich should cover lines 487-541.""" | |
| import sys | |
| from unittest.mock import MagicMock | |
| graph = _make_simple_graph() | |
| viz = GraphVisualizer(graph) | |
| mock_console = MagicMock() | |
| mock_tree = MagicMock() | |
| mock_branch = MagicMock() | |
| mock_tree.add.return_value = mock_branch | |
| mock_branch.add.return_value = MagicMock() | |
| mock_table = MagicMock() | |
| mock_rich_console = MagicMock() | |
| mock_rich_console.Console.return_value = mock_console | |
| mock_rich_tree = MagicMock() | |
| mock_rich_tree.Tree.return_value = mock_tree | |
| mock_rich_table = MagicMock() | |
| mock_rich_table.Table.return_value = mock_table | |
| with pytest.MonkeyPatch.context() as mp: | |
| mp.setitem(sys.modules, "rich.console", mock_rich_console) | |
| mp.setitem(sys.modules, "rich.table", mock_rich_table) | |
| mp.setitem(sys.modules, "rich.tree", mock_rich_tree) | |
| mp.setitem(sys.modules, "rich", MagicMock()) | |
| viz.print_colored() | |
| mock_console.print.assert_called() | |
| def test_print_colored_with_graph_with_tools_and_description(self): | |
| """Covers description + tools branches in print_colored (lines 503-511).""" | |
| import sys | |
| from unittest.mock import MagicMock | |
| from core.agent import AgentProfile | |
| agent = AgentProfile( | |
| agent_id="a", | |
| display_name="Agent A", | |
| description="A helpful agent " * 10, # long description | |
| tools=["tool1", "tool2"], | |
| ) | |
| import rustworkx as rx | |
| g = rx.PyDiGraph() | |
| g.add_node({"id": "a"}) | |
| graph = RoleGraph(node_ids=["a"], graph=g, A_com=torch.zeros((1, 1))) | |
| graph.agents = [agent] | |
| viz = GraphVisualizer(graph) | |
| mock_console = MagicMock() | |
| mock_tree = MagicMock() | |
| mock_branch = MagicMock() | |
| mock_node = MagicMock() | |
| mock_tree.add.return_value = mock_branch | |
| mock_branch.add.return_value = mock_node | |
| mock_node.add.return_value = MagicMock() | |
| mock_table = MagicMock() | |
| mock_rich_console = MagicMock() | |
| mock_rich_console.Console.return_value = mock_console | |
| mock_rich_tree = MagicMock() | |
| mock_rich_tree.Tree.return_value = mock_tree | |
| mock_rich_table = MagicMock() | |
| mock_rich_table.Table.return_value = mock_table | |
| with pytest.MonkeyPatch.context() as mp: | |
| mp.setitem(sys.modules, "rich.console", mock_rich_console) | |
| mp.setitem(sys.modules, "rich.table", mock_rich_table) | |
| mp.setitem(sys.modules, "rich.tree", mock_rich_tree) | |
| mp.setitem(sys.modules, "rich", MagicMock()) | |
| viz.print_colored() | |
| mock_console.print.assert_called() | |
| def test_print_colored_with_many_edges_shows_table(self): | |
| """Covers edge table section in print_colored (lines 522-541).""" | |
| import sys | |
| from unittest.mock import MagicMock | |
| from core.agent import AgentProfile | |
| n = 15 | |
| ids = [f"n{i}" for i in range(n)] | |
| agents = [AgentProfile(agent_id=nid, display_name=nid) for nid in ids] | |
| import rustworkx as rx | |
| g = rx.PyDiGraph() | |
| idx_map = {} | |
| for nid in ids: | |
| idx_map[nid] = g.add_node({"id": nid}) | |
| for i in range(n - 1): | |
| g.add_edge(idx_map[ids[i]], idx_map[ids[i + 1]], {"weight": 1.0, "source": ids[i], "target": ids[i + 1]}) | |
| A = torch.zeros((n, n)) | |
| for i in range(n - 1): | |
| A[i, i + 1] = 1.0 | |
| graph = RoleGraph(node_ids=ids, graph=g, A_com=A) | |
| graph.agents = agents | |
| viz = GraphVisualizer(graph) | |
| mock_console = MagicMock() | |
| mock_tree = MagicMock() | |
| mock_branch = MagicMock() | |
| mock_tree.add.return_value = mock_branch | |
| mock_branch.add.return_value = MagicMock() | |
| mock_table = MagicMock() | |
| mock_rich_console = MagicMock() | |
| mock_rich_console.Console.return_value = mock_console | |
| mock_rich_tree = MagicMock() | |
| mock_rich_tree.Tree.return_value = mock_tree | |
| mock_rich_table = MagicMock() | |
| mock_rich_table.Table.return_value = mock_table | |
| with pytest.MonkeyPatch.context() as mp: | |
| mp.setitem(sys.modules, "rich.console", mock_rich_console) | |
| mp.setitem(sys.modules, "rich.table", mock_rich_table) | |
| mp.setitem(sys.modules, "rich.tree", mock_rich_tree) | |
| mp.setitem(sys.modules, "rich", MagicMock()) | |
| viz.print_colored() | |
| # Should have printed tree + table | |
| assert mock_console.print.call_count >= 2 | |
| def test_print_colored_without_rich(self): | |
| """If rich is not available, print_colored falls back gracefully.""" | |
| import sys | |
| graph = _make_simple_graph() | |
| viz = GraphVisualizer(graph) | |
| rich_backup = sys.modules.get("rich") | |
| sys.modules["rich"] = None # type: ignore | |
| sys.modules["rich.console"] = None # type: ignore | |
| sys.modules["rich.table"] = None # type: ignore | |
| sys.modules["rich.tree"] = None # type: ignore | |
| try: | |
| viz.print_colored() # Should not raise (fallback) | |
| finally: | |
| if rich_backup is not None: | |
| sys.modules["rich"] = rich_backup | |
| else: | |
| sys.modules.pop("rich", None) | |
| sys.modules.pop("rich.console", None) | |
| sys.modules.pop("rich.table", None) | |
| sys.modules.pop("rich.tree", None) | |
| class TestSafeidDoubleUnderscore: | |
| """Test _safe_id with double underscore inputs.""" | |
| def test_double_underscore_removed(self): | |
| graph = _make_simple_graph() | |
| viz = GraphVisualizer(graph) | |
| result = viz._safe_id("agent__name__here") | |
| assert "__" not in result | |
| def test_leading_trailing_underscores_removed(self): | |
| graph = _make_simple_graph() | |
| viz = GraphVisualizer(graph) | |
| result = viz._safe_id("_agent_name_") | |
| assert not result.startswith("_") | |
| assert not result.endswith("_") | |
| def test_empty_string_returns_unknown(self): | |
| graph = _make_simple_graph() | |
| viz = GraphVisualizer(graph) | |
| result = viz._safe_id("") | |
| assert result == "unknown" | |
| def test_only_special_chars(self): | |
| graph = _make_simple_graph() | |
| viz = GraphVisualizer(graph) | |
| result = viz._safe_id("---") | |
| assert result == "unknown" | |
| class TestToMermaidEdgeDedup: | |
| """Test that duplicate edges are not added to mermaid output.""" | |
| def test_duplicate_edges_deduped(self): | |
| from core.agent import AgentProfile | |
| agent_a = AgentProfile(agent_id="a", display_name="A") | |
| agent_b = AgentProfile(agent_id="b", display_name="B") | |
| g = rx.PyDiGraph() | |
| g.add_node({"id": "a"}) | |
| g.add_node({"id": "b"}) | |
| g.add_edge(0, 1, {"weight": 1.0}) | |
| g.add_edge(0, 1, {"weight": 0.5}) # Duplicate | |
| graph = RoleGraph( | |
| node_ids=["a", "b"], | |
| role_connections={"a": ["b"], "b": []}, | |
| graph=g, | |
| A_com=torch.tensor([[0.0, 1.0], [0.0, 0.0]]), | |
| ) | |
| graph.agents = [agent_a, agent_b] | |
| viz = GraphVisualizer(graph) | |
| mermaid = viz.to_mermaid() | |
| # Count occurrences of a --> b or a -> b | |
| count = mermaid.count("a --> b") + mermaid.count("a->b") | |
| assert count <= 1 # Should be deduped | |
| class TestToMermaidEdgeEmptySourceTarget: | |
| """Test handling of empty source/target in mermaid output.""" | |
| def test_empty_source_edge_skipped(self): | |
| from core.agent import AgentProfile | |
| agent_a = AgentProfile(agent_id="a", display_name="A") | |
| g = rx.PyDiGraph() | |
| g.add_node({"id": "a"}) | |
| # Add edge with empty source/target | |
| g.add_edge(0, 0, {"source": "", "target": ""}) | |
| graph = RoleGraph( | |
| node_ids=["a"], | |
| role_connections={"a": []}, | |
| graph=g, | |
| A_com=torch.zeros((1, 1)), | |
| ) | |
| graph.agents = [agent_a] | |
| viz = GraphVisualizer(graph) | |
| mermaid = viz.to_mermaid() | |
| # Should not raise | |
| assert "flowchart" in mermaid | |
| class TestRenderToImageConvenience: | |
| """Test render_to_image and show_graph_interactive convenience functions.""" | |
| def test_render_to_image_with_mock(self, tmp_path): | |
| from unittest.mock import MagicMock, patch | |
| from core.visualization import render_to_image | |
| graph = _make_simple_graph() | |
| mock_source = MagicMock() | |
| mock_graphviz = MagicMock() | |
| mock_graphviz.Source.return_value = mock_source | |
| with patch.dict("sys.modules", {"graphviz": mock_graphviz}): | |
| render_to_image(graph, tmp_path / "test.png") | |
| mock_source.render.assert_called_once() | |
| def test_show_graph_interactive_with_mock(self): | |
| from unittest.mock import MagicMock, patch | |
| from core.visualization import show_graph_interactive | |
| graph = _make_simple_graph() | |
| mock_source = MagicMock() | |
| mock_graphviz = MagicMock() | |
| mock_graphviz.Source.return_value = mock_source | |
| with patch.dict("sys.modules", {"graphviz": mock_graphviz}): | |
| show_graph_interactive(graph) # Should not raise | |
| if __name__ == "__main__": | |
| pytest.main([__file__, "-v"]) | |