File size: 12,081 Bytes
7b2787b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
"""
Graph Definition for Workflow Engine.

The Graph is the core structure that defines the workflow - nodes, edges,
conditional routing, and execution flow.
"""

from typing import Any, Callable, Dict, List, Optional, Set, Union
from dataclasses import dataclass, field
from enum import Enum
import uuid

from app.engine.node import Node, NodeType, get_registered_node, create_node_from_function


# Special node names
END = "__END__"
START = "__START__"


class EdgeType(str, Enum):
    """Types of edges between nodes."""
    DIRECT = "direct"           # Always follow this edge
    CONDITIONAL = "conditional"  # Choose based on condition


@dataclass
class Edge:
    """An edge connecting two nodes."""
    source: str
    target: str
    edge_type: EdgeType = EdgeType.DIRECT
    
    def to_dict(self) -> Dict[str, str]:
        return {
            "source": self.source,
            "target": self.target,
            "type": self.edge_type.value
        }


@dataclass
class ConditionalEdge:
    """
    A conditional edge that routes to different nodes based on a condition.
    
    The condition function receives the current state and returns a route key.
    The routes dict maps route keys to target node names.
    """
    source: str
    condition: Callable[[Dict[str, Any]], str]
    routes: Dict[str, str]  # route_key -> target_node_name
    
    def evaluate(self, state_data: Dict[str, Any]) -> str:
        """Evaluate the condition and return the target node name."""
        route_key = self.condition(state_data)
        if route_key not in self.routes:
            raise ValueError(
                f"Condition returned unknown route '{route_key}'. "
                f"Available routes: {list(self.routes.keys())}"
            )
        return self.routes[route_key]
    
    def to_dict(self) -> Dict[str, Any]:
        return {
            "source": self.source,
            "condition": self.condition.__name__ if hasattr(self.condition, '__name__') else str(self.condition),
            "routes": self.routes
        }


@dataclass
class Graph:
    """
    A workflow graph consisting of nodes and edges.
    
    The graph defines the structure of a workflow:
    - Nodes: Processing units that transform state
    - Edges: Connections between nodes
    - Conditional Edges: Branching logic based on state
    
    Attributes:
        graph_id: Unique identifier for this graph
        name: Human-readable name
        nodes: Dict of node_name -> Node
        edges: List of direct edges
        conditional_edges: Dict of source_node -> ConditionalEdge
        entry_point: Name of the first node to execute
        max_iterations: Maximum loop iterations allowed
    """
    
    graph_id: str = field(default_factory=lambda: str(uuid.uuid4()))
    name: str = "Unnamed Workflow"
    nodes: Dict[str, Node] = field(default_factory=dict)
    edges: Dict[str, str] = field(default_factory=dict)  # source -> target for direct edges
    conditional_edges: Dict[str, ConditionalEdge] = field(default_factory=dict)
    entry_point: Optional[str] = None
    max_iterations: int = 100
    description: str = ""
    metadata: Dict[str, Any] = field(default_factory=dict)
    
    def add_node(
        self,
        name: str,
        handler: Optional[Callable] = None,
        node_type: NodeType = NodeType.STANDARD,
        description: str = ""
    ) -> "Graph":
        """
        Add a node to the graph.
        
        If handler is not provided, attempts to find a registered node
        with the given name.
        
        Args:
            name: Unique name for the node
            handler: Function to execute (optional if registered)
            node_type: Type of node
            description: Human-readable description
            
        Returns:
            Self for chaining
        """
        if handler is None:
            # Try to find a registered handler
            handler = get_registered_node(name)
            if handler is None:
                raise ValueError(
                    f"No handler provided for node '{name}' and no registered "
                    f"node found with that name"
                )
        
        if name in self.nodes:
            raise ValueError(f"Node '{name}' already exists in the graph")
        
        node = create_node_from_function(handler, name, node_type, description)
        self.nodes[name] = node
        
        # Set as entry point if it's the first node or marked as entry
        if self.entry_point is None or node_type == NodeType.ENTRY:
            self.entry_point = name
        
        return self
    
    def add_edge(self, source: str, target: str) -> "Graph":
        """
        Add a direct edge from source to target.
        
        Args:
            source: Source node name
            target: Target node name (or END)
            
        Returns:
            Self for chaining
        """
        if source not in self.nodes:
            raise ValueError(f"Source node '{source}' not found in graph")
        if target != END and target not in self.nodes:
            raise ValueError(f"Target node '{target}' not found in graph")
        
        # Check for conflicts with conditional edges
        if source in self.conditional_edges:
            raise ValueError(
                f"Node '{source}' already has a conditional edge. "
                f"Cannot add a direct edge."
            )
        
        self.edges[source] = target
        return self
    
    def add_conditional_edge(
        self,
        source: str,
        condition: Callable[[Dict[str, Any]], str],
        routes: Dict[str, str]
    ) -> "Graph":
        """
        Add a conditional edge from source node.
        
        The condition function receives state and returns a route key.
        
        Args:
            source: Source node name
            condition: Function that returns route key
            routes: Dict mapping route keys to target nodes
            
        Returns:
            Self for chaining
        """
        if source not in self.nodes:
            raise ValueError(f"Source node '{source}' not found in graph")
        
        # Validate all targets
        for route_key, target in routes.items():
            if target != END and target not in self.nodes:
                raise ValueError(
                    f"Target node '{target}' for route '{route_key}' not found in graph"
                )
        
        # Check for conflicts with direct edges
        if source in self.edges:
            raise ValueError(
                f"Node '{source}' already has a direct edge. "
                f"Cannot add a conditional edge."
            )
        
        self.conditional_edges[source] = ConditionalEdge(
            source=source,
            condition=condition,
            routes=routes
        )
        return self
    
    def set_entry_point(self, node_name: str) -> "Graph":
        """Set the entry point of the graph."""
        if node_name not in self.nodes:
            raise ValueError(f"Node '{node_name}' not found in graph")
        self.entry_point = node_name
        return self
    
    def get_next_node(self, current_node: str, state_data: Dict[str, Any]) -> Optional[str]:
        """
        Get the next node to execute based on edges and state.
        
        Args:
            current_node: Current node name
            state_data: Current state data
            
        Returns:
            Next node name, END, or None if no edge defined
        """
        # Check for conditional edge first
        if current_node in self.conditional_edges:
            conditional = self.conditional_edges[current_node]
            return conditional.evaluate(state_data)
        
        # Check for direct edge
        if current_node in self.edges:
            return self.edges[current_node]
        
        # No edge defined - implicit end
        return None
    
    def validate(self) -> List[str]:
        """
        Validate the graph structure.
        
        Returns:
            List of validation errors (empty if valid)
        """
        errors = []
        
        # Must have at least one node
        if not self.nodes:
            errors.append("Graph must have at least one node")
            return errors
        
        # Must have an entry point
        if not self.entry_point:
            errors.append("Graph must have an entry point")
        elif self.entry_point not in self.nodes:
            errors.append(f"Entry point '{self.entry_point}' not found in nodes")
        
        # Check for orphan nodes (not reachable from entry point)
        reachable = self._get_reachable_nodes()
        orphans = set(self.nodes.keys()) - reachable
        if orphans:
            errors.append(f"Orphan nodes (not reachable): {orphans}")
        
        # Check that nodes without outgoing edges make sense
        for node_name in self.nodes:
            if node_name not in self.edges and node_name not in self.conditional_edges:
                # This is an implicit end node - that's okay
                pass
        
        return errors
    
    def _get_reachable_nodes(self) -> Set[str]:
        """Get all nodes reachable from the entry point."""
        if not self.entry_point:
            return set()
        
        reachable = set()
        to_visit = [self.entry_point]
        
        while to_visit:
            node = to_visit.pop()
            if node in reachable or node == END:
                continue
            
            reachable.add(node)
            
            # Add direct edge target
            if node in self.edges:
                to_visit.append(self.edges[node])
            
            # Add conditional edge targets
            if node in self.conditional_edges:
                for target in self.conditional_edges[node].routes.values():
                    to_visit.append(target)
        
        return reachable
    
    def to_dict(self) -> Dict[str, Any]:
        """Serialize the graph to a dictionary."""
        return {
            "graph_id": self.graph_id,
            "name": self.name,
            "description": self.description,
            "nodes": {name: node.to_dict() for name, node in self.nodes.items()},
            "edges": self.edges,
            "conditional_edges": {
                name: edge.to_dict() 
                for name, edge in self.conditional_edges.items()
            },
            "entry_point": self.entry_point,
            "max_iterations": self.max_iterations,
            "metadata": self.metadata,
        }
    
    def to_mermaid(self) -> str:
        """Generate a Mermaid diagram of the graph."""
        lines = ["graph TD"]
        
        # Add nodes
        for name, node in self.nodes.items():
            label = name.replace("_", " ").title()
            if node.node_type == NodeType.ENTRY:
                lines.append(f'    {name}["{label} ๐Ÿš€"]')
            elif node.node_type == NodeType.EXIT:
                lines.append(f'    {name}["{label} ๐Ÿ"]')
            else:
                lines.append(f'    {name}["{label}"]')
        
        # Add END node if used
        has_end = END in self.edges.values()
        for cond in self.conditional_edges.values():
            if END in cond.routes.values():
                has_end = True
                break
        
        if has_end:
            lines.append(f'    {END}(("END"))')
        
        # Add direct edges
        for source, target in self.edges.items():
            lines.append(f"    {source} --> {target}")
        
        # Add conditional edges
        for source, cond in self.conditional_edges.items():
            for route_key, target in cond.routes.items():
                lines.append(f"    {source} -->|{route_key}| {target}")
        
        return "\n".join(lines)
    
    def __repr__(self) -> str:
        return (
            f"Graph(name='{self.name}', nodes={list(self.nodes.keys())}, "
            f"entry='{self.entry_point}')"
        )