File size: 7,067 Bytes
06ba7ea | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 | from __future__ import annotations
from collections import defaultdict
from typing import Any, Dict, List, Optional, Set
from langchain_core.tools.structured import StructuredTool
from src.open_storyline.storage.agent_memory import ArtifactStore
class NodeManager:
def __init__(self, tools: List[StructuredTool] = None):
self.kind_to_node_ids: Dict[str, List[str]] = defaultdict(list) # node_kind -> list of node_ids (sorted)
self.id_to_tool: Dict[str, StructuredTool] = {} # node_id -> StructuredTool
self.id_to_next: Dict[str, List[str]] = {} # node_id -> list of next executable node_ids
self.id_to_priority: Dict[str, int] = {} # node_id -> priority
self.id_to_kind: Dict[str, str] = {} # node_id -> node_kind
# New: Prerequisite dependency related
self.id_to_require_prior_kind: Dict[str, List[str]] = {} # node_id -> required prerequisite features when executing auto method
self.id_to_default_require_prior_kind: Dict[str, List[str]] = {} # node_id -> prerequisite features needed for default method execution
# Reverse index: which nodes depend on a specific kind
self.kind_to_dependent_nodes: Dict[str, Set[str]] = defaultdict(set) # kind -> set of node_ids that depend on this feature
self.kind_to_default_dependent_nodes: Dict[str, Set[str]] = defaultdict(set) # kind -> set of node_ids whose default method depends on this feature
if tools:
self._build(tools)
def _build(self, tools: List[StructuredTool]):
for tool in tools:
if tool.metadata:
metadata = tool.metadata.get('_meta', {})
node_id = metadata.get('node_id')
if node_id:
self.add_node(tool)
def add_node(self, tool: StructuredTool) -> bool:
# metadata is None, failed to add node
if not tool.metadata:
return False
metadata = tool.metadata.get('_meta', {})
node_id = metadata.get('node_id')
if not node_id:
return False
if node_id in self.id_to_tool:
self.remove_node(node_id)
node_kind = metadata.get('node_kind', node_id)
priority = metadata.get('priority', 0)
next_nodes = metadata.get('next_available_node', [])
require_prior_kind = metadata.get('require_prior_kind', [])
default_require_prior_kind = metadata.get('default_require_prior_kind', [])
# Update dependencies
self.id_to_tool[node_id] = tool
self.id_to_priority[node_id] = priority
self.id_to_next[node_id] = next_nodes
self.id_to_kind[node_id] = node_kind
self.id_to_require_prior_kind[node_id] = require_prior_kind
self.id_to_default_require_prior_kind[node_id] = default_require_prior_kind
# Add to kind_to_node_ids and re-sort
self.kind_to_node_ids[node_kind].append(node_id)
self._sort_kind(node_kind)
# Update reverse index
for kind in require_prior_kind:
self.kind_to_dependent_nodes[kind].add(node_id)
for kind in default_require_prior_kind:
self.kind_to_default_dependent_nodes[kind].add(node_id)
return True
def remove_node(self, node_id: str, clean_references: bool = True) -> bool:
"""
Delete a node, not used for the time being.
Args:
node_id: ID of the node to delete
clean_references: Whether to clean up references to this node from other nodes
"""
if node_id not in self.id_to_tool:
return False
node_kind = self.id_to_kind[node_id]
# Clean up reverse index
if node_id in self.id_to_require_prior_kind:
for kind in self.id_to_require_prior_kind[node_id]:
self.kind_to_dependent_nodes[kind].discard(node_id)
if not self.kind_to_dependent_nodes[kind]:
del self.kind_to_dependent_nodes[kind]
if node_id in self.id_to_default_require_prior_kind:
for kind in self.id_to_default_require_prior_kind[node_id]:
self.kind_to_default_dependent_nodes[kind].discard(node_id)
if not self.kind_to_default_dependent_nodes[kind]:
del self.kind_to_default_dependent_nodes[kind]
del self.id_to_tool[node_id]
del self.id_to_priority[node_id]
del self.id_to_next[node_id]
del self.id_to_kind[node_id]
if node_id in self.id_to_require_prior_kind:
del self.id_to_require_prior_kind[node_id]
if node_id in self.id_to_default_require_prior_kind:
del self.id_to_default_require_prior_kind[node_id]
# Remove from kind group
if node_id in self.kind_to_node_ids[node_kind]:
self.kind_to_node_ids[node_kind].remove(node_id)
# If no nodes left for this kind, remove the kind
if not self.kind_to_node_ids[node_kind]:
del self.kind_to_node_ids[node_kind]
# Remove references to this node in other nodes
if clean_references:
for nid in list(self.id_to_next.keys()):
if node_id in self.id_to_next[nid]:
self.id_to_next[nid].remove(node_id)
return True
def _sort_kind(self, kind: str):
"""Sort node list for specified kind by priority"""
if kind in self.kind_to_node_ids:
self.kind_to_node_ids[kind].sort(
key=lambda nid: self.id_to_priority[nid],
reverse=True
)
def get_tool(self, node_id: str) -> Optional[StructuredTool]:
"""Get tool by node_id"""
return self.id_to_tool.get(node_id)
def check_excutable(self, session_id:str, store: ArtifactStore, all_require_kind: List[str]) -> Dict[str, Any]:
"""
Check if executable and return unexecuted features
"""
collected_output = {}
for req_kind in all_require_kind:
req_ids_queue = self.kind_to_node_ids[req_kind]
# 1. Collect latest outputs from all nodes
valid_outputs = []
for node_id in req_ids_queue:
output = store.get_latest_meta(node_id=node_id, session_id=session_id)
if output is not None:
valid_outputs.append(output)
# 2. Identify the most recently created output
if valid_outputs:
latest_output = max(valid_outputs, key=lambda output: output.created_at)
collected_output[req_kind] = latest_output
return {
"excutable": len(collected_output.keys())==len(all_require_kind),
"collected_node": collected_output,
"missing_kind": list(set(all_require_kind) - set(collected_output.keys()))
}
|