|
|
""" |
|
|
Graph Definition for Workflow Engine. |
|
|
|
|
|
The Graph is the core structure that defines the workflow - nodes, edges, |
|
|
conditional routing, and execution flow. |
|
|
""" |
|
|
|
|
|
from typing import Any, Callable, Dict, List, Optional, Set, Union |
|
|
from dataclasses import dataclass, field |
|
|
from enum import Enum |
|
|
import uuid |
|
|
|
|
|
from app.engine.node import Node, NodeType, get_registered_node, create_node_from_function |
|
|
|
|
|
|
|
|
|
|
|
END = "__END__" |
|
|
START = "__START__" |
|
|
|
|
|
|
|
|
class EdgeType(str, Enum): |
|
|
"""Types of edges between nodes.""" |
|
|
DIRECT = "direct" |
|
|
CONDITIONAL = "conditional" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Edge: |
|
|
"""An edge connecting two nodes.""" |
|
|
source: str |
|
|
target: str |
|
|
edge_type: EdgeType = EdgeType.DIRECT |
|
|
|
|
|
def to_dict(self) -> Dict[str, str]: |
|
|
return { |
|
|
"source": self.source, |
|
|
"target": self.target, |
|
|
"type": self.edge_type.value |
|
|
} |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ConditionalEdge: |
|
|
""" |
|
|
A conditional edge that routes to different nodes based on a condition. |
|
|
|
|
|
The condition function receives the current state and returns a route key. |
|
|
The routes dict maps route keys to target node names. |
|
|
""" |
|
|
source: str |
|
|
condition: Callable[[Dict[str, Any]], str] |
|
|
routes: Dict[str, str] |
|
|
|
|
|
def evaluate(self, state_data: Dict[str, Any]) -> str: |
|
|
"""Evaluate the condition and return the target node name.""" |
|
|
route_key = self.condition(state_data) |
|
|
if route_key not in self.routes: |
|
|
raise ValueError( |
|
|
f"Condition returned unknown route '{route_key}'. " |
|
|
f"Available routes: {list(self.routes.keys())}" |
|
|
) |
|
|
return self.routes[route_key] |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
return { |
|
|
"source": self.source, |
|
|
"condition": self.condition.__name__ if hasattr(self.condition, '__name__') else str(self.condition), |
|
|
"routes": self.routes |
|
|
} |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Graph: |
|
|
""" |
|
|
A workflow graph consisting of nodes and edges. |
|
|
|
|
|
The graph defines the structure of a workflow: |
|
|
- Nodes: Processing units that transform state |
|
|
- Edges: Connections between nodes |
|
|
- Conditional Edges: Branching logic based on state |
|
|
|
|
|
Attributes: |
|
|
graph_id: Unique identifier for this graph |
|
|
name: Human-readable name |
|
|
nodes: Dict of node_name -> Node |
|
|
edges: List of direct edges |
|
|
conditional_edges: Dict of source_node -> ConditionalEdge |
|
|
entry_point: Name of the first node to execute |
|
|
max_iterations: Maximum loop iterations allowed |
|
|
""" |
|
|
|
|
|
graph_id: str = field(default_factory=lambda: str(uuid.uuid4())) |
|
|
name: str = "Unnamed Workflow" |
|
|
nodes: Dict[str, Node] = field(default_factory=dict) |
|
|
edges: Dict[str, str] = field(default_factory=dict) |
|
|
conditional_edges: Dict[str, ConditionalEdge] = field(default_factory=dict) |
|
|
entry_point: Optional[str] = None |
|
|
max_iterations: int = 100 |
|
|
description: str = "" |
|
|
metadata: Dict[str, Any] = field(default_factory=dict) |
|
|
|
|
|
def add_node( |
|
|
self, |
|
|
name: str, |
|
|
handler: Optional[Callable] = None, |
|
|
node_type: NodeType = NodeType.STANDARD, |
|
|
description: str = "" |
|
|
) -> "Graph": |
|
|
""" |
|
|
Add a node to the graph. |
|
|
|
|
|
If handler is not provided, attempts to find a registered node |
|
|
with the given name. |
|
|
|
|
|
Args: |
|
|
name: Unique name for the node |
|
|
handler: Function to execute (optional if registered) |
|
|
node_type: Type of node |
|
|
description: Human-readable description |
|
|
|
|
|
Returns: |
|
|
Self for chaining |
|
|
""" |
|
|
if handler is None: |
|
|
|
|
|
handler = get_registered_node(name) |
|
|
if handler is None: |
|
|
raise ValueError( |
|
|
f"No handler provided for node '{name}' and no registered " |
|
|
f"node found with that name" |
|
|
) |
|
|
|
|
|
if name in self.nodes: |
|
|
raise ValueError(f"Node '{name}' already exists in the graph") |
|
|
|
|
|
node = create_node_from_function(handler, name, node_type, description) |
|
|
self.nodes[name] = node |
|
|
|
|
|
|
|
|
if self.entry_point is None or node_type == NodeType.ENTRY: |
|
|
self.entry_point = name |
|
|
|
|
|
return self |
|
|
|
|
|
def add_edge(self, source: str, target: str) -> "Graph": |
|
|
""" |
|
|
Add a direct edge from source to target. |
|
|
|
|
|
Args: |
|
|
source: Source node name |
|
|
target: Target node name (or END) |
|
|
|
|
|
Returns: |
|
|
Self for chaining |
|
|
""" |
|
|
if source not in self.nodes: |
|
|
raise ValueError(f"Source node '{source}' not found in graph") |
|
|
if target != END and target not in self.nodes: |
|
|
raise ValueError(f"Target node '{target}' not found in graph") |
|
|
|
|
|
|
|
|
if source in self.conditional_edges: |
|
|
raise ValueError( |
|
|
f"Node '{source}' already has a conditional edge. " |
|
|
f"Cannot add a direct edge." |
|
|
) |
|
|
|
|
|
self.edges[source] = target |
|
|
return self |
|
|
|
|
|
def add_conditional_edge( |
|
|
self, |
|
|
source: str, |
|
|
condition: Callable[[Dict[str, Any]], str], |
|
|
routes: Dict[str, str] |
|
|
) -> "Graph": |
|
|
""" |
|
|
Add a conditional edge from source node. |
|
|
|
|
|
The condition function receives state and returns a route key. |
|
|
|
|
|
Args: |
|
|
source: Source node name |
|
|
condition: Function that returns route key |
|
|
routes: Dict mapping route keys to target nodes |
|
|
|
|
|
Returns: |
|
|
Self for chaining |
|
|
""" |
|
|
if source not in self.nodes: |
|
|
raise ValueError(f"Source node '{source}' not found in graph") |
|
|
|
|
|
|
|
|
for route_key, target in routes.items(): |
|
|
if target != END and target not in self.nodes: |
|
|
raise ValueError( |
|
|
f"Target node '{target}' for route '{route_key}' not found in graph" |
|
|
) |
|
|
|
|
|
|
|
|
if source in self.edges: |
|
|
raise ValueError( |
|
|
f"Node '{source}' already has a direct edge. " |
|
|
f"Cannot add a conditional edge." |
|
|
) |
|
|
|
|
|
self.conditional_edges[source] = ConditionalEdge( |
|
|
source=source, |
|
|
condition=condition, |
|
|
routes=routes |
|
|
) |
|
|
return self |
|
|
|
|
|
def set_entry_point(self, node_name: str) -> "Graph": |
|
|
"""Set the entry point of the graph.""" |
|
|
if node_name not in self.nodes: |
|
|
raise ValueError(f"Node '{node_name}' not found in graph") |
|
|
self.entry_point = node_name |
|
|
return self |
|
|
|
|
|
def get_next_node(self, current_node: str, state_data: Dict[str, Any]) -> Optional[str]: |
|
|
""" |
|
|
Get the next node to execute based on edges and state. |
|
|
|
|
|
Args: |
|
|
current_node: Current node name |
|
|
state_data: Current state data |
|
|
|
|
|
Returns: |
|
|
Next node name, END, or None if no edge defined |
|
|
""" |
|
|
|
|
|
if current_node in self.conditional_edges: |
|
|
conditional = self.conditional_edges[current_node] |
|
|
return conditional.evaluate(state_data) |
|
|
|
|
|
|
|
|
if current_node in self.edges: |
|
|
return self.edges[current_node] |
|
|
|
|
|
|
|
|
return None |
|
|
|
|
|
def validate(self) -> List[str]: |
|
|
""" |
|
|
Validate the graph structure. |
|
|
|
|
|
Returns: |
|
|
List of validation errors (empty if valid) |
|
|
""" |
|
|
errors = [] |
|
|
|
|
|
|
|
|
if not self.nodes: |
|
|
errors.append("Graph must have at least one node") |
|
|
return errors |
|
|
|
|
|
|
|
|
if not self.entry_point: |
|
|
errors.append("Graph must have an entry point") |
|
|
elif self.entry_point not in self.nodes: |
|
|
errors.append(f"Entry point '{self.entry_point}' not found in nodes") |
|
|
|
|
|
|
|
|
reachable = self._get_reachable_nodes() |
|
|
orphans = set(self.nodes.keys()) - reachable |
|
|
if orphans: |
|
|
errors.append(f"Orphan nodes (not reachable): {orphans}") |
|
|
|
|
|
|
|
|
for node_name in self.nodes: |
|
|
if node_name not in self.edges and node_name not in self.conditional_edges: |
|
|
|
|
|
pass |
|
|
|
|
|
return errors |
|
|
|
|
|
def _get_reachable_nodes(self) -> Set[str]: |
|
|
"""Get all nodes reachable from the entry point.""" |
|
|
if not self.entry_point: |
|
|
return set() |
|
|
|
|
|
reachable = set() |
|
|
to_visit = [self.entry_point] |
|
|
|
|
|
while to_visit: |
|
|
node = to_visit.pop() |
|
|
if node in reachable or node == END: |
|
|
continue |
|
|
|
|
|
reachable.add(node) |
|
|
|
|
|
|
|
|
if node in self.edges: |
|
|
to_visit.append(self.edges[node]) |
|
|
|
|
|
|
|
|
if node in self.conditional_edges: |
|
|
for target in self.conditional_edges[node].routes.values(): |
|
|
to_visit.append(target) |
|
|
|
|
|
return reachable |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
"""Serialize the graph to a dictionary.""" |
|
|
return { |
|
|
"graph_id": self.graph_id, |
|
|
"name": self.name, |
|
|
"description": self.description, |
|
|
"nodes": {name: node.to_dict() for name, node in self.nodes.items()}, |
|
|
"edges": self.edges, |
|
|
"conditional_edges": { |
|
|
name: edge.to_dict() |
|
|
for name, edge in self.conditional_edges.items() |
|
|
}, |
|
|
"entry_point": self.entry_point, |
|
|
"max_iterations": self.max_iterations, |
|
|
"metadata": self.metadata, |
|
|
} |
|
|
|
|
|
def to_mermaid(self) -> str: |
|
|
"""Generate a Mermaid diagram of the graph.""" |
|
|
lines = ["graph TD"] |
|
|
|
|
|
|
|
|
for name, node in self.nodes.items(): |
|
|
label = name.replace("_", " ").title() |
|
|
if node.node_type == NodeType.ENTRY: |
|
|
lines.append(f' {name}["{label} ๐"]') |
|
|
elif node.node_type == NodeType.EXIT: |
|
|
lines.append(f' {name}["{label} ๐"]') |
|
|
else: |
|
|
lines.append(f' {name}["{label}"]') |
|
|
|
|
|
|
|
|
has_end = END in self.edges.values() |
|
|
for cond in self.conditional_edges.values(): |
|
|
if END in cond.routes.values(): |
|
|
has_end = True |
|
|
break |
|
|
|
|
|
if has_end: |
|
|
lines.append(f' {END}(("END"))') |
|
|
|
|
|
|
|
|
for source, target in self.edges.items(): |
|
|
lines.append(f" {source} --> {target}") |
|
|
|
|
|
|
|
|
for source, cond in self.conditional_edges.items(): |
|
|
for route_key, target in cond.routes.items(): |
|
|
lines.append(f" {source} -->|{route_key}| {target}") |
|
|
|
|
|
return "\n".join(lines) |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
return ( |
|
|
f"Graph(name='{self.name}', nodes={list(self.nodes.keys())}, " |
|
|
f"entry='{self.entry_point}')" |
|
|
) |
|
|
|