gMAS / src /core /graph.py
Артём Боярских
chore: initial commit
3193174
"""RoleGraph on rustworkx with dynamic topology support."""
from collections import deque
from collections.abc import Iterable, Mapping, Sequence
from enum import Enum
from typing import Any, Protocol, runtime_checkable
import rustworkx as rx
import torch
from pydantic import BaseModel, ConfigDict, Field
# Constants for magic values
EDGE_THRESHOLD = 0.5
__all__ = [
"GraphIntegrityError",
"RoleGraph",
"StateMigrationPolicy",
"StateStorage",
]
class StateMigrationPolicy(str, Enum):
DISCARD = "discard"
COPY = "copy"
ARCHIVE = "archive"
@runtime_checkable
class StateStorage(Protocol):
def save(self, node_id: str, state: dict[str, Any]) -> None: ...
def load(self, node_id: str) -> dict[str, Any] | None: ...
def delete(self, node_id: str) -> None: ...
class GraphIntegrityError(Exception):
pass
def _get_agent_id(agent: Any) -> str | None:
"""Safely get agent_id from an agent (object or dict)."""
if hasattr(agent, "agent_id"):
return agent.agent_id
if isinstance(agent, dict):
return agent.get("id") or agent.get("agent_id")
return None
class RoleGraph(BaseModel):
"""
Role graph on rustworkx with adjacency matrices and auxiliary data.
Supports conditional routing via edge_conditions.
Supports explicit start_node and end_node for execution optimisation.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
agents: list[Any] = Field(default_factory=list)
node_ids: list[str] = Field(default_factory=list)
role_connections: dict[str, list[str]] = Field(default_factory=dict)
task_node: str | None = None
query: str | None = None
answer: str | None = None
graph: rx.PyDiGraph = Field(default_factory=rx.PyDiGraph)
A_com: torch.Tensor = Field(default_factory=lambda: torch.zeros((0, 0), dtype=torch.float32))
S_tilde: torch.Tensor | None = Field(default=None)
p_matrix: torch.Tensor | None = Field(default=None)
state_storage: Any | None = Field(default=None, exclude=True)
# Explicit start/end nodes for execution path optimisation
start_node: str | None = Field(default=None)
end_node: str | None = Field(default=None)
# Inactive nodes — present in the graph but not executed
# Saves tokens without removing nodes from the structure
disabled_nodes: set[str] = Field(default_factory=set)
# Routing conditions: {(source, target): condition}
# Callable conditions (not serialized)
edge_conditions: dict[tuple[str, str], Any] = Field(default_factory=dict, exclude=True)
# String conditions from the schema
edge_condition_names: dict[tuple[str, str], str] = Field(default_factory=dict)
@property
def role_sequence(self) -> list[str]:
"""Order of roles (agent identifiers)."""
result = []
for a in self.agents:
if hasattr(a, "agent_id"):
result.append(a.agent_id)
elif isinstance(a, dict):
result.append(a.get("id", a.get("agent_id", str(a))))
else:
result.append(str(a))
return result
@property
def embeddings(self) -> torch.Tensor:
"""Stack of agent embeddings or an empty tensor."""
embs = []
for a in self.agents:
emb = getattr(a, "embedding", None) if hasattr(a, "embedding") else None
if emb is not None:
embs.append(emb)
return torch.stack(embs) if embs else torch.zeros((0, 0), dtype=torch.float32)
@property
def num_nodes(self) -> int:
"""Number of nodes in the graph."""
return self.graph.num_nodes()
@property
def num_edges(self) -> int:
"""Number of edges in the graph."""
return self.graph.num_edges()
@property
def edges(self) -> list[dict[str, Any]]:
"""List of edges with data (source, target, attr, weight...)."""
result = []
for i in self.graph.edge_indices():
s, t = self.graph.get_edge_endpoints_by_index(i)
d = self.graph.get_edge_data_by_index(i)
edge = {"source": self._nid(s), "target": self._nid(t)}
if isinstance(d, dict):
for k, v in d.items():
if isinstance(v, torch.Tensor):
edge[k] = v.tolist()
else:
edge[k] = v
result.append(edge)
return result
@property
def edge_index(self) -> torch.Tensor:
"""Edge index in PyG format (2 x E)."""
if not self.graph.num_edges():
return torch.zeros((2, 0), dtype=torch.long)
src, tgt = [], []
for i in self.graph.edge_indices():
s, t = self.graph.get_edge_endpoints_by_index(i)
src.append(s)
tgt.append(t)
return torch.tensor([src, tgt], dtype=torch.long)
@property
def edge_attr(self) -> torch.Tensor:
"""Edge feature matrix (default: weight + attr fields)."""
if not self.graph.num_edges():
return torch.zeros((0, 4), dtype=torch.float32)
attrs = []
default_attr = torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32)
for i in self.graph.edge_indices():
d = self.graph.get_edge_data_by_index(i)
attr = d.get("attr", default_attr) if isinstance(d, dict) else default_attr
if isinstance(attr, torch.Tensor):
attrs.append(attr)
else:
attrs.append(torch.tensor(attr, dtype=torch.float32))
return torch.vstack(attrs).to(torch.float32)
@property
def has_conditional_edges(self) -> bool:
"""Whether the graph has conditional edges."""
return bool(self.edge_conditions) or bool(self.edge_condition_names)
@property
def conditional_edges(self) -> list[tuple[str, str]]:
"""List of conditional edges (source, target)."""
edges = set(self.edge_conditions.keys())
edges.update(self.edge_condition_names.keys())
return list(edges)
def get_edge_condition(self, source: str, target: str) -> Any | str | None:
"""
Get the condition for an edge (callable or string).
Returns the callable if present, otherwise the string condition, otherwise None.
"""
# First check callable
if (source, target) in self.edge_conditions:
return self.edge_conditions[(source, target)]
# Then string
if (source, target) in self.edge_condition_names:
return self.edge_condition_names[(source, target)]
return None
def get_all_edge_conditions(self) -> dict[tuple[str, str], Any]:
"""Get all edge conditions (union of callable and string conditions)."""
result: dict[tuple[str, str], Any] = {}
# First string conditions
result.update(self.edge_condition_names)
# Then callable (overwrite string ones if present)
result.update(self.edge_conditions)
return result
def set_edge_condition(
self,
source: str,
target: str,
condition: Any,
) -> bool:
"""
Set the condition for an edge.
Args:
source: Source ID.
target: Target ID.
condition: Callable or string condition.
Returns:
True if the edge exists and the condition was set.
"""
# Check that the edge exists
src_idx = self.get_node_index(source)
tgt_idx = self.get_node_index(target)
if src_idx is None or tgt_idx is None:
return False
if callable(condition):
self.edge_conditions[(source, target)] = condition
elif isinstance(condition, str):
self.edge_condition_names[(source, target)] = condition
return True
def remove_edge_condition(self, source: str, target: str) -> bool:
"""Remove the condition from an edge."""
removed = False
if (source, target) in self.edge_conditions:
del self.edge_conditions[(source, target)]
removed = True
if (source, target) in self.edge_condition_names:
del self.edge_condition_names[(source, target)]
removed = True
return removed
def _nid(self, idx: int) -> str:
"""Return the node identifier by rustworkx index."""
d = self.graph.get_node_data(idx)
return d.get("id", str(idx)) if isinstance(d, dict) else str(idx)
def get_node_index(self, node_id: str) -> int | None:
"""Find the rustworkx index of a node by its ID."""
for i in self.graph.node_indices():
d = self.graph.get_node_data(i)
if isinstance(d, dict) and d.get("id") == node_id:
return i
return None
def get_agent_by_id(self, agent_id: str) -> Any | None:
"""Return the agent object by its identifier."""
for agent in self.agents:
aid = getattr(agent, "agent_id", None)
if aid is None and isinstance(agent, dict):
aid = agent.get("id", agent.get("agent_id"))
if aid == agent_id:
return agent
return None
def add_node(
self,
agent: Any,
connections_from: Sequence[str] | None = None,
connections_to: Sequence[str] | None = None,
) -> bool:
"""Add a node/agent and optionally connect it to neighbours."""
node_id = getattr(agent, "agent_id", None)
if node_id is None and isinstance(agent, dict):
node_id = agent.get("id", agent.get("agent_id"))
if node_id is None or node_id in self.node_ids:
return False
node_type = "task" if getattr(agent, "type", None) == "task" else "agent"
self.graph.add_node({"id": node_id, "type": node_type})
self.agents.append(agent)
self.node_ids.append(node_id)
self._expand_adjacency(1)
self.role_connections[node_id] = []
for src_id in connections_from or []:
if src_id in self.node_ids:
self.add_edge(src_id, node_id)
for tgt_id in connections_to or []:
if tgt_id in self.node_ids:
self.add_edge(node_id, tgt_id)
return True
def remove_node(
self,
node_id: str,
policy: StateMigrationPolicy = StateMigrationPolicy.DISCARD,
) -> Any | None:
"""Remove a node, with optional state migration/archiving."""
if node_id not in self.node_ids:
return None
agent_idx = self.node_ids.index(node_id)
agent = self.agents[agent_idx]
rx_idx = self.get_node_index(node_id)
if policy == StateMigrationPolicy.ARCHIVE:
self._archive_state(agent)
if rx_idx is not None:
self.graph.remove_node(rx_idx)
self.agents.pop(agent_idx)
self.node_ids.pop(agent_idx)
self._shrink_adjacency(agent_idx)
self.role_connections.pop(node_id, None)
for conns in self.role_connections.values():
if node_id in conns:
conns.remove(node_id)
if self.task_node == node_id:
object.__setattr__(self, "task_node", None)
return agent
def replace_node(
self,
node_id: str,
new_agent: Any,
policy: StateMigrationPolicy = StateMigrationPolicy.COPY,
) -> Any | None:
"""Replace a node with a new agent using the selected state migration policy."""
if node_id not in self.node_ids:
return None
agent_idx = self.node_ids.index(node_id)
old_agent = self.agents[agent_idx]
rx_idx = self.get_node_index(node_id)
if policy == StateMigrationPolicy.COPY:
new_agent = self._copy_state(old_agent, new_agent)
elif policy == StateMigrationPolicy.ARCHIVE:
self._archive_state(old_agent)
new_id = _get_agent_id(new_agent)
if new_id is None:
new_id = str(id(new_agent))
node_type = "task" if getattr(new_agent, "type", None) == "task" else "agent"
if rx_idx is not None:
self.graph[rx_idx] = {"id": new_id, "type": node_type}
self.agents[agent_idx] = new_agent
self.node_ids[agent_idx] = new_id
if node_id != new_id:
if node_id in self.role_connections:
self.role_connections[new_id] = self.role_connections.pop(node_id)
for conns in self.role_connections.values():
for i, c in enumerate(conns):
if c == node_id:
conns[i] = new_id
if self.task_node == node_id:
object.__setattr__(self, "task_node", new_id)
return old_agent
def _copy_state(self, old_agent: Any, new_agent: Any) -> Any:
"""Copy state/hidden_state/embedding from the old agent to the new one."""
if hasattr(old_agent, "state") and hasattr(new_agent, "with_state"):
new_agent = new_agent.with_state(list(old_agent.state))
if (
hasattr(old_agent, "hidden_state")
and old_agent.hidden_state is not None
and hasattr(new_agent, "with_hidden_state")
):
new_agent = new_agent.with_hidden_state(old_agent.hidden_state)
if hasattr(old_agent, "embedding") and old_agent.embedding is not None and hasattr(new_agent, "with_embedding"):
new_agent = new_agent.with_embedding(old_agent.embedding)
return new_agent
def _archive_state(self, agent: Any) -> None:
"""Save the agent state to external storage if it is configured."""
if self.state_storage is None:
return
state_data = {
"state": list(getattr(agent, "state", [])),
"hidden_state": (
agent.hidden_state.cpu().tolist()
if hasattr(agent, "hidden_state") and agent.hidden_state is not None
else None
),
"embedding": (
agent.embedding.cpu().tolist() if hasattr(agent, "embedding") and agent.embedding is not None else None
),
}
self.state_storage.save(_get_agent_id(agent) or "", state_data)
def _expand_adjacency(self, count: int = 1) -> None:
"""Expand the adjacency/probability matrices when adding nodes."""
n = self.A_com.shape[0] if self.A_com.numel() > 0 else 0
new_n = n + count
new_a = torch.zeros((new_n, new_n), dtype=torch.float32)
if n > 0:
new_a[:n, :n] = self.A_com
object.__setattr__(self, "A_com", new_a)
if self.S_tilde is not None:
new_s = torch.zeros((new_n, new_n), dtype=torch.float32)
new_s[:n, :n] = self.S_tilde
object.__setattr__(self, "S_tilde", new_s)
if self.p_matrix is not None:
new_p = torch.zeros((new_n, new_n), dtype=torch.float32)
new_p[:n, :n] = self.p_matrix
object.__setattr__(self, "p_matrix", new_p)
def _shrink_adjacency(self, idx: int) -> None:
"""Remove a row/column from the matrices when removing a node."""
if self.A_com.numel() == 0:
return
mask = torch.ones(self.A_com.shape[0], dtype=torch.bool)
mask[idx] = False
object.__setattr__(self, "A_com", self.A_com[mask][:, mask])
if self.S_tilde is not None:
object.__setattr__(self, "S_tilde", self.S_tilde[mask][:, mask])
if self.p_matrix is not None:
object.__setattr__(self, "p_matrix", self.p_matrix[mask][:, mask])
def add_edge(
self,
source_id: str,
target_id: str,
weight: float = 1.0,
**edge_attrs,
) -> bool:
"""Add a directed edge and update the adjacency matrix."""
src_idx = self.get_node_index(source_id)
tgt_idx = self.get_node_index(target_id)
if src_idx is None or tgt_idx is None:
return False
self.graph.add_edge(src_idx, tgt_idx, {"weight": weight, **edge_attrs})
src_list_idx = self.node_ids.index(source_id)
tgt_list_idx = self.node_ids.index(target_id)
if self.A_com.numel() > 0:
self.A_com[src_list_idx, tgt_list_idx] = weight
return True
def remove_edge(self, source_id: str, target_id: str) -> bool:
"""Remove an edge and zero out the weight in the matrix."""
src_idx = self.get_node_index(source_id)
tgt_idx = self.get_node_index(target_id)
if src_idx is None or tgt_idx is None:
return False
for eid in self.graph.edge_indices():
s, t = self.graph.get_edge_endpoints_by_index(eid)
if s == src_idx and t == tgt_idx:
self.graph.remove_edge_from_index(eid)
src_list_idx = self.node_ids.index(source_id)
tgt_list_idx = self.node_ids.index(target_id)
if self.A_com.numel() > 0:
self.A_com[src_list_idx, tgt_list_idx] = 0.0
return True
return False
def get_neighbors(self, node_id: str, direction: str = "out") -> list[str]:
"""Return neighbouring nodes (out/in/both)."""
idx = self.get_node_index(node_id)
if idx is None:
return []
neighbors = set()
for eid in self.graph.edge_indices():
s, t = self.graph.get_edge_endpoints_by_index(eid)
if direction in ("out", "both") and s == idx:
neighbors.add(self._nid(t))
if direction in ("in", "both") and t == idx:
neighbors.add(self._nid(s))
return list(neighbors)
def update_communication(
self,
a_com: torch.Tensor,
s_tilde: torch.Tensor | None = None,
p_matrix: torch.Tensor | None = None,
) -> None:
"""Fully replace the communication matrix and graph edges."""
a_tensor = a_com.detach().cpu() if a_com.requires_grad else a_com.cpu()
for eid in list(self.graph.edge_indices()):
self.graph.remove_edge_from_index(eid)
n_nodes = a_tensor.shape[0]
node_indices = list(self.graph.node_indices())
for i in range(n_nodes):
for j in range(n_nodes):
if a_tensor[i, j].item() > EDGE_THRESHOLD and i < len(node_indices) and j < len(node_indices):
edge_data = {"weight": float(a_tensor[i, j].item()), "from_update": True}
if s_tilde is not None:
s_tensor = s_tilde.detach().cpu() if s_tilde.requires_grad else s_tilde.cpu()
edge_data["score"] = float(s_tensor[i, j].item())
if p_matrix is not None:
p_tensor = p_matrix.detach().cpu() if p_matrix.requires_grad else p_matrix.cpu()
edge_data["p_ij"] = float(p_tensor[i, j].item())
self.graph.add_edge(node_indices[i], node_indices[j], edge_data)
object.__setattr__(self, "A_com", a_tensor.to(torch.float32))
if s_tilde is not None:
s_tensor = s_tilde.detach().cpu() if s_tilde.requires_grad else s_tilde.cpu()
object.__setattr__(self, "S_tilde", s_tensor.to(torch.float32))
if p_matrix is not None:
p_tensor = p_matrix.detach().cpu() if p_matrix.requires_grad else p_matrix.cpu()
object.__setattr__(self, "p_matrix", p_tensor.to(torch.float32))
def verify_integrity(self, raise_on_error: bool = True) -> list[str]:
"""Check consistency of the agent list, nodes, and matrices."""
errors: list[str] = []
n_agents = len(self.agents)
n_ids = len(self.node_ids)
n_rx = self.graph.num_nodes()
n_matrix = self.A_com.shape[0] if self.A_com.numel() > 0 else 0
if n_agents != n_ids:
errors.append(f"agents ({n_agents}) != node_ids ({n_ids})")
if n_agents != n_rx:
errors.append(f"agents ({n_agents}) != rustworkx nodes ({n_rx})")
if n_agents != n_matrix:
errors.append(f"agents ({n_agents}) != matrix size ({n_matrix})")
role_seq = set(self.role_sequence)
node_ids_set = set(self.node_ids)
if role_seq != node_ids_set:
diff = role_seq.symmetric_difference(node_ids_set)
errors.append(f"role_sequence != node_ids, diff: {diff}")
rx_ids = set()
for i in self.graph.node_indices():
data = self.graph.get_node_data(i)
if isinstance(data, dict) and "id" in data:
rx_ids.add(data["id"])
if rx_ids != node_ids_set:
diff = rx_ids.symmetric_difference(node_ids_set)
errors.append(f"rustworkx IDs != node_ids, diff: {diff}")
for src, targets in self.role_connections.items():
if src not in node_ids_set:
errors.append(f"connection source '{src}' not in nodes")
errors.extend(f"connection target '{t}' not in nodes" for t in targets if t not in node_ids_set)
if self.task_node is not None and self.task_node not in node_ids_set:
errors.append(f"task_node '{self.task_node}' not in nodes")
if errors and raise_on_error:
raise GraphIntegrityError("; ".join(errors))
return errors
def is_consistent(self) -> bool:
"""Quick size consistency check without a detailed report."""
n = len(self.agents)
return (
len(self.node_ids) == n
and self.graph.num_nodes() == n
and (self.A_com.shape[0] if self.A_com.numel() > 0 else 0) == n
)
def to_dict(self) -> dict[str, Any]:
"""Serialize the graph to a dict (for saving or debugging)."""
emb = self.embeddings
return {
"role_sequence": list(self.role_sequence),
"node_ids": list(self.node_ids),
"role_connections": {k: list(v) for k, v in self.role_connections.items()},
"task_node": self.task_node,
"query": self.query,
"answer": self.answer,
"agents": [
{
"agent_id": _get_agent_id(a),
"display_name": getattr(a, "display_name", None),
"persona": getattr(a, "persona", ""),
"description": getattr(a, "description", ""),
"llm_backbone": getattr(a, "llm_backbone", None),
"tools": list(getattr(a, "tools", [])),
"embedding": a.embedding.cpu().tolist() if a.embedding is not None else None,
"state": list(getattr(a, "state", [])),
}
for a in self.agents
],
"edges": self.edges,
"embeddings": emb.cpu().tolist() if emb.numel() > 0 else [],
"edge_index": self.edge_index.tolist() if self.edge_index.numel() > 0 else [[], []],
"edge_attr": self.edge_attr.tolist() if self.edge_attr.numel() > 0 else [],
"adjacency": self.A_com.tolist() if self.A_com.numel() > 0 else [],
"num_nodes": self.num_nodes,
"num_edges": self.num_edges,
}
@classmethod
def from_dict(
cls,
data: dict[str, Any],
agent_factory: Any = None,
verify: bool = True,
) -> "RoleGraph":
"""Create a RoleGraph from a dict with agents and edges."""
from core.agent import AgentProfile
factory = agent_factory or AgentProfile
agents = []
for a_data in data.get("agents", []):
emb = a_data.get("embedding")
embedding = torch.tensor(emb, dtype=torch.float32) if emb else None
aid = a_data.get("agent_id")
agent = factory(
agent_id=aid,
display_name=a_data.get("display_name", aid),
persona=a_data.get("persona", ""),
description=a_data.get("description", ""),
llm_backbone=a_data.get("llm_backbone"),
tools=a_data.get("tools", []),
state=a_data.get("state", []),
embedding=embedding,
)
agents.append(agent)
graph = rx.PyDiGraph()
idx_map = {}
for agent in agents:
aid = _get_agent_id(agent)
idx_map[aid] = graph.add_node(
{
"id": aid,
"type": "agent",
}
)
for edge in data.get("edges", []):
src_id = edge.get("source")
tgt_id = edge.get("target")
if src_id in idx_map and tgt_id in idx_map:
edge_data = {k: v for k, v in edge.items() if k not in ("source", "target")}
graph.add_edge(idx_map[src_id], idx_map[tgt_id], edge_data)
adj = data.get("adjacency", [])
a_com = (
torch.tensor(adj, dtype=torch.float32)
if adj
else torch.zeros((len(agents), len(agents)), dtype=torch.float32)
)
rg = cls(
agents=agents,
node_ids=data.get("node_ids", [_get_agent_id(a) for a in agents]),
role_connections=data.get("role_connections", {}),
task_node=data.get("task_node"),
query=data.get("query"),
answer=data.get("answer"),
graph=graph,
A_com=a_com,
)
if verify:
rg.verify_integrity()
return rg
@classmethod
def from_graph(
cls,
agents: Sequence[Any],
graph: rx.PyDiGraph,
a_com: torch.Tensor,
connections: Mapping[str, Iterable[str]],
task_node: str | None = None,
query: str | None = None,
answer: str | None = None,
verify: bool = True,
) -> "RoleGraph":
"""Create a RoleGraph from an existing PyDiGraph and adjacency matrix."""
agents_list = list(agents)
a_tensor = a_com if isinstance(a_com, torch.Tensor) else torch.tensor(a_com, dtype=torch.float32)
node_ids_raw = [_get_agent_id(a) for a in agents_list]
node_ids_filtered = [nid for nid in node_ids_raw if nid is not None]
rg = cls(
agents=agents_list,
node_ids=node_ids_filtered,
role_connections={k: list(v) for k, v in connections.items()},
task_node=task_node,
query=query,
answer=answer,
graph=graph,
A_com=a_tensor.to(torch.float32),
)
if verify:
rg.verify_integrity()
return rg
def to_pyg_data(
self,
node_features: dict[str, torch.Tensor] | None = None,
edge_features: dict[str, torch.Tensor] | None = None,
include_embeddings: bool = True,
include_default_edge_attr: bool = True,
) -> Any:
"""Convert the graph to torch_geometric.data.Data with features."""
from torch_geometric.data import Data
n = len(self.role_sequence)
num_edges = self.num_edges
x_parts = []
if include_embeddings:
emb = self.embeddings
if emb.numel() > 0:
x_parts.append(emb)
if node_features:
for node_feat in node_features.values():
if node_feat.shape[0] == n:
feat_to_add = node_feat.unsqueeze(1) if node_feat.dim() == 1 else node_feat
x_parts.append(feat_to_add)
x = torch.cat(x_parts, dim=1) if x_parts else torch.zeros((n, 0), dtype=torch.float32)
ei = self.edge_index if self.edge_index.numel() > 0 else torch.zeros((2, 0), dtype=torch.long)
ea_parts = []
if include_default_edge_attr:
default_ea = self.edge_attr if self.edge_attr.numel() > 0 else None
if default_ea is not None and default_ea.numel() > 0:
ea_parts.append(default_ea)
if edge_features:
for edge_feat in edge_features.values():
if edge_feat.shape[0] == num_edges:
feat_to_add = edge_feat.unsqueeze(1) if edge_feat.dim() == 1 else edge_feat
ea_parts.append(feat_to_add)
ea = torch.cat(ea_parts, dim=1) if ea_parts else torch.zeros((ei.shape[1], 0), dtype=torch.float32)
data = Data(x=x, edge_index=ei, edge_attr=ea, num_nodes=n)
data.node_ids = self.node_ids
data.role_sequence = self.role_sequence
if self.p_matrix is not None:
data.p_matrix = self.p_matrix.clone()
return data
def get_edge_features_from_schema(self) -> dict[str, torch.Tensor]:
"""Extract edge feature tensors from the saved schema."""
features = {
"weight": [],
"probability": [],
"trust": [],
}
for eid in self.graph.edge_indices():
data = self.graph.get_edge_data_by_index(eid)
if isinstance(data, dict):
features["weight"].append(data.get("weight", 1.0))
features["probability"].append(data.get("probability", 1.0))
schema = data.get("schema", {})
cost = schema.get("cost", {})
features["trust"].append(cost.get("trust", 1.0))
else:
features["weight"].append(1.0)
features["probability"].append(1.0)
features["trust"].append(1.0)
return {name: torch.tensor(values, dtype=torch.float32) for name, values in features.items()}
def get_node_features_from_schema(self) -> dict[str, torch.Tensor]:
"""Extract node feature tensors from the rustworkx data schema."""
features = {
"trust_score": [],
"quality_score": [],
}
for node_id in self.node_ids:
idx = self.get_node_index(node_id)
if idx is not None:
data = self.graph.get_node_data(idx)
if isinstance(data, dict):
schema = data.get("schema", {})
features["trust_score"].append(schema.get("trust_score", 1.0))
features["quality_score"].append(schema.get("quality_score", 1.0))
else:
features["trust_score"].append(1.0)
features["quality_score"].append(1.0)
else:
features["trust_score"].append(1.0)
features["quality_score"].append(1.0)
return {name: torch.tensor(values, dtype=torch.float32) for name, values in features.items()}
def subgraph(self, node_ids: list[str]) -> "RoleGraph":
"""Build a subgraph containing only the selected nodes and their connections."""
agents = [a for a in self.agents if _get_agent_id(a) in node_ids]
id_set = set(node_ids)
new_graph = rx.PyDiGraph()
idx_map = {}
for agent in agents:
agent_id = _get_agent_id(agent)
if agent_id is None:
continue
old_idx = self.get_node_index(agent_id)
if old_idx is not None:
node_data = self.graph.get_node_data(old_idx)
new_idx = new_graph.add_node(node_data)
idx_map[old_idx] = new_idx
for eid in self.graph.edge_indices():
s, t = self.graph.get_edge_endpoints_by_index(eid)
if s in idx_map and t in idx_map:
edge_data = self.graph.get_edge_data_by_index(eid)
new_graph.add_edge(idx_map[s], idx_map[t], edge_data)
indices = [self.node_ids.index(nid) for nid in node_ids if nid in self.node_ids]
if indices and self.A_com.numel() > 0:
indices_tensor = torch.tensor(indices)
new_a = self.A_com[indices_tensor][:, indices_tensor]
else:
new_a = torch.zeros((len(agents), len(agents)), dtype=torch.float32)
new_connections = {k: [v for v in vs if v in id_set] for k, vs in self.role_connections.items() if k in id_set}
node_ids_raw = [_get_agent_id(a) for a in agents]
node_ids_filtered = [nid for nid in node_ids_raw if nid is not None]
return RoleGraph(
agents=agents,
node_ids=node_ids_filtered,
role_connections=new_connections,
task_node=self.task_node if self.task_node in id_set else None,
query=self.query,
answer=self.answer,
graph=new_graph,
A_com=new_a,
start_node=self.start_node if self.start_node in id_set else None,
end_node=self.end_node if self.end_node in id_set else None,
)
def set_start_node(self, node_id: str) -> bool:
"""
Set the start node for execution.
Args:
node_id: ID of the node from which execution starts.
Returns:
True if the node exists and was set.
"""
if node_id not in self.node_ids:
return False
object.__setattr__(self, "start_node", node_id)
return True
def set_end_node(self, node_id: str) -> bool:
"""
Set the end node for execution.
Args:
node_id: ID of the node at which execution ends.
Returns:
True if the node exists and was set.
"""
if node_id not in self.node_ids:
return False
object.__setattr__(self, "end_node", node_id)
return True
def set_execution_bounds(self, start_node: str | None, end_node: str | None) -> bool:
"""
Set start and end nodes simultaneously.
Args:
start_node: ID of the start node (None for auto-detection).
end_node: ID of the end node (None for auto-detection).
Returns:
True if both nodes are valid (or None).
"""
if start_node is not None and start_node not in self.node_ids:
return False
if end_node is not None and end_node not in self.node_ids:
return False
object.__setattr__(self, "start_node", start_node)
object.__setattr__(self, "end_node", end_node)
return True
# =========================================================================
# INACTIVE NODES (disabled nodes)
# =========================================================================
def disable(self, node_ids: str | list[str]) -> int:
"""
Deactivate nodes — they remain in the graph but will not be executed.
Args:
node_ids: Node ID or list of node IDs to deactivate.
Returns:
Number of successfully deactivated nodes.
Example:
graph.disable("agent1") # Single node
graph.disable(["a1", "a2", "a3"]) # Multiple nodes
"""
if isinstance(node_ids, str):
node_ids = [node_ids]
count = 0
for node_id in node_ids:
if node_id in self.node_ids:
self.disabled_nodes.add(node_id)
count += 1
return count
def enable(self, node_ids: str | list[str] | None = None) -> int:
"""
Activate nodes.
Args:
node_ids: Node ID, list of node IDs, or None to activate all.
Returns:
Number of activated nodes.
Example:
graph.enable("agent1") # Single node
graph.enable(["a1", "a2"]) # Multiple nodes
graph.enable() # All nodes
"""
if node_ids is None:
count = len(self.disabled_nodes)
self.disabled_nodes.clear()
return count
if isinstance(node_ids, str):
node_ids = [node_ids]
count = 0
for node_id in node_ids:
if node_id in self.disabled_nodes:
self.disabled_nodes.remove(node_id)
count += 1
return count
def is_enabled(self, node_id: str) -> bool:
"""Check whether a node is active."""
return node_id in self.node_ids and node_id not in self.disabled_nodes
def get_enabled(self) -> list[str]:
"""Get the list of active nodes."""
return [nid for nid in self.node_ids if nid not in self.disabled_nodes]
def get_disabled(self) -> list[str]:
"""Get the list of deactivated nodes."""
return list(self.disabled_nodes)
def get_reachable_from(self, source_id: str, threshold: float = EDGE_THRESHOLD) -> set[str]:
"""
Get all nodes reachable from source_id (forward BFS).
Args:
source_id: ID of the start node.
threshold: Minimum edge weight to consider a connection.
Returns:
Set of reachable node IDs (including source_id).
"""
if source_id not in self.node_ids:
return set()
reachable = {source_id}
queue = deque([source_id])
while queue:
current = queue.popleft()
current_idx = self.node_ids.index(current)
for j, node_id in enumerate(self.node_ids):
if node_id in reachable:
continue
if self.A_com.numel() > 0 and self.A_com[current_idx, j].item() > threshold:
reachable.add(node_id)
queue.append(node_id)
return reachable
def get_nodes_reaching(self, target_id: str, threshold: float = EDGE_THRESHOLD) -> set[str]:
"""
Get all nodes from which target_id is reachable (backward BFS).
Args:
target_id: ID of the target node.
threshold: Minimum edge weight to consider a connection.
Returns:
Set of node IDs from which target_id is reachable (including target_id itself).
"""
if target_id not in self.node_ids:
return set()
reaching = {target_id}
queue = deque([target_id])
while queue:
current = queue.popleft()
current_idx = self.node_ids.index(current)
for i, node_id in enumerate(self.node_ids):
if node_id in reaching:
continue
if self.A_com.numel() > 0 and self.A_com[i, current_idx].item() > threshold:
reaching.add(node_id)
queue.append(node_id)
return reaching
def get_relevant_nodes(
self,
start_node: str | None = None,
end_node: str | None = None,
threshold: float = EDGE_THRESHOLD,
) -> set[str]:
"""
Get nodes that lie on paths from start to end.
This is the intersection of:
- Nodes reachable from start_node
- Nodes from which end_node is reachable
Nodes not in this set are isolated and not needed for execution.
Args:
start_node: ID of the start node (or self.start_node, or the first by order).
end_node: ID of the end node (or self.end_node, or the last by order).
threshold: Minimum edge weight.
Returns:
Set of relevant node IDs.
"""
# Determine start
effective_start = start_node or self.start_node
if effective_start is None and self.node_ids:
# First node with no incoming edges
for node_id in self.node_ids:
idx = self.node_ids.index(node_id)
if self.A_com.numel() > 0:
in_degree = (self.A_com[:, idx] > threshold).sum().item()
if in_degree == 0:
effective_start = node_id
break
if effective_start is None:
effective_start = self.node_ids[0]
# Determine end
effective_end = end_node or self.end_node
if effective_end is None and self.node_ids:
# Last node with no outgoing edges
for node_id in reversed(self.node_ids):
idx = self.node_ids.index(node_id)
if self.A_com.numel() > 0:
out_degree = (self.A_com[idx, :] > threshold).sum().item()
if out_degree == 0:
effective_end = node_id
break
if effective_end is None:
effective_end = self.node_ids[-1]
if effective_start is None or effective_end is None:
return set()
# Intersection of nodes reachable from start and leading to end
reachable_from_start = self.get_reachable_from(effective_start, threshold)
reaching_end = self.get_nodes_reaching(effective_end, threshold)
return reachable_from_start & reaching_end
def get_isolated_nodes(
self,
start_node: str | None = None,
end_node: str | None = None,
threshold: float = EDGE_THRESHOLD,
) -> set[str]:
"""
Get isolated nodes that do not participate in the start->end path.
These nodes can be excluded from execution to save tokens.
Args:
start_node: ID of the start node.
end_node: ID of the end node.
threshold: Minimum edge weight.
Returns:
Set of isolated node IDs.
"""
relevant = self.get_relevant_nodes(start_node, end_node, threshold)
all_nodes = set(self.node_ids)
return all_nodes - relevant
def get_optimized_execution_order(
self,
start_node: str | None = None,
end_node: str | None = None,
threshold: float = EDGE_THRESHOLD,
) -> list[str]:
"""
Get the optimised execution order, excluding isolated nodes.
Args:
start_node: ID of the start node.
end_node: ID of the end node.
threshold: Minimum edge weight.
Returns:
List of node IDs in topological order (relevant nodes only).
"""
relevant = self.get_relevant_nodes(start_node, end_node, threshold)
# Topological sort of relevant nodes only
# Build in-degree for relevant nodes
in_degree: dict[str, int] = dict.fromkeys(relevant, 0)
for i, src in enumerate(self.node_ids):
if src not in relevant:
continue
for j, tgt in enumerate(self.node_ids):
if tgt not in relevant:
continue
if self.A_com.numel() > 0 and self.A_com[i, j].item() > threshold:
in_degree[tgt] += 1
# Kahn's algorithm
queue = deque([node_id for node_id in relevant if in_degree[node_id] == 0])
result: list[str] = []
while queue:
current = queue.popleft()
result.append(current)
current_idx = self.node_ids.index(current)
for j, tgt in enumerate(self.node_ids):
if tgt not in relevant or tgt in result:
continue
if self.A_com.numel() > 0 and self.A_com[current_idx, j].item() > threshold:
in_degree[tgt] -= 1
if in_degree[tgt] == 0:
queue.append(tgt)
# Add remaining nodes (in case of cycles)
for node_id in relevant:
if node_id not in result:
result.append(node_id)
return result
# =========================================================================
# DATA VALIDATION (input/output schema validation)
# =========================================================================
def get_agent_schema(self, agent_id: str) -> Any | None:
"""
Get the agent schema from the node data.
Args:
agent_id: Agent ID.
Returns:
AgentNodeSchema or None if not found.
"""
idx = self.get_node_index(agent_id)
if idx is None:
return None
data = self.graph.get_node_data(idx)
if not isinstance(data, dict):
return None
schema_dict = data.get("schema")
if schema_dict is None:
return None
# Restore AgentNodeSchema
from core.schema import AgentNodeSchema, NodeType
if schema_dict.get("type") == NodeType.AGENT.value or schema_dict.get("type") == "agent":
return AgentNodeSchema.model_validate(schema_dict)
return None
def validate_agent_input(
self,
agent_id: str,
data: dict[str, Any] | str,
) -> Any:
"""
Validate input data for an agent against its input_schema.
Args:
agent_id: Agent ID.
data: Data to validate (dict or JSON string).
Returns:
SchemaValidationResult with the validation result.
Example:
result = graph.validate_agent_input("solver", {"question": "2+2=?"})
if not result.valid:
print(f"Validation failed: {result.errors}")
"""
from core.schema import SchemaValidationResult
schema = self.get_agent_schema(agent_id)
if schema is None:
return SchemaValidationResult(
valid=True,
schema_type="input",
message=f"No schema found for agent '{agent_id}'",
)
return schema.validate_input(data)
def validate_agent_output(
self,
agent_id: str,
data: dict[str, Any] | str,
) -> Any:
"""
Validate output data for an agent against its output_schema.
Args:
agent_id: Agent ID.
data: Data to validate (dict or JSON string).
Returns:
SchemaValidationResult with the validation result.
Example:
result = graph.validate_agent_output("solver", response)
if result.valid:
parsed = result.validated_data
"""
from core.schema import SchemaValidationResult
schema = self.get_agent_schema(agent_id)
if schema is None:
return SchemaValidationResult(
valid=True,
schema_type="output",
message=f"No schema found for agent '{agent_id}'",
)
return schema.validate_output(data)
def has_input_schema(self, agent_id: str) -> bool:
"""Check whether the agent has an input_schema."""
schema = self.get_agent_schema(agent_id)
return schema is not None and schema.has_input_schema()
def has_output_schema(self, agent_id: str) -> bool:
"""Check whether the agent has an output_schema."""
schema = self.get_agent_schema(agent_id)
return schema is not None and schema.has_output_schema()
def get_input_schema_json(self, agent_id: str) -> dict[str, Any] | None:
"""
Get the JSON Schema for the agent's input data.
Useful for generating prompts describing the expected format.
"""
schema = self.get_agent_schema(agent_id)
if schema is None:
return None
return schema.input_schema_json
def get_output_schema_json(self, agent_id: str) -> dict[str, Any] | None:
"""
Get the JSON Schema for the agent's output data.
Useful for generating prompts describing the expected response format.
"""
schema = self.get_agent_schema(agent_id)
if schema is None:
return None
return schema.output_schema_json