gMAS / tests /test_visualization.py
Артём Боярских
chore: initial commit
3193174
"""Tests for src/core/visualization.py — GraphVisualizer and helper functions."""
import tempfile
from pathlib import Path
import pytest
import rustworkx as rx
import torch
from core.graph import RoleGraph
from core.visualization import (
EdgeStyle,
GraphVisualizer,
ImageFormat,
MermaidDirection,
NodeShape,
NodeStyle,
VisualizationStyle,
print_graph,
to_ascii,
to_dot,
to_mermaid,
)
# ============================================================================
# Helpers
# ============================================================================
def _make_simple_graph(num_nodes: int = 2, add_edge: bool = True) -> RoleGraph:
"""Create a simple RoleGraph for testing."""
from core.agent import AgentProfile
ids = [chr(ord("a") + i) for i in range(num_nodes)]
agents = [AgentProfile(agent_id=aid, display_name=f"Agent {aid.upper()}") for aid in ids]
g = rx.PyDiGraph()
idx_map = {}
for aid in ids:
idx_map[aid] = g.add_node({"id": aid})
if add_edge and num_nodes >= 2:
g.add_edge(idx_map[ids[0]], idx_map[ids[1]], {"weight": 0.8})
a_com = torch.zeros((num_nodes, num_nodes))
if add_edge and num_nodes >= 2:
a_com[0, 1] = 0.8
role_connections = {aid: [] for aid in ids}
if add_edge and num_nodes >= 2:
role_connections[ids[0]] = [ids[1]]
graph = RoleGraph(
node_ids=ids,
role_connections=role_connections,
graph=g,
A_com=a_com,
)
graph.agents = agents
return graph
def _make_graph_with_tools() -> RoleGraph:
"""Create a graph with agents that have tools."""
from core.agent import AgentProfile
agent_a = AgentProfile(
agent_id="researcher",
display_name="Researcher",
tools=["web_search", "file_search", "code_exec"],
description="A research agent",
)
agent_b = AgentProfile(
agent_id="writer",
display_name="Writer",
tools=["text_tool"],
description="A writing agent",
)
g = rx.PyDiGraph()
g.add_node({"id": "researcher"})
g.add_node({"id": "writer"})
g.add_edge(0, 1, {"weight": 1.0})
graph = RoleGraph(
node_ids=["researcher", "writer"],
role_connections={"researcher": ["writer"], "writer": []},
graph=g,
A_com=torch.tensor([[0.0, 1.0], [0.0, 0.0]]),
)
graph.agents = [agent_a, agent_b]
return graph
# ============================================================================
# Tests for enums and models
# ============================================================================
class TestEnumsAndModels:
def test_mermaid_direction_values(self):
assert MermaidDirection.TOP_BOTTOM.value == "TB"
assert MermaidDirection.LEFT_RIGHT.value == "LR"
assert MermaidDirection.BOTTOM_TOP.value == "BT"
assert MermaidDirection.RIGHT_LEFT.value == "RL"
def test_image_format_values(self):
assert ImageFormat.PNG.value == "png"
assert ImageFormat.SVG.value == "svg"
assert ImageFormat.PDF.value == "pdf"
def test_image_format_from_path_png(self):
fmt = ImageFormat.from_path("graph.png")
assert fmt == ImageFormat.PNG
def test_image_format_from_path_svg(self):
fmt = ImageFormat.from_path("diagram.svg")
assert fmt == ImageFormat.SVG
def test_image_format_from_path_jpeg(self):
fmt = ImageFormat.from_path("image.jpeg")
assert fmt == ImageFormat.JPEG
def test_image_format_from_path_unknown_defaults_to_png(self):
fmt = ImageFormat.from_path("graph.xyz")
assert fmt == ImageFormat.PNG
def test_image_format_from_path_no_extension(self):
fmt = ImageFormat.from_path("output")
assert fmt == ImageFormat.PNG
def test_node_style_defaults(self):
style = NodeStyle()
assert style.shape == NodeShape.ROUND
assert style.fill_color.startswith("#")
assert style.stroke_color.startswith("#")
def test_edge_style_defaults(self):
style = EdgeStyle()
assert style.line_style == "solid"
assert style.arrow_head == "normal"
def test_visualization_style_defaults(self):
style = VisualizationStyle()
assert style.direction == MermaidDirection.TOP_BOTTOM
assert style.show_tools is True
assert style.show_weights is False
# ============================================================================
# Tests for GraphVisualizer
# ============================================================================
class TestGraphVisualizerInit:
def test_default_style(self):
graph = _make_simple_graph()
viz = GraphVisualizer(graph)
assert viz.graph is graph
assert isinstance(viz.style, VisualizationStyle)
def test_custom_style(self):
graph = _make_simple_graph()
style = VisualizationStyle(direction=MermaidDirection.LEFT_RIGHT)
viz = GraphVisualizer(graph, style=style)
assert viz.style.direction == MermaidDirection.LEFT_RIGHT
class TestToMermaid:
def test_basic_output(self):
graph = _make_simple_graph()
viz = GraphVisualizer(graph)
mermaid = viz.to_mermaid()
assert "flowchart" in mermaid
assert "TB" in mermaid
def test_with_title(self):
graph = _make_simple_graph()
viz = GraphVisualizer(graph)
mermaid = viz.to_mermaid(title="My Graph")
assert "title: My Graph" in mermaid
assert "---" in mermaid
def test_different_direction(self):
graph = _make_simple_graph()
viz = GraphVisualizer(graph)
mermaid = viz.to_mermaid(direction=MermaidDirection.LEFT_RIGHT)
assert "LR" in mermaid
def test_nodes_in_output(self):
graph = _make_simple_graph()
viz = GraphVisualizer(graph)
mermaid = viz.to_mermaid()
assert "Agent A" in mermaid or "a(" in mermaid
def test_edges_in_output(self):
graph = _make_simple_graph()
viz = GraphVisualizer(graph)
mermaid = viz.to_mermaid()
assert "-->" in mermaid
def test_with_weights(self):
graph = _make_simple_graph()
style = VisualizationStyle(show_weights=True)
viz = GraphVisualizer(graph, style=style)
mermaid = viz.to_mermaid()
# Weight 0.8 should appear since it's != 1.0
assert "w=0.80" in mermaid
def test_empty_graph(self):
graph = RoleGraph()
viz = GraphVisualizer(graph)
mermaid = viz.to_mermaid()
assert "flowchart" in mermaid
def test_with_tools(self):
graph = _make_graph_with_tools()
style = VisualizationStyle(show_tools=True)
viz = GraphVisualizer(graph, style=style)
mermaid = viz.to_mermaid()
assert "web_search" in mermaid
def test_classdefs_included(self):
graph = _make_simple_graph()
viz = GraphVisualizer(graph)
mermaid = viz.to_mermaid()
assert "classDef" in mermaid
class TestToAscii:
def test_basic_output(self):
graph = _make_simple_graph()
viz = GraphVisualizer(graph)
ascii_out = viz.to_ascii()
assert "╔" in ascii_out
assert "Graph" in ascii_out
def test_without_edges(self):
graph = _make_simple_graph()
viz = GraphVisualizer(graph)
ascii_no_edges = viz.to_ascii(show_edges=False)
assert "Edges:" not in ascii_no_edges
def test_with_edges(self):
graph = _make_simple_graph()
viz = GraphVisualizer(graph)
ascii_out = viz.to_ascii(show_edges=True)
assert "Edges:" in ascii_out
def test_shows_tools(self):
graph = _make_graph_with_tools()
style = VisualizationStyle(show_tools=True)
viz = GraphVisualizer(graph, style=style)
ascii_out = viz.to_ascii()
assert "🔧" in ascii_out
def test_custom_box_width(self):
graph = _make_simple_graph()
viz = GraphVisualizer(graph)
ascii_out = viz.to_ascii(box_width=30)
assert "╔" in ascii_out
def test_long_node_name_truncated(self):
from core.agent import AgentProfile
long_name_agent = AgentProfile(
agent_id="short",
display_name="A" * 50, # Very long name
)
g = rx.PyDiGraph()
g.add_node({"id": "short"})
graph = RoleGraph(node_ids=["short"], graph=g)
graph.agents = [long_name_agent]
viz = GraphVisualizer(graph)
ascii_out = viz.to_ascii(box_width=20)
assert "..." in ascii_out
def test_many_edges_displayed(self):
from core.agent import AgentProfile
# Build a graph with many edges
num_nodes = 15
ids = [f"node{i}" for i in range(num_nodes)]
agents = [AgentProfile(agent_id=nid, display_name=nid) for nid in ids]
g = rx.PyDiGraph()
idx_map = {}
for nid in ids:
idx_map[nid] = g.add_node({"id": nid})
for i in range(num_nodes - 1):
g.add_edge(idx_map[ids[i]], idx_map[ids[i + 1]], {"weight": 1.0})
A = torch.zeros((num_nodes, num_nodes))
for i in range(num_nodes - 1):
A[i, i + 1] = 1.0
role_connections = {nid: [] for nid in ids}
for i in range(num_nodes - 1):
role_connections[ids[i]] = [ids[i + 1]]
graph = RoleGraph(node_ids=ids, role_connections=role_connections, graph=g, A_com=A)
graph.agents = agents
viz = GraphVisualizer(graph)
ascii_out = viz.to_ascii(show_edges=True)
# ASCII output should contain edges section
assert "Edges:" in ascii_out
assert "→" in ascii_out
class TestToDot:
def test_basic_output(self):
graph = _make_simple_graph()
viz = GraphVisualizer(graph)
dot = viz.to_dot()
assert "digraph" in dot
assert "AgentGraph" in dot
assert "->" in dot
def test_custom_graph_name(self):
graph = _make_simple_graph()
viz = GraphVisualizer(graph)
dot = viz.to_dot(graph_name="MyCustomGraph")
assert "MyCustomGraph" in dot
def test_custom_rankdir(self):
graph = _make_simple_graph()
viz = GraphVisualizer(graph)
dot = viz.to_dot(rankdir="LR")
assert "rankdir=LR" in dot
def test_with_dpi(self):
graph = _make_simple_graph()
viz = GraphVisualizer(graph)
dot = viz.to_dot(dpi=150)
assert "dpi=150" in dot
def test_with_weights(self):
graph = _make_simple_graph()
style = VisualizationStyle(show_weights=True)
viz = GraphVisualizer(graph, style=style)
dot = viz.to_dot()
assert "0.80" in dot
def test_with_tools_in_label(self):
graph = _make_graph_with_tools()
style = VisualizationStyle(show_tools=True)
viz = GraphVisualizer(graph, style=style)
dot = viz.to_dot()
assert "web_search" in dot
def test_empty_graph(self):
graph = RoleGraph()
viz = GraphVisualizer(graph)
dot = viz.to_dot()
assert "digraph" in dot
assert "}" in dot
class TestToAdjacencyMatrix:
def test_empty_graph(self):
graph = RoleGraph()
viz = GraphVisualizer(graph)
result = viz.to_adjacency_matrix()
# Empty graph returns a string (the "Empty adjacency matrix" check in source
# uses a_com.size == 0 which compares a method reference, so falls through
# to produce an empty header string)
assert isinstance(result, str)
def test_basic_matrix(self):
from core.agent import AgentProfile
agents = [AgentProfile(agent_id=aid, display_name=aid) for aid in ["a", "b"]]
graph = RoleGraph(
node_ids=["a", "b"],
A_com=torch.tensor([[0.0, 0.8], [0.0, 0.0]]),
)
graph.agents = agents
viz = GraphVisualizer(graph)
result = viz.to_adjacency_matrix()
assert "0.80" in result
def test_without_labels(self):
from core.agent import AgentProfile
agents = [AgentProfile(agent_id=aid, display_name=aid) for aid in ["a", "b"]]
graph = RoleGraph(
node_ids=["a", "b"],
A_com=torch.tensor([[0.0, 0.5], [0.0, 0.0]]),
)
graph.agents = agents
viz = GraphVisualizer(graph)
result = viz.to_adjacency_matrix(show_labels=False)
# Should not include headers when show_labels=False
assert "0.50" in result
class TestSafeId:
def test_simple_id(self):
graph = RoleGraph()
viz = GraphVisualizer(graph)
assert viz._safe_id("agent") == "agent"
def test_id_with_hyphens(self):
graph = RoleGraph()
viz = GraphVisualizer(graph)
assert viz._safe_id("agent-name") == "agent_name"
def test_id_with_spaces(self):
graph = RoleGraph()
viz = GraphVisualizer(graph)
assert viz._safe_id("my agent") == "my_agent"
def test_id_starting_with_digit(self):
graph = RoleGraph()
viz = GraphVisualizer(graph)
result = viz._safe_id("123agent")
assert result.startswith("n_")
def test_id_with_dots(self):
graph = RoleGraph()
viz = GraphVisualizer(graph)
result = viz._safe_id("agent.name")
assert "." not in result
def test_empty_id(self):
graph = RoleGraph()
viz = GraphVisualizer(graph)
result = viz._safe_id("")
assert result == "unknown"
class TestFormatNodeLabel:
def test_basic_label(self):
from core.agent import AgentProfile
agent = AgentProfile(agent_id="a", display_name="My Agent")
graph = RoleGraph()
viz = GraphVisualizer(graph)
label = viz._format_node_label(agent, viz.style.agent_style)
assert "My Agent" in label
def test_long_name_truncated(self):
from core.agent import AgentProfile
agent = AgentProfile(agent_id="a", display_name="A" * 50)
graph = RoleGraph()
style = VisualizationStyle(max_label_length=20)
viz = GraphVisualizer(graph, style=style)
label = viz._format_node_label(agent, viz.style.agent_style)
assert "..." in label
def test_with_tools(self):
from core.agent import AgentProfile
agent = AgentProfile(agent_id="a", display_name="Agent", tools=["tool1", "tool2", "tool3"])
graph = RoleGraph()
style = VisualizationStyle(show_tools=True)
viz = GraphVisualizer(graph, style=style)
label = viz._format_node_label(agent, viz.style.agent_style)
assert "tool1" in label
def test_with_many_tools_shows_ellipsis(self):
from core.agent import AgentProfile
agent = AgentProfile(agent_id="a", display_name="Agent", tools=["t1", "t2", "t3", "t4"])
graph = RoleGraph()
style = VisualizationStyle(show_tools=True)
viz = GraphVisualizer(graph, style=style)
label = viz._format_node_label(agent, viz.style.agent_style)
assert "..." in label
class TestSaveMermaid:
def test_save_mermaid_file(self, tmp_path):
graph = _make_simple_graph()
viz = GraphVisualizer(graph)
filepath = tmp_path / "test.mmd"
viz.save_mermaid(filepath)
content = filepath.read_text()
assert "flowchart" in content
def test_save_mermaid_md_file(self, tmp_path):
graph = _make_simple_graph()
viz = GraphVisualizer(graph)
filepath = tmp_path / "test.md"
viz.save_mermaid(filepath, title="Test Graph")
content = filepath.read_text()
assert "```mermaid" in content
assert "```" in content
def test_save_dot_file(self, tmp_path):
graph = _make_simple_graph()
viz = GraphVisualizer(graph)
filepath = tmp_path / "test.dot"
viz.save_dot(filepath)
content = filepath.read_text()
assert "digraph" in content
class TestRenderImage:
def test_render_raises_import_error_without_graphviz(self, tmp_path):
import sys
import unittest.mock as mock
graph = _make_simple_graph()
viz = GraphVisualizer(graph)
# Mock the import to simulate graphviz not being installed
with mock.patch.dict(sys.modules, {"graphviz": None}):
with pytest.raises((ImportError, Exception)):
viz.render_image(tmp_path / "output.png")
def test_show_interactive_raises_without_graphviz(self):
import sys
import unittest.mock as mock
graph = _make_simple_graph()
viz = GraphVisualizer(graph)
with mock.patch.dict(sys.modules, {"graphviz": None}):
with pytest.raises((ImportError, Exception)):
viz.show_interactive()
class TestPrintColored:
def test_print_colored_no_error(self):
"""print_colored should not raise even if rich is unavailable."""
graph = _make_simple_graph()
viz = GraphVisualizer(graph)
# Should not raise; may silently fail if rich is not installed
try:
viz.print_colored()
except Exception:
pass # Acceptable if rich not installed
# ============================================================================
# Tests for module-level convenience functions
# ============================================================================
class TestConvenienceFunctions:
def test_to_mermaid_function(self):
graph = _make_simple_graph()
result = to_mermaid(graph)
assert "flowchart" in result
def test_to_mermaid_with_direction(self):
graph = _make_simple_graph()
result = to_mermaid(graph, direction=MermaidDirection.LEFT_RIGHT)
assert "LR" in result
def test_to_mermaid_with_title(self):
graph = _make_simple_graph()
result = to_mermaid(graph, title="Test")
assert "title: Test" in result
def test_to_mermaid_with_custom_style(self):
graph = _make_simple_graph()
style = VisualizationStyle(show_weights=True)
result = to_mermaid(graph, style=style)
assert "flowchart" in result
def test_to_ascii_function(self):
graph = _make_simple_graph()
result = to_ascii(graph)
assert "╔" in result
def test_to_ascii_no_edges(self):
graph = _make_simple_graph()
result = to_ascii(graph, show_edges=False)
assert "Edges:" not in result
def test_to_dot_function(self):
graph = _make_simple_graph()
result = to_dot(graph)
assert "digraph" in result
def test_to_dot_custom_name(self):
graph = _make_simple_graph()
result = to_dot(graph, graph_name="Custom")
assert "Custom" in result
def test_print_graph_auto_format(self):
"""print_graph with auto format should not raise."""
import sys
from unittest.mock import MagicMock
graph = _make_simple_graph()
mock_console = MagicMock()
mock_tree = MagicMock()
mock_branch = MagicMock()
mock_tree.add.return_value = mock_branch
mock_branch.add.return_value = MagicMock()
mock_table = MagicMock()
mock_rich_console = MagicMock()
mock_rich_console.Console.return_value = mock_console
mock_rich_tree = MagicMock()
mock_rich_tree.Tree.return_value = mock_tree
mock_rich_table = MagicMock()
mock_rich_table.Table.return_value = mock_table
with pytest.MonkeyPatch.context() as mp:
mp.setitem(sys.modules, "rich.console", mock_rich_console)
mp.setitem(sys.modules, "rich.table", mock_rich_table)
mp.setitem(sys.modules, "rich.tree", mock_rich_tree)
mp.setitem(sys.modules, "rich", MagicMock())
print_graph(graph, output_format="auto") # Should call print_colored via rich
def test_print_graph_colored_format(self):
"""print_graph with colored format should not raise."""
graph = _make_simple_graph()
print_graph(graph, output_format="colored")
def test_print_graph_ascii_format(self):
"""print_graph with ascii format should not raise."""
graph = _make_simple_graph()
print_graph(graph, output_format="ascii")
def test_print_graph_mermaid_format(self):
"""print_graph with mermaid format should not raise."""
graph = _make_simple_graph()
print_graph(graph, output_format="mermaid")
class TestEdgeShortNameTruncation:
"""Test ASCII representation edge source/target name truncation."""
def test_long_source_name_truncated_in_ascii(self):
from core.agent import AgentProfile
agent_a = AgentProfile(
agent_id="very_long_source_name_here",
display_name="Very Long Source",
)
agent_b = AgentProfile(
agent_id="b",
display_name="B",
)
g = rx.PyDiGraph()
g.add_node({"id": "very_long_source_name_here"})
g.add_node({"id": "b"})
g.add_edge(0, 1, {"weight": 1.0})
graph = RoleGraph(
node_ids=["very_long_source_name_here", "b"],
role_connections={"very_long_source_name_here": ["b"], "b": []},
graph=g,
A_com=torch.tensor([[0.0, 1.0], [0.0, 0.0]]),
)
graph.agents = [agent_a, agent_b]
viz = GraphVisualizer(graph)
ascii_out = viz.to_ascii(show_edges=True)
# Long source name should be truncated in edges section
assert ".." in ascii_out
class TestTaskNodeVisualization:
"""Test visualization of task nodes (diamond shape)."""
def test_task_node_in_mermaid(self):
from core.agent import AgentProfile
class TaskAgent(AgentProfile):
type: str = "task"
task_agent = TaskAgent(agent_id="task1", display_name="My Task")
regular_agent = AgentProfile(agent_id="agent1", display_name="My Agent")
g = rx.PyDiGraph()
g.add_node({"id": "task1"})
g.add_node({"id": "agent1"})
g.add_edge(0, 1, {"type": "task", "weight": 1.0})
graph = RoleGraph(
node_ids=["task1", "agent1"],
role_connections={"task1": ["agent1"], "agent1": []},
graph=g,
A_com=torch.tensor([[0.0, 1.0], [0.0, 0.0]]),
)
graph.agents = [task_agent, regular_agent]
viz = GraphVisualizer(graph)
mermaid = viz.to_mermaid()
# Task node should use diamond/special shape {label}
assert "task1" in mermaid
def test_task_edge_in_dot(self):
from core.agent import AgentProfile
agent_a = AgentProfile(agent_id="a", display_name="A")
agent_b = AgentProfile(agent_id="b", display_name="B")
g = rx.PyDiGraph()
g.add_node({"id": "a"})
g.add_node({"id": "b"})
g.add_edge(0, 1, {"type": "task_edge", "weight": 1.0})
graph = RoleGraph(
node_ids=["a", "b"],
role_connections={"a": ["b"], "b": []},
graph=g,
A_com=torch.tensor([[0.0, 1.0], [0.0, 0.0]]),
)
graph.agents = [agent_a, agent_b]
viz = GraphVisualizer(graph)
dot = viz.to_dot()
assert "dashed" in dot
class TestRenderImage:
"""Test render_image and show_interactive (with mocked graphviz)."""
def test_render_image_no_graphviz_raises(self):
import sys
graph = _make_simple_graph()
viz = GraphVisualizer(graph)
# Temporarily remove graphviz from modules if it exists
graphviz_backup = sys.modules.get("graphviz")
sys.modules["graphviz"] = None # type: ignore
try:
with pytest.raises(ImportError, match="[Gg]raphviz"):
viz.render_image("test.png")
finally:
if graphviz_backup is not None:
sys.modules["graphviz"] = graphviz_backup
else:
del sys.modules["graphviz"]
def test_render_image_with_mock_graphviz(self, tmp_path):
from unittest.mock import MagicMock, patch
graph = _make_simple_graph()
viz = GraphVisualizer(graph)
mock_source = MagicMock()
mock_graphviz = MagicMock()
mock_graphviz.Source.return_value = mock_source
with patch.dict("sys.modules", {"graphviz": mock_graphviz}):
viz.render_image(tmp_path / "test.png")
mock_source.render.assert_called_once()
def test_render_image_with_explicit_format(self, tmp_path):
from unittest.mock import MagicMock, patch
graph = _make_simple_graph()
viz = GraphVisualizer(graph)
mock_source = MagicMock()
mock_graphviz = MagicMock()
mock_graphviz.Source.return_value = mock_source
with patch.dict("sys.modules", {"graphviz": mock_graphviz}):
viz.render_image(tmp_path / "test", image_format=ImageFormat.SVG)
mock_source.render.assert_called_once()
def test_render_image_with_dpi(self, tmp_path):
from unittest.mock import MagicMock, patch
graph = _make_simple_graph()
viz = GraphVisualizer(graph)
mock_source = MagicMock()
mock_graphviz = MagicMock()
mock_graphviz.Source.return_value = mock_source
with patch.dict("sys.modules", {"graphviz": mock_graphviz}):
viz.render_image(tmp_path / "test.png", dpi=300)
mock_source.render.assert_called_once()
def test_render_image_svg_ignores_dpi(self, tmp_path):
from unittest.mock import MagicMock, patch
graph = _make_simple_graph()
viz = GraphVisualizer(graph)
mock_source = MagicMock()
mock_graphviz = MagicMock()
mock_graphviz.Source.return_value = mock_source
with patch.dict("sys.modules", {"graphviz": mock_graphviz}):
viz.render_image(tmp_path / "test.svg", dpi=300)
# For SVG, dpi should be None (ignored)
mock_source.render.assert_called_once()
def test_render_image_raises_on_render_error(self, tmp_path):
from unittest.mock import MagicMock, patch
graph = _make_simple_graph()
viz = GraphVisualizer(graph)
mock_source = MagicMock()
mock_source.render.side_effect = Exception("render failed")
mock_graphviz = MagicMock()
mock_graphviz.Source.return_value = mock_source
with patch.dict("sys.modules", {"graphviz": mock_graphviz}):
with pytest.raises(RuntimeError, match="render failed"):
viz.render_image(tmp_path / "test.png")
def test_show_interactive_no_graphviz_raises(self):
import sys
graph = _make_simple_graph()
viz = GraphVisualizer(graph)
graphviz_backup = sys.modules.get("graphviz")
sys.modules["graphviz"] = None # type: ignore
try:
with pytest.raises(ImportError, match="[Gg]raphviz"):
viz.show_interactive()
finally:
if graphviz_backup is not None:
sys.modules["graphviz"] = graphviz_backup
else:
del sys.modules["graphviz"]
def test_show_interactive_with_mock_graphviz(self):
from unittest.mock import MagicMock, patch
graph = _make_simple_graph()
viz = GraphVisualizer(graph)
mock_source = MagicMock()
mock_graphviz = MagicMock()
mock_graphviz.Source.return_value = mock_source
with patch.dict("sys.modules", {"graphviz": mock_graphviz}):
viz.show_interactive() # Should not raise
class TestPrintColored:
"""Test print_colored method."""
def test_print_colored_with_mocked_rich(self):
"""print_colored with mocked rich should cover lines 487-541."""
import sys
from unittest.mock import MagicMock
graph = _make_simple_graph()
viz = GraphVisualizer(graph)
mock_console = MagicMock()
mock_tree = MagicMock()
mock_branch = MagicMock()
mock_tree.add.return_value = mock_branch
mock_branch.add.return_value = MagicMock()
mock_table = MagicMock()
mock_rich_console = MagicMock()
mock_rich_console.Console.return_value = mock_console
mock_rich_tree = MagicMock()
mock_rich_tree.Tree.return_value = mock_tree
mock_rich_table = MagicMock()
mock_rich_table.Table.return_value = mock_table
with pytest.MonkeyPatch.context() as mp:
mp.setitem(sys.modules, "rich.console", mock_rich_console)
mp.setitem(sys.modules, "rich.table", mock_rich_table)
mp.setitem(sys.modules, "rich.tree", mock_rich_tree)
mp.setitem(sys.modules, "rich", MagicMock())
viz.print_colored()
mock_console.print.assert_called()
def test_print_colored_with_graph_with_tools_and_description(self):
"""Covers description + tools branches in print_colored (lines 503-511)."""
import sys
from unittest.mock import MagicMock
from core.agent import AgentProfile
agent = AgentProfile(
agent_id="a",
display_name="Agent A",
description="A helpful agent " * 10, # long description
tools=["tool1", "tool2"],
)
import rustworkx as rx
g = rx.PyDiGraph()
g.add_node({"id": "a"})
graph = RoleGraph(node_ids=["a"], graph=g, A_com=torch.zeros((1, 1)))
graph.agents = [agent]
viz = GraphVisualizer(graph)
mock_console = MagicMock()
mock_tree = MagicMock()
mock_branch = MagicMock()
mock_node = MagicMock()
mock_tree.add.return_value = mock_branch
mock_branch.add.return_value = mock_node
mock_node.add.return_value = MagicMock()
mock_table = MagicMock()
mock_rich_console = MagicMock()
mock_rich_console.Console.return_value = mock_console
mock_rich_tree = MagicMock()
mock_rich_tree.Tree.return_value = mock_tree
mock_rich_table = MagicMock()
mock_rich_table.Table.return_value = mock_table
with pytest.MonkeyPatch.context() as mp:
mp.setitem(sys.modules, "rich.console", mock_rich_console)
mp.setitem(sys.modules, "rich.table", mock_rich_table)
mp.setitem(sys.modules, "rich.tree", mock_rich_tree)
mp.setitem(sys.modules, "rich", MagicMock())
viz.print_colored()
mock_console.print.assert_called()
def test_print_colored_with_many_edges_shows_table(self):
"""Covers edge table section in print_colored (lines 522-541)."""
import sys
from unittest.mock import MagicMock
from core.agent import AgentProfile
n = 15
ids = [f"n{i}" for i in range(n)]
agents = [AgentProfile(agent_id=nid, display_name=nid) for nid in ids]
import rustworkx as rx
g = rx.PyDiGraph()
idx_map = {}
for nid in ids:
idx_map[nid] = g.add_node({"id": nid})
for i in range(n - 1):
g.add_edge(idx_map[ids[i]], idx_map[ids[i + 1]], {"weight": 1.0, "source": ids[i], "target": ids[i + 1]})
A = torch.zeros((n, n))
for i in range(n - 1):
A[i, i + 1] = 1.0
graph = RoleGraph(node_ids=ids, graph=g, A_com=A)
graph.agents = agents
viz = GraphVisualizer(graph)
mock_console = MagicMock()
mock_tree = MagicMock()
mock_branch = MagicMock()
mock_tree.add.return_value = mock_branch
mock_branch.add.return_value = MagicMock()
mock_table = MagicMock()
mock_rich_console = MagicMock()
mock_rich_console.Console.return_value = mock_console
mock_rich_tree = MagicMock()
mock_rich_tree.Tree.return_value = mock_tree
mock_rich_table = MagicMock()
mock_rich_table.Table.return_value = mock_table
with pytest.MonkeyPatch.context() as mp:
mp.setitem(sys.modules, "rich.console", mock_rich_console)
mp.setitem(sys.modules, "rich.table", mock_rich_table)
mp.setitem(sys.modules, "rich.tree", mock_rich_tree)
mp.setitem(sys.modules, "rich", MagicMock())
viz.print_colored()
# Should have printed tree + table
assert mock_console.print.call_count >= 2
def test_print_colored_without_rich(self):
"""If rich is not available, print_colored falls back gracefully."""
import sys
graph = _make_simple_graph()
viz = GraphVisualizer(graph)
rich_backup = sys.modules.get("rich")
sys.modules["rich"] = None # type: ignore
sys.modules["rich.console"] = None # type: ignore
sys.modules["rich.table"] = None # type: ignore
sys.modules["rich.tree"] = None # type: ignore
try:
viz.print_colored() # Should not raise (fallback)
finally:
if rich_backup is not None:
sys.modules["rich"] = rich_backup
else:
sys.modules.pop("rich", None)
sys.modules.pop("rich.console", None)
sys.modules.pop("rich.table", None)
sys.modules.pop("rich.tree", None)
class TestSafeidDoubleUnderscore:
"""Test _safe_id with double underscore inputs."""
def test_double_underscore_removed(self):
graph = _make_simple_graph()
viz = GraphVisualizer(graph)
result = viz._safe_id("agent__name__here")
assert "__" not in result
def test_leading_trailing_underscores_removed(self):
graph = _make_simple_graph()
viz = GraphVisualizer(graph)
result = viz._safe_id("_agent_name_")
assert not result.startswith("_")
assert not result.endswith("_")
def test_empty_string_returns_unknown(self):
graph = _make_simple_graph()
viz = GraphVisualizer(graph)
result = viz._safe_id("")
assert result == "unknown"
def test_only_special_chars(self):
graph = _make_simple_graph()
viz = GraphVisualizer(graph)
result = viz._safe_id("---")
assert result == "unknown"
class TestToMermaidEdgeDedup:
"""Test that duplicate edges are not added to mermaid output."""
def test_duplicate_edges_deduped(self):
from core.agent import AgentProfile
agent_a = AgentProfile(agent_id="a", display_name="A")
agent_b = AgentProfile(agent_id="b", display_name="B")
g = rx.PyDiGraph()
g.add_node({"id": "a"})
g.add_node({"id": "b"})
g.add_edge(0, 1, {"weight": 1.0})
g.add_edge(0, 1, {"weight": 0.5}) # Duplicate
graph = RoleGraph(
node_ids=["a", "b"],
role_connections={"a": ["b"], "b": []},
graph=g,
A_com=torch.tensor([[0.0, 1.0], [0.0, 0.0]]),
)
graph.agents = [agent_a, agent_b]
viz = GraphVisualizer(graph)
mermaid = viz.to_mermaid()
# Count occurrences of a --> b or a -> b
count = mermaid.count("a --> b") + mermaid.count("a->b")
assert count <= 1 # Should be deduped
class TestToMermaidEdgeEmptySourceTarget:
"""Test handling of empty source/target in mermaid output."""
def test_empty_source_edge_skipped(self):
from core.agent import AgentProfile
agent_a = AgentProfile(agent_id="a", display_name="A")
g = rx.PyDiGraph()
g.add_node({"id": "a"})
# Add edge with empty source/target
g.add_edge(0, 0, {"source": "", "target": ""})
graph = RoleGraph(
node_ids=["a"],
role_connections={"a": []},
graph=g,
A_com=torch.zeros((1, 1)),
)
graph.agents = [agent_a]
viz = GraphVisualizer(graph)
mermaid = viz.to_mermaid()
# Should not raise
assert "flowchart" in mermaid
class TestRenderToImageConvenience:
"""Test render_to_image and show_graph_interactive convenience functions."""
def test_render_to_image_with_mock(self, tmp_path):
from unittest.mock import MagicMock, patch
from core.visualization import render_to_image
graph = _make_simple_graph()
mock_source = MagicMock()
mock_graphviz = MagicMock()
mock_graphviz.Source.return_value = mock_source
with patch.dict("sys.modules", {"graphviz": mock_graphviz}):
render_to_image(graph, tmp_path / "test.png")
mock_source.render.assert_called_once()
def test_show_graph_interactive_with_mock(self):
from unittest.mock import MagicMock, patch
from core.visualization import show_graph_interactive
graph = _make_simple_graph()
mock_source = MagicMock()
mock_graphviz = MagicMock()
mock_graphviz.Source.return_value = mock_source
with patch.dict("sys.modules", {"graphviz": mock_graphviz}):
show_graph_interactive(graph) # Should not raise
if __name__ == "__main__":
pytest.main([__file__, "-v"])