Spaces:
Running
Running
| """RoleGraph on rustworkx with dynamic topology support.""" | |
| from collections import deque | |
| from collections.abc import Iterable, Mapping, Sequence | |
| from enum import Enum | |
| from typing import Any, Protocol, runtime_checkable | |
| import rustworkx as rx | |
| import torch | |
| from pydantic import BaseModel, ConfigDict, Field | |
| # Constants for magic values | |
| EDGE_THRESHOLD = 0.5 | |
| __all__ = [ | |
| "GraphIntegrityError", | |
| "RoleGraph", | |
| "StateMigrationPolicy", | |
| "StateStorage", | |
| ] | |
| class StateMigrationPolicy(str, Enum): | |
| DISCARD = "discard" | |
| COPY = "copy" | |
| ARCHIVE = "archive" | |
| class StateStorage(Protocol): | |
| def save(self, node_id: str, state: dict[str, Any]) -> None: ... | |
| def load(self, node_id: str) -> dict[str, Any] | None: ... | |
| def delete(self, node_id: str) -> None: ... | |
| class GraphIntegrityError(Exception): | |
| pass | |
| def _get_agent_id(agent: Any) -> str | None: | |
| """Safely get agent_id from an agent (object or dict).""" | |
| if hasattr(agent, "agent_id"): | |
| return agent.agent_id | |
| if isinstance(agent, dict): | |
| return agent.get("id") or agent.get("agent_id") | |
| return None | |
| class RoleGraph(BaseModel): | |
| """ | |
| Role graph on rustworkx with adjacency matrices and auxiliary data. | |
| Supports conditional routing via edge_conditions. | |
| Supports explicit start_node and end_node for execution optimisation. | |
| """ | |
| model_config = ConfigDict(arbitrary_types_allowed=True) | |
| agents: list[Any] = Field(default_factory=list) | |
| node_ids: list[str] = Field(default_factory=list) | |
| role_connections: dict[str, list[str]] = Field(default_factory=dict) | |
| task_node: str | None = None | |
| query: str | None = None | |
| answer: str | None = None | |
| graph: rx.PyDiGraph = Field(default_factory=rx.PyDiGraph) | |
| A_com: torch.Tensor = Field(default_factory=lambda: torch.zeros((0, 0), dtype=torch.float32)) | |
| S_tilde: torch.Tensor | None = Field(default=None) | |
| p_matrix: torch.Tensor | None = Field(default=None) | |
| state_storage: Any | None = Field(default=None, exclude=True) | |
| # Explicit start/end nodes for execution path optimisation | |
| start_node: str | None = Field(default=None) | |
| end_node: str | None = Field(default=None) | |
| # Inactive nodes — present in the graph but not executed | |
| # Saves tokens without removing nodes from the structure | |
| disabled_nodes: set[str] = Field(default_factory=set) | |
| # Routing conditions: {(source, target): condition} | |
| # Callable conditions (not serialized) | |
| edge_conditions: dict[tuple[str, str], Any] = Field(default_factory=dict, exclude=True) | |
| # String conditions from the schema | |
| edge_condition_names: dict[tuple[str, str], str] = Field(default_factory=dict) | |
| def role_sequence(self) -> list[str]: | |
| """Order of roles (agent identifiers).""" | |
| result = [] | |
| for a in self.agents: | |
| if hasattr(a, "agent_id"): | |
| result.append(a.agent_id) | |
| elif isinstance(a, dict): | |
| result.append(a.get("id", a.get("agent_id", str(a)))) | |
| else: | |
| result.append(str(a)) | |
| return result | |
| def embeddings(self) -> torch.Tensor: | |
| """Stack of agent embeddings or an empty tensor.""" | |
| embs = [] | |
| for a in self.agents: | |
| emb = getattr(a, "embedding", None) if hasattr(a, "embedding") else None | |
| if emb is not None: | |
| embs.append(emb) | |
| return torch.stack(embs) if embs else torch.zeros((0, 0), dtype=torch.float32) | |
| def num_nodes(self) -> int: | |
| """Number of nodes in the graph.""" | |
| return self.graph.num_nodes() | |
| def num_edges(self) -> int: | |
| """Number of edges in the graph.""" | |
| return self.graph.num_edges() | |
| def edges(self) -> list[dict[str, Any]]: | |
| """List of edges with data (source, target, attr, weight...).""" | |
| result = [] | |
| for i in self.graph.edge_indices(): | |
| s, t = self.graph.get_edge_endpoints_by_index(i) | |
| d = self.graph.get_edge_data_by_index(i) | |
| edge = {"source": self._nid(s), "target": self._nid(t)} | |
| if isinstance(d, dict): | |
| for k, v in d.items(): | |
| if isinstance(v, torch.Tensor): | |
| edge[k] = v.tolist() | |
| else: | |
| edge[k] = v | |
| result.append(edge) | |
| return result | |
| def edge_index(self) -> torch.Tensor: | |
| """Edge index in PyG format (2 x E).""" | |
| if not self.graph.num_edges(): | |
| return torch.zeros((2, 0), dtype=torch.long) | |
| src, tgt = [], [] | |
| for i in self.graph.edge_indices(): | |
| s, t = self.graph.get_edge_endpoints_by_index(i) | |
| src.append(s) | |
| tgt.append(t) | |
| return torch.tensor([src, tgt], dtype=torch.long) | |
| def edge_attr(self) -> torch.Tensor: | |
| """Edge feature matrix (default: weight + attr fields).""" | |
| if not self.graph.num_edges(): | |
| return torch.zeros((0, 4), dtype=torch.float32) | |
| attrs = [] | |
| default_attr = torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32) | |
| for i in self.graph.edge_indices(): | |
| d = self.graph.get_edge_data_by_index(i) | |
| attr = d.get("attr", default_attr) if isinstance(d, dict) else default_attr | |
| if isinstance(attr, torch.Tensor): | |
| attrs.append(attr) | |
| else: | |
| attrs.append(torch.tensor(attr, dtype=torch.float32)) | |
| return torch.vstack(attrs).to(torch.float32) | |
| def has_conditional_edges(self) -> bool: | |
| """Whether the graph has conditional edges.""" | |
| return bool(self.edge_conditions) or bool(self.edge_condition_names) | |
| def conditional_edges(self) -> list[tuple[str, str]]: | |
| """List of conditional edges (source, target).""" | |
| edges = set(self.edge_conditions.keys()) | |
| edges.update(self.edge_condition_names.keys()) | |
| return list(edges) | |
| def get_edge_condition(self, source: str, target: str) -> Any | str | None: | |
| """ | |
| Get the condition for an edge (callable or string). | |
| Returns the callable if present, otherwise the string condition, otherwise None. | |
| """ | |
| # First check callable | |
| if (source, target) in self.edge_conditions: | |
| return self.edge_conditions[(source, target)] | |
| # Then string | |
| if (source, target) in self.edge_condition_names: | |
| return self.edge_condition_names[(source, target)] | |
| return None | |
| def get_all_edge_conditions(self) -> dict[tuple[str, str], Any]: | |
| """Get all edge conditions (union of callable and string conditions).""" | |
| result: dict[tuple[str, str], Any] = {} | |
| # First string conditions | |
| result.update(self.edge_condition_names) | |
| # Then callable (overwrite string ones if present) | |
| result.update(self.edge_conditions) | |
| return result | |
| def set_edge_condition( | |
| self, | |
| source: str, | |
| target: str, | |
| condition: Any, | |
| ) -> bool: | |
| """ | |
| Set the condition for an edge. | |
| Args: | |
| source: Source ID. | |
| target: Target ID. | |
| condition: Callable or string condition. | |
| Returns: | |
| True if the edge exists and the condition was set. | |
| """ | |
| # Check that the edge exists | |
| src_idx = self.get_node_index(source) | |
| tgt_idx = self.get_node_index(target) | |
| if src_idx is None or tgt_idx is None: | |
| return False | |
| if callable(condition): | |
| self.edge_conditions[(source, target)] = condition | |
| elif isinstance(condition, str): | |
| self.edge_condition_names[(source, target)] = condition | |
| return True | |
| def remove_edge_condition(self, source: str, target: str) -> bool: | |
| """Remove the condition from an edge.""" | |
| removed = False | |
| if (source, target) in self.edge_conditions: | |
| del self.edge_conditions[(source, target)] | |
| removed = True | |
| if (source, target) in self.edge_condition_names: | |
| del self.edge_condition_names[(source, target)] | |
| removed = True | |
| return removed | |
| def _nid(self, idx: int) -> str: | |
| """Return the node identifier by rustworkx index.""" | |
| d = self.graph.get_node_data(idx) | |
| return d.get("id", str(idx)) if isinstance(d, dict) else str(idx) | |
| def get_node_index(self, node_id: str) -> int | None: | |
| """Find the rustworkx index of a node by its ID.""" | |
| for i in self.graph.node_indices(): | |
| d = self.graph.get_node_data(i) | |
| if isinstance(d, dict) and d.get("id") == node_id: | |
| return i | |
| return None | |
| def get_agent_by_id(self, agent_id: str) -> Any | None: | |
| """Return the agent object by its identifier.""" | |
| for agent in self.agents: | |
| aid = getattr(agent, "agent_id", None) | |
| if aid is None and isinstance(agent, dict): | |
| aid = agent.get("id", agent.get("agent_id")) | |
| if aid == agent_id: | |
| return agent | |
| return None | |
| def add_node( | |
| self, | |
| agent: Any, | |
| connections_from: Sequence[str] | None = None, | |
| connections_to: Sequence[str] | None = None, | |
| ) -> bool: | |
| """Add a node/agent and optionally connect it to neighbours.""" | |
| node_id = getattr(agent, "agent_id", None) | |
| if node_id is None and isinstance(agent, dict): | |
| node_id = agent.get("id", agent.get("agent_id")) | |
| if node_id is None or node_id in self.node_ids: | |
| return False | |
| node_type = "task" if getattr(agent, "type", None) == "task" else "agent" | |
| self.graph.add_node({"id": node_id, "type": node_type}) | |
| self.agents.append(agent) | |
| self.node_ids.append(node_id) | |
| self._expand_adjacency(1) | |
| self.role_connections[node_id] = [] | |
| for src_id in connections_from or []: | |
| if src_id in self.node_ids: | |
| self.add_edge(src_id, node_id) | |
| for tgt_id in connections_to or []: | |
| if tgt_id in self.node_ids: | |
| self.add_edge(node_id, tgt_id) | |
| return True | |
| def remove_node( | |
| self, | |
| node_id: str, | |
| policy: StateMigrationPolicy = StateMigrationPolicy.DISCARD, | |
| ) -> Any | None: | |
| """Remove a node, with optional state migration/archiving.""" | |
| if node_id not in self.node_ids: | |
| return None | |
| agent_idx = self.node_ids.index(node_id) | |
| agent = self.agents[agent_idx] | |
| rx_idx = self.get_node_index(node_id) | |
| if policy == StateMigrationPolicy.ARCHIVE: | |
| self._archive_state(agent) | |
| if rx_idx is not None: | |
| self.graph.remove_node(rx_idx) | |
| self.agents.pop(agent_idx) | |
| self.node_ids.pop(agent_idx) | |
| self._shrink_adjacency(agent_idx) | |
| self.role_connections.pop(node_id, None) | |
| for conns in self.role_connections.values(): | |
| if node_id in conns: | |
| conns.remove(node_id) | |
| if self.task_node == node_id: | |
| object.__setattr__(self, "task_node", None) | |
| return agent | |
| def replace_node( | |
| self, | |
| node_id: str, | |
| new_agent: Any, | |
| policy: StateMigrationPolicy = StateMigrationPolicy.COPY, | |
| ) -> Any | None: | |
| """Replace a node with a new agent using the selected state migration policy.""" | |
| if node_id not in self.node_ids: | |
| return None | |
| agent_idx = self.node_ids.index(node_id) | |
| old_agent = self.agents[agent_idx] | |
| rx_idx = self.get_node_index(node_id) | |
| if policy == StateMigrationPolicy.COPY: | |
| new_agent = self._copy_state(old_agent, new_agent) | |
| elif policy == StateMigrationPolicy.ARCHIVE: | |
| self._archive_state(old_agent) | |
| new_id = _get_agent_id(new_agent) | |
| if new_id is None: | |
| new_id = str(id(new_agent)) | |
| node_type = "task" if getattr(new_agent, "type", None) == "task" else "agent" | |
| if rx_idx is not None: | |
| self.graph[rx_idx] = {"id": new_id, "type": node_type} | |
| self.agents[agent_idx] = new_agent | |
| self.node_ids[agent_idx] = new_id | |
| if node_id != new_id: | |
| if node_id in self.role_connections: | |
| self.role_connections[new_id] = self.role_connections.pop(node_id) | |
| for conns in self.role_connections.values(): | |
| for i, c in enumerate(conns): | |
| if c == node_id: | |
| conns[i] = new_id | |
| if self.task_node == node_id: | |
| object.__setattr__(self, "task_node", new_id) | |
| return old_agent | |
| def _copy_state(self, old_agent: Any, new_agent: Any) -> Any: | |
| """Copy state/hidden_state/embedding from the old agent to the new one.""" | |
| if hasattr(old_agent, "state") and hasattr(new_agent, "with_state"): | |
| new_agent = new_agent.with_state(list(old_agent.state)) | |
| if ( | |
| hasattr(old_agent, "hidden_state") | |
| and old_agent.hidden_state is not None | |
| and hasattr(new_agent, "with_hidden_state") | |
| ): | |
| new_agent = new_agent.with_hidden_state(old_agent.hidden_state) | |
| if hasattr(old_agent, "embedding") and old_agent.embedding is not None and hasattr(new_agent, "with_embedding"): | |
| new_agent = new_agent.with_embedding(old_agent.embedding) | |
| return new_agent | |
| def _archive_state(self, agent: Any) -> None: | |
| """Save the agent state to external storage if it is configured.""" | |
| if self.state_storage is None: | |
| return | |
| state_data = { | |
| "state": list(getattr(agent, "state", [])), | |
| "hidden_state": ( | |
| agent.hidden_state.cpu().tolist() | |
| if hasattr(agent, "hidden_state") and agent.hidden_state is not None | |
| else None | |
| ), | |
| "embedding": ( | |
| agent.embedding.cpu().tolist() if hasattr(agent, "embedding") and agent.embedding is not None else None | |
| ), | |
| } | |
| self.state_storage.save(_get_agent_id(agent) or "", state_data) | |
| def _expand_adjacency(self, count: int = 1) -> None: | |
| """Expand the adjacency/probability matrices when adding nodes.""" | |
| n = self.A_com.shape[0] if self.A_com.numel() > 0 else 0 | |
| new_n = n + count | |
| new_a = torch.zeros((new_n, new_n), dtype=torch.float32) | |
| if n > 0: | |
| new_a[:n, :n] = self.A_com | |
| object.__setattr__(self, "A_com", new_a) | |
| if self.S_tilde is not None: | |
| new_s = torch.zeros((new_n, new_n), dtype=torch.float32) | |
| new_s[:n, :n] = self.S_tilde | |
| object.__setattr__(self, "S_tilde", new_s) | |
| if self.p_matrix is not None: | |
| new_p = torch.zeros((new_n, new_n), dtype=torch.float32) | |
| new_p[:n, :n] = self.p_matrix | |
| object.__setattr__(self, "p_matrix", new_p) | |
| def _shrink_adjacency(self, idx: int) -> None: | |
| """Remove a row/column from the matrices when removing a node.""" | |
| if self.A_com.numel() == 0: | |
| return | |
| mask = torch.ones(self.A_com.shape[0], dtype=torch.bool) | |
| mask[idx] = False | |
| object.__setattr__(self, "A_com", self.A_com[mask][:, mask]) | |
| if self.S_tilde is not None: | |
| object.__setattr__(self, "S_tilde", self.S_tilde[mask][:, mask]) | |
| if self.p_matrix is not None: | |
| object.__setattr__(self, "p_matrix", self.p_matrix[mask][:, mask]) | |
| def add_edge( | |
| self, | |
| source_id: str, | |
| target_id: str, | |
| weight: float = 1.0, | |
| **edge_attrs, | |
| ) -> bool: | |
| """Add a directed edge and update the adjacency matrix.""" | |
| src_idx = self.get_node_index(source_id) | |
| tgt_idx = self.get_node_index(target_id) | |
| if src_idx is None or tgt_idx is None: | |
| return False | |
| self.graph.add_edge(src_idx, tgt_idx, {"weight": weight, **edge_attrs}) | |
| src_list_idx = self.node_ids.index(source_id) | |
| tgt_list_idx = self.node_ids.index(target_id) | |
| if self.A_com.numel() > 0: | |
| self.A_com[src_list_idx, tgt_list_idx] = weight | |
| return True | |
| def remove_edge(self, source_id: str, target_id: str) -> bool: | |
| """Remove an edge and zero out the weight in the matrix.""" | |
| src_idx = self.get_node_index(source_id) | |
| tgt_idx = self.get_node_index(target_id) | |
| if src_idx is None or tgt_idx is None: | |
| return False | |
| for eid in self.graph.edge_indices(): | |
| s, t = self.graph.get_edge_endpoints_by_index(eid) | |
| if s == src_idx and t == tgt_idx: | |
| self.graph.remove_edge_from_index(eid) | |
| src_list_idx = self.node_ids.index(source_id) | |
| tgt_list_idx = self.node_ids.index(target_id) | |
| if self.A_com.numel() > 0: | |
| self.A_com[src_list_idx, tgt_list_idx] = 0.0 | |
| return True | |
| return False | |
| def get_neighbors(self, node_id: str, direction: str = "out") -> list[str]: | |
| """Return neighbouring nodes (out/in/both).""" | |
| idx = self.get_node_index(node_id) | |
| if idx is None: | |
| return [] | |
| neighbors = set() | |
| for eid in self.graph.edge_indices(): | |
| s, t = self.graph.get_edge_endpoints_by_index(eid) | |
| if direction in ("out", "both") and s == idx: | |
| neighbors.add(self._nid(t)) | |
| if direction in ("in", "both") and t == idx: | |
| neighbors.add(self._nid(s)) | |
| return list(neighbors) | |
| def update_communication( | |
| self, | |
| a_com: torch.Tensor, | |
| s_tilde: torch.Tensor | None = None, | |
| p_matrix: torch.Tensor | None = None, | |
| ) -> None: | |
| """Fully replace the communication matrix and graph edges.""" | |
| a_tensor = a_com.detach().cpu() if a_com.requires_grad else a_com.cpu() | |
| for eid in list(self.graph.edge_indices()): | |
| self.graph.remove_edge_from_index(eid) | |
| n_nodes = a_tensor.shape[0] | |
| node_indices = list(self.graph.node_indices()) | |
| for i in range(n_nodes): | |
| for j in range(n_nodes): | |
| if a_tensor[i, j].item() > EDGE_THRESHOLD and i < len(node_indices) and j < len(node_indices): | |
| edge_data = {"weight": float(a_tensor[i, j].item()), "from_update": True} | |
| if s_tilde is not None: | |
| s_tensor = s_tilde.detach().cpu() if s_tilde.requires_grad else s_tilde.cpu() | |
| edge_data["score"] = float(s_tensor[i, j].item()) | |
| if p_matrix is not None: | |
| p_tensor = p_matrix.detach().cpu() if p_matrix.requires_grad else p_matrix.cpu() | |
| edge_data["p_ij"] = float(p_tensor[i, j].item()) | |
| self.graph.add_edge(node_indices[i], node_indices[j], edge_data) | |
| object.__setattr__(self, "A_com", a_tensor.to(torch.float32)) | |
| if s_tilde is not None: | |
| s_tensor = s_tilde.detach().cpu() if s_tilde.requires_grad else s_tilde.cpu() | |
| object.__setattr__(self, "S_tilde", s_tensor.to(torch.float32)) | |
| if p_matrix is not None: | |
| p_tensor = p_matrix.detach().cpu() if p_matrix.requires_grad else p_matrix.cpu() | |
| object.__setattr__(self, "p_matrix", p_tensor.to(torch.float32)) | |
| def verify_integrity(self, raise_on_error: bool = True) -> list[str]: | |
| """Check consistency of the agent list, nodes, and matrices.""" | |
| errors: list[str] = [] | |
| n_agents = len(self.agents) | |
| n_ids = len(self.node_ids) | |
| n_rx = self.graph.num_nodes() | |
| n_matrix = self.A_com.shape[0] if self.A_com.numel() > 0 else 0 | |
| if n_agents != n_ids: | |
| errors.append(f"agents ({n_agents}) != node_ids ({n_ids})") | |
| if n_agents != n_rx: | |
| errors.append(f"agents ({n_agents}) != rustworkx nodes ({n_rx})") | |
| if n_agents != n_matrix: | |
| errors.append(f"agents ({n_agents}) != matrix size ({n_matrix})") | |
| role_seq = set(self.role_sequence) | |
| node_ids_set = set(self.node_ids) | |
| if role_seq != node_ids_set: | |
| diff = role_seq.symmetric_difference(node_ids_set) | |
| errors.append(f"role_sequence != node_ids, diff: {diff}") | |
| rx_ids = set() | |
| for i in self.graph.node_indices(): | |
| data = self.graph.get_node_data(i) | |
| if isinstance(data, dict) and "id" in data: | |
| rx_ids.add(data["id"]) | |
| if rx_ids != node_ids_set: | |
| diff = rx_ids.symmetric_difference(node_ids_set) | |
| errors.append(f"rustworkx IDs != node_ids, diff: {diff}") | |
| for src, targets in self.role_connections.items(): | |
| if src not in node_ids_set: | |
| errors.append(f"connection source '{src}' not in nodes") | |
| errors.extend(f"connection target '{t}' not in nodes" for t in targets if t not in node_ids_set) | |
| if self.task_node is not None and self.task_node not in node_ids_set: | |
| errors.append(f"task_node '{self.task_node}' not in nodes") | |
| if errors and raise_on_error: | |
| raise GraphIntegrityError("; ".join(errors)) | |
| return errors | |
| def is_consistent(self) -> bool: | |
| """Quick size consistency check without a detailed report.""" | |
| n = len(self.agents) | |
| return ( | |
| len(self.node_ids) == n | |
| and self.graph.num_nodes() == n | |
| and (self.A_com.shape[0] if self.A_com.numel() > 0 else 0) == n | |
| ) | |
| def to_dict(self) -> dict[str, Any]: | |
| """Serialize the graph to a dict (for saving or debugging).""" | |
| emb = self.embeddings | |
| return { | |
| "role_sequence": list(self.role_sequence), | |
| "node_ids": list(self.node_ids), | |
| "role_connections": {k: list(v) for k, v in self.role_connections.items()}, | |
| "task_node": self.task_node, | |
| "query": self.query, | |
| "answer": self.answer, | |
| "agents": [ | |
| { | |
| "agent_id": _get_agent_id(a), | |
| "display_name": getattr(a, "display_name", None), | |
| "persona": getattr(a, "persona", ""), | |
| "description": getattr(a, "description", ""), | |
| "llm_backbone": getattr(a, "llm_backbone", None), | |
| "tools": list(getattr(a, "tools", [])), | |
| "embedding": a.embedding.cpu().tolist() if a.embedding is not None else None, | |
| "state": list(getattr(a, "state", [])), | |
| } | |
| for a in self.agents | |
| ], | |
| "edges": self.edges, | |
| "embeddings": emb.cpu().tolist() if emb.numel() > 0 else [], | |
| "edge_index": self.edge_index.tolist() if self.edge_index.numel() > 0 else [[], []], | |
| "edge_attr": self.edge_attr.tolist() if self.edge_attr.numel() > 0 else [], | |
| "adjacency": self.A_com.tolist() if self.A_com.numel() > 0 else [], | |
| "num_nodes": self.num_nodes, | |
| "num_edges": self.num_edges, | |
| } | |
| def from_dict( | |
| cls, | |
| data: dict[str, Any], | |
| agent_factory: Any = None, | |
| verify: bool = True, | |
| ) -> "RoleGraph": | |
| """Create a RoleGraph from a dict with agents and edges.""" | |
| from core.agent import AgentProfile | |
| factory = agent_factory or AgentProfile | |
| agents = [] | |
| for a_data in data.get("agents", []): | |
| emb = a_data.get("embedding") | |
| embedding = torch.tensor(emb, dtype=torch.float32) if emb else None | |
| aid = a_data.get("agent_id") | |
| agent = factory( | |
| agent_id=aid, | |
| display_name=a_data.get("display_name", aid), | |
| persona=a_data.get("persona", ""), | |
| description=a_data.get("description", ""), | |
| llm_backbone=a_data.get("llm_backbone"), | |
| tools=a_data.get("tools", []), | |
| state=a_data.get("state", []), | |
| embedding=embedding, | |
| ) | |
| agents.append(agent) | |
| graph = rx.PyDiGraph() | |
| idx_map = {} | |
| for agent in agents: | |
| aid = _get_agent_id(agent) | |
| idx_map[aid] = graph.add_node( | |
| { | |
| "id": aid, | |
| "type": "agent", | |
| } | |
| ) | |
| for edge in data.get("edges", []): | |
| src_id = edge.get("source") | |
| tgt_id = edge.get("target") | |
| if src_id in idx_map and tgt_id in idx_map: | |
| edge_data = {k: v for k, v in edge.items() if k not in ("source", "target")} | |
| graph.add_edge(idx_map[src_id], idx_map[tgt_id], edge_data) | |
| adj = data.get("adjacency", []) | |
| a_com = ( | |
| torch.tensor(adj, dtype=torch.float32) | |
| if adj | |
| else torch.zeros((len(agents), len(agents)), dtype=torch.float32) | |
| ) | |
| rg = cls( | |
| agents=agents, | |
| node_ids=data.get("node_ids", [_get_agent_id(a) for a in agents]), | |
| role_connections=data.get("role_connections", {}), | |
| task_node=data.get("task_node"), | |
| query=data.get("query"), | |
| answer=data.get("answer"), | |
| graph=graph, | |
| A_com=a_com, | |
| ) | |
| if verify: | |
| rg.verify_integrity() | |
| return rg | |
| def from_graph( | |
| cls, | |
| agents: Sequence[Any], | |
| graph: rx.PyDiGraph, | |
| a_com: torch.Tensor, | |
| connections: Mapping[str, Iterable[str]], | |
| task_node: str | None = None, | |
| query: str | None = None, | |
| answer: str | None = None, | |
| verify: bool = True, | |
| ) -> "RoleGraph": | |
| """Create a RoleGraph from an existing PyDiGraph and adjacency matrix.""" | |
| agents_list = list(agents) | |
| a_tensor = a_com if isinstance(a_com, torch.Tensor) else torch.tensor(a_com, dtype=torch.float32) | |
| node_ids_raw = [_get_agent_id(a) for a in agents_list] | |
| node_ids_filtered = [nid for nid in node_ids_raw if nid is not None] | |
| rg = cls( | |
| agents=agents_list, | |
| node_ids=node_ids_filtered, | |
| role_connections={k: list(v) for k, v in connections.items()}, | |
| task_node=task_node, | |
| query=query, | |
| answer=answer, | |
| graph=graph, | |
| A_com=a_tensor.to(torch.float32), | |
| ) | |
| if verify: | |
| rg.verify_integrity() | |
| return rg | |
| def to_pyg_data( | |
| self, | |
| node_features: dict[str, torch.Tensor] | None = None, | |
| edge_features: dict[str, torch.Tensor] | None = None, | |
| include_embeddings: bool = True, | |
| include_default_edge_attr: bool = True, | |
| ) -> Any: | |
| """Convert the graph to torch_geometric.data.Data with features.""" | |
| from torch_geometric.data import Data | |
| n = len(self.role_sequence) | |
| num_edges = self.num_edges | |
| x_parts = [] | |
| if include_embeddings: | |
| emb = self.embeddings | |
| if emb.numel() > 0: | |
| x_parts.append(emb) | |
| if node_features: | |
| for node_feat in node_features.values(): | |
| if node_feat.shape[0] == n: | |
| feat_to_add = node_feat.unsqueeze(1) if node_feat.dim() == 1 else node_feat | |
| x_parts.append(feat_to_add) | |
| x = torch.cat(x_parts, dim=1) if x_parts else torch.zeros((n, 0), dtype=torch.float32) | |
| ei = self.edge_index if self.edge_index.numel() > 0 else torch.zeros((2, 0), dtype=torch.long) | |
| ea_parts = [] | |
| if include_default_edge_attr: | |
| default_ea = self.edge_attr if self.edge_attr.numel() > 0 else None | |
| if default_ea is not None and default_ea.numel() > 0: | |
| ea_parts.append(default_ea) | |
| if edge_features: | |
| for edge_feat in edge_features.values(): | |
| if edge_feat.shape[0] == num_edges: | |
| feat_to_add = edge_feat.unsqueeze(1) if edge_feat.dim() == 1 else edge_feat | |
| ea_parts.append(feat_to_add) | |
| ea = torch.cat(ea_parts, dim=1) if ea_parts else torch.zeros((ei.shape[1], 0), dtype=torch.float32) | |
| data = Data(x=x, edge_index=ei, edge_attr=ea, num_nodes=n) | |
| data.node_ids = self.node_ids | |
| data.role_sequence = self.role_sequence | |
| if self.p_matrix is not None: | |
| data.p_matrix = self.p_matrix.clone() | |
| return data | |
| def get_edge_features_from_schema(self) -> dict[str, torch.Tensor]: | |
| """Extract edge feature tensors from the saved schema.""" | |
| features = { | |
| "weight": [], | |
| "probability": [], | |
| "trust": [], | |
| } | |
| for eid in self.graph.edge_indices(): | |
| data = self.graph.get_edge_data_by_index(eid) | |
| if isinstance(data, dict): | |
| features["weight"].append(data.get("weight", 1.0)) | |
| features["probability"].append(data.get("probability", 1.0)) | |
| schema = data.get("schema", {}) | |
| cost = schema.get("cost", {}) | |
| features["trust"].append(cost.get("trust", 1.0)) | |
| else: | |
| features["weight"].append(1.0) | |
| features["probability"].append(1.0) | |
| features["trust"].append(1.0) | |
| return {name: torch.tensor(values, dtype=torch.float32) for name, values in features.items()} | |
| def get_node_features_from_schema(self) -> dict[str, torch.Tensor]: | |
| """Extract node feature tensors from the rustworkx data schema.""" | |
| features = { | |
| "trust_score": [], | |
| "quality_score": [], | |
| } | |
| for node_id in self.node_ids: | |
| idx = self.get_node_index(node_id) | |
| if idx is not None: | |
| data = self.graph.get_node_data(idx) | |
| if isinstance(data, dict): | |
| schema = data.get("schema", {}) | |
| features["trust_score"].append(schema.get("trust_score", 1.0)) | |
| features["quality_score"].append(schema.get("quality_score", 1.0)) | |
| else: | |
| features["trust_score"].append(1.0) | |
| features["quality_score"].append(1.0) | |
| else: | |
| features["trust_score"].append(1.0) | |
| features["quality_score"].append(1.0) | |
| return {name: torch.tensor(values, dtype=torch.float32) for name, values in features.items()} | |
| def subgraph(self, node_ids: list[str]) -> "RoleGraph": | |
| """Build a subgraph containing only the selected nodes and their connections.""" | |
| agents = [a for a in self.agents if _get_agent_id(a) in node_ids] | |
| id_set = set(node_ids) | |
| new_graph = rx.PyDiGraph() | |
| idx_map = {} | |
| for agent in agents: | |
| agent_id = _get_agent_id(agent) | |
| if agent_id is None: | |
| continue | |
| old_idx = self.get_node_index(agent_id) | |
| if old_idx is not None: | |
| node_data = self.graph.get_node_data(old_idx) | |
| new_idx = new_graph.add_node(node_data) | |
| idx_map[old_idx] = new_idx | |
| for eid in self.graph.edge_indices(): | |
| s, t = self.graph.get_edge_endpoints_by_index(eid) | |
| if s in idx_map and t in idx_map: | |
| edge_data = self.graph.get_edge_data_by_index(eid) | |
| new_graph.add_edge(idx_map[s], idx_map[t], edge_data) | |
| indices = [self.node_ids.index(nid) for nid in node_ids if nid in self.node_ids] | |
| if indices and self.A_com.numel() > 0: | |
| indices_tensor = torch.tensor(indices) | |
| new_a = self.A_com[indices_tensor][:, indices_tensor] | |
| else: | |
| new_a = torch.zeros((len(agents), len(agents)), dtype=torch.float32) | |
| new_connections = {k: [v for v in vs if v in id_set] for k, vs in self.role_connections.items() if k in id_set} | |
| node_ids_raw = [_get_agent_id(a) for a in agents] | |
| node_ids_filtered = [nid for nid in node_ids_raw if nid is not None] | |
| return RoleGraph( | |
| agents=agents, | |
| node_ids=node_ids_filtered, | |
| role_connections=new_connections, | |
| task_node=self.task_node if self.task_node in id_set else None, | |
| query=self.query, | |
| answer=self.answer, | |
| graph=new_graph, | |
| A_com=new_a, | |
| start_node=self.start_node if self.start_node in id_set else None, | |
| end_node=self.end_node if self.end_node in id_set else None, | |
| ) | |
| def set_start_node(self, node_id: str) -> bool: | |
| """ | |
| Set the start node for execution. | |
| Args: | |
| node_id: ID of the node from which execution starts. | |
| Returns: | |
| True if the node exists and was set. | |
| """ | |
| if node_id not in self.node_ids: | |
| return False | |
| object.__setattr__(self, "start_node", node_id) | |
| return True | |
| def set_end_node(self, node_id: str) -> bool: | |
| """ | |
| Set the end node for execution. | |
| Args: | |
| node_id: ID of the node at which execution ends. | |
| Returns: | |
| True if the node exists and was set. | |
| """ | |
| if node_id not in self.node_ids: | |
| return False | |
| object.__setattr__(self, "end_node", node_id) | |
| return True | |
| def set_execution_bounds(self, start_node: str | None, end_node: str | None) -> bool: | |
| """ | |
| Set start and end nodes simultaneously. | |
| Args: | |
| start_node: ID of the start node (None for auto-detection). | |
| end_node: ID of the end node (None for auto-detection). | |
| Returns: | |
| True if both nodes are valid (or None). | |
| """ | |
| if start_node is not None and start_node not in self.node_ids: | |
| return False | |
| if end_node is not None and end_node not in self.node_ids: | |
| return False | |
| object.__setattr__(self, "start_node", start_node) | |
| object.__setattr__(self, "end_node", end_node) | |
| return True | |
| # ========================================================================= | |
| # INACTIVE NODES (disabled nodes) | |
| # ========================================================================= | |
| def disable(self, node_ids: str | list[str]) -> int: | |
| """ | |
| Deactivate nodes — they remain in the graph but will not be executed. | |
| Args: | |
| node_ids: Node ID or list of node IDs to deactivate. | |
| Returns: | |
| Number of successfully deactivated nodes. | |
| Example: | |
| graph.disable("agent1") # Single node | |
| graph.disable(["a1", "a2", "a3"]) # Multiple nodes | |
| """ | |
| if isinstance(node_ids, str): | |
| node_ids = [node_ids] | |
| count = 0 | |
| for node_id in node_ids: | |
| if node_id in self.node_ids: | |
| self.disabled_nodes.add(node_id) | |
| count += 1 | |
| return count | |
| def enable(self, node_ids: str | list[str] | None = None) -> int: | |
| """ | |
| Activate nodes. | |
| Args: | |
| node_ids: Node ID, list of node IDs, or None to activate all. | |
| Returns: | |
| Number of activated nodes. | |
| Example: | |
| graph.enable("agent1") # Single node | |
| graph.enable(["a1", "a2"]) # Multiple nodes | |
| graph.enable() # All nodes | |
| """ | |
| if node_ids is None: | |
| count = len(self.disabled_nodes) | |
| self.disabled_nodes.clear() | |
| return count | |
| if isinstance(node_ids, str): | |
| node_ids = [node_ids] | |
| count = 0 | |
| for node_id in node_ids: | |
| if node_id in self.disabled_nodes: | |
| self.disabled_nodes.remove(node_id) | |
| count += 1 | |
| return count | |
| def is_enabled(self, node_id: str) -> bool: | |
| """Check whether a node is active.""" | |
| return node_id in self.node_ids and node_id not in self.disabled_nodes | |
| def get_enabled(self) -> list[str]: | |
| """Get the list of active nodes.""" | |
| return [nid for nid in self.node_ids if nid not in self.disabled_nodes] | |
| def get_disabled(self) -> list[str]: | |
| """Get the list of deactivated nodes.""" | |
| return list(self.disabled_nodes) | |
| def get_reachable_from(self, source_id: str, threshold: float = EDGE_THRESHOLD) -> set[str]: | |
| """ | |
| Get all nodes reachable from source_id (forward BFS). | |
| Args: | |
| source_id: ID of the start node. | |
| threshold: Minimum edge weight to consider a connection. | |
| Returns: | |
| Set of reachable node IDs (including source_id). | |
| """ | |
| if source_id not in self.node_ids: | |
| return set() | |
| reachable = {source_id} | |
| queue = deque([source_id]) | |
| while queue: | |
| current = queue.popleft() | |
| current_idx = self.node_ids.index(current) | |
| for j, node_id in enumerate(self.node_ids): | |
| if node_id in reachable: | |
| continue | |
| if self.A_com.numel() > 0 and self.A_com[current_idx, j].item() > threshold: | |
| reachable.add(node_id) | |
| queue.append(node_id) | |
| return reachable | |
| def get_nodes_reaching(self, target_id: str, threshold: float = EDGE_THRESHOLD) -> set[str]: | |
| """ | |
| Get all nodes from which target_id is reachable (backward BFS). | |
| Args: | |
| target_id: ID of the target node. | |
| threshold: Minimum edge weight to consider a connection. | |
| Returns: | |
| Set of node IDs from which target_id is reachable (including target_id itself). | |
| """ | |
| if target_id not in self.node_ids: | |
| return set() | |
| reaching = {target_id} | |
| queue = deque([target_id]) | |
| while queue: | |
| current = queue.popleft() | |
| current_idx = self.node_ids.index(current) | |
| for i, node_id in enumerate(self.node_ids): | |
| if node_id in reaching: | |
| continue | |
| if self.A_com.numel() > 0 and self.A_com[i, current_idx].item() > threshold: | |
| reaching.add(node_id) | |
| queue.append(node_id) | |
| return reaching | |
| def get_relevant_nodes( | |
| self, | |
| start_node: str | None = None, | |
| end_node: str | None = None, | |
| threshold: float = EDGE_THRESHOLD, | |
| ) -> set[str]: | |
| """ | |
| Get nodes that lie on paths from start to end. | |
| This is the intersection of: | |
| - Nodes reachable from start_node | |
| - Nodes from which end_node is reachable | |
| Nodes not in this set are isolated and not needed for execution. | |
| Args: | |
| start_node: ID of the start node (or self.start_node, or the first by order). | |
| end_node: ID of the end node (or self.end_node, or the last by order). | |
| threshold: Minimum edge weight. | |
| Returns: | |
| Set of relevant node IDs. | |
| """ | |
| # Determine start | |
| effective_start = start_node or self.start_node | |
| if effective_start is None and self.node_ids: | |
| # First node with no incoming edges | |
| for node_id in self.node_ids: | |
| idx = self.node_ids.index(node_id) | |
| if self.A_com.numel() > 0: | |
| in_degree = (self.A_com[:, idx] > threshold).sum().item() | |
| if in_degree == 0: | |
| effective_start = node_id | |
| break | |
| if effective_start is None: | |
| effective_start = self.node_ids[0] | |
| # Determine end | |
| effective_end = end_node or self.end_node | |
| if effective_end is None and self.node_ids: | |
| # Last node with no outgoing edges | |
| for node_id in reversed(self.node_ids): | |
| idx = self.node_ids.index(node_id) | |
| if self.A_com.numel() > 0: | |
| out_degree = (self.A_com[idx, :] > threshold).sum().item() | |
| if out_degree == 0: | |
| effective_end = node_id | |
| break | |
| if effective_end is None: | |
| effective_end = self.node_ids[-1] | |
| if effective_start is None or effective_end is None: | |
| return set() | |
| # Intersection of nodes reachable from start and leading to end | |
| reachable_from_start = self.get_reachable_from(effective_start, threshold) | |
| reaching_end = self.get_nodes_reaching(effective_end, threshold) | |
| return reachable_from_start & reaching_end | |
| def get_isolated_nodes( | |
| self, | |
| start_node: str | None = None, | |
| end_node: str | None = None, | |
| threshold: float = EDGE_THRESHOLD, | |
| ) -> set[str]: | |
| """ | |
| Get isolated nodes that do not participate in the start->end path. | |
| These nodes can be excluded from execution to save tokens. | |
| Args: | |
| start_node: ID of the start node. | |
| end_node: ID of the end node. | |
| threshold: Minimum edge weight. | |
| Returns: | |
| Set of isolated node IDs. | |
| """ | |
| relevant = self.get_relevant_nodes(start_node, end_node, threshold) | |
| all_nodes = set(self.node_ids) | |
| return all_nodes - relevant | |
| def get_optimized_execution_order( | |
| self, | |
| start_node: str | None = None, | |
| end_node: str | None = None, | |
| threshold: float = EDGE_THRESHOLD, | |
| ) -> list[str]: | |
| """ | |
| Get the optimised execution order, excluding isolated nodes. | |
| Args: | |
| start_node: ID of the start node. | |
| end_node: ID of the end node. | |
| threshold: Minimum edge weight. | |
| Returns: | |
| List of node IDs in topological order (relevant nodes only). | |
| """ | |
| relevant = self.get_relevant_nodes(start_node, end_node, threshold) | |
| # Topological sort of relevant nodes only | |
| # Build in-degree for relevant nodes | |
| in_degree: dict[str, int] = dict.fromkeys(relevant, 0) | |
| for i, src in enumerate(self.node_ids): | |
| if src not in relevant: | |
| continue | |
| for j, tgt in enumerate(self.node_ids): | |
| if tgt not in relevant: | |
| continue | |
| if self.A_com.numel() > 0 and self.A_com[i, j].item() > threshold: | |
| in_degree[tgt] += 1 | |
| # Kahn's algorithm | |
| queue = deque([node_id for node_id in relevant if in_degree[node_id] == 0]) | |
| result: list[str] = [] | |
| while queue: | |
| current = queue.popleft() | |
| result.append(current) | |
| current_idx = self.node_ids.index(current) | |
| for j, tgt in enumerate(self.node_ids): | |
| if tgt not in relevant or tgt in result: | |
| continue | |
| if self.A_com.numel() > 0 and self.A_com[current_idx, j].item() > threshold: | |
| in_degree[tgt] -= 1 | |
| if in_degree[tgt] == 0: | |
| queue.append(tgt) | |
| # Add remaining nodes (in case of cycles) | |
| for node_id in relevant: | |
| if node_id not in result: | |
| result.append(node_id) | |
| return result | |
| # ========================================================================= | |
| # DATA VALIDATION (input/output schema validation) | |
| # ========================================================================= | |
| def get_agent_schema(self, agent_id: str) -> Any | None: | |
| """ | |
| Get the agent schema from the node data. | |
| Args: | |
| agent_id: Agent ID. | |
| Returns: | |
| AgentNodeSchema or None if not found. | |
| """ | |
| idx = self.get_node_index(agent_id) | |
| if idx is None: | |
| return None | |
| data = self.graph.get_node_data(idx) | |
| if not isinstance(data, dict): | |
| return None | |
| schema_dict = data.get("schema") | |
| if schema_dict is None: | |
| return None | |
| # Restore AgentNodeSchema | |
| from core.schema import AgentNodeSchema, NodeType | |
| if schema_dict.get("type") == NodeType.AGENT.value or schema_dict.get("type") == "agent": | |
| return AgentNodeSchema.model_validate(schema_dict) | |
| return None | |
| def validate_agent_input( | |
| self, | |
| agent_id: str, | |
| data: dict[str, Any] | str, | |
| ) -> Any: | |
| """ | |
| Validate input data for an agent against its input_schema. | |
| Args: | |
| agent_id: Agent ID. | |
| data: Data to validate (dict or JSON string). | |
| Returns: | |
| SchemaValidationResult with the validation result. | |
| Example: | |
| result = graph.validate_agent_input("solver", {"question": "2+2=?"}) | |
| if not result.valid: | |
| print(f"Validation failed: {result.errors}") | |
| """ | |
| from core.schema import SchemaValidationResult | |
| schema = self.get_agent_schema(agent_id) | |
| if schema is None: | |
| return SchemaValidationResult( | |
| valid=True, | |
| schema_type="input", | |
| message=f"No schema found for agent '{agent_id}'", | |
| ) | |
| return schema.validate_input(data) | |
| def validate_agent_output( | |
| self, | |
| agent_id: str, | |
| data: dict[str, Any] | str, | |
| ) -> Any: | |
| """ | |
| Validate output data for an agent against its output_schema. | |
| Args: | |
| agent_id: Agent ID. | |
| data: Data to validate (dict or JSON string). | |
| Returns: | |
| SchemaValidationResult with the validation result. | |
| Example: | |
| result = graph.validate_agent_output("solver", response) | |
| if result.valid: | |
| parsed = result.validated_data | |
| """ | |
| from core.schema import SchemaValidationResult | |
| schema = self.get_agent_schema(agent_id) | |
| if schema is None: | |
| return SchemaValidationResult( | |
| valid=True, | |
| schema_type="output", | |
| message=f"No schema found for agent '{agent_id}'", | |
| ) | |
| return schema.validate_output(data) | |
| def has_input_schema(self, agent_id: str) -> bool: | |
| """Check whether the agent has an input_schema.""" | |
| schema = self.get_agent_schema(agent_id) | |
| return schema is not None and schema.has_input_schema() | |
| def has_output_schema(self, agent_id: str) -> bool: | |
| """Check whether the agent has an output_schema.""" | |
| schema = self.get_agent_schema(agent_id) | |
| return schema is not None and schema.has_output_schema() | |
| def get_input_schema_json(self, agent_id: str) -> dict[str, Any] | None: | |
| """ | |
| Get the JSON Schema for the agent's input data. | |
| Useful for generating prompts describing the expected format. | |
| """ | |
| schema = self.get_agent_schema(agent_id) | |
| if schema is None: | |
| return None | |
| return schema.input_schema_json | |
| def get_output_schema_json(self, agent_id: str) -> dict[str, Any] | None: | |
| """ | |
| Get the JSON Schema for the agent's output data. | |
| Useful for generating prompts describing the expected response format. | |
| """ | |
| schema = self.get_agent_schema(agent_id) | |
| if schema is None: | |
| return None | |
| return schema.output_schema_json | |