File size: 5,302 Bytes
7b2787b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""
State Management for Workflow Engine.

This module provides the state management system that flows through the workflow.
State is immutable - each node receives state and returns a new modified state.
"""

from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
from datetime import datetime
from copy import deepcopy
import uuid


class StateSnapshot(BaseModel):
    """A snapshot of state at a specific point in execution."""
    
    timestamp: datetime = Field(default_factory=datetime.now)
    node_name: str
    state_data: Dict[str, Any]
    iteration: int = 0


class WorkflowState(BaseModel):
    """
    The shared state that flows through the workflow.
    
    This is a flexible container that holds all data being processed
    by the workflow nodes. Each node can read from and write to this state.
    
    Attributes:
        data: The actual workflow data (flexible dictionary)
        metadata: Execution metadata (iteration count, visited nodes, etc.)
    """
    
    # The actual data being processed
    data: Dict[str, Any] = Field(default_factory=dict)
    
    # Execution metadata
    current_node: Optional[str] = None
    iteration: int = 0
    visited_nodes: List[str] = Field(default_factory=list)
    started_at: Optional[datetime] = None
    completed_at: Optional[datetime] = None
    
    class Config:
        arbitrary_types_allowed = True
    
    def get(self, key: str, default: Any = None) -> Any:
        """Get a value from the state data."""
        return self.data.get(key, default)
    
    def set(self, key: str, value: Any) -> "WorkflowState":
        """Set a value in state data and return a new state (immutable pattern)."""
        new_data = deepcopy(self.data)
        new_data[key] = value
        return self.model_copy(update={"data": new_data})
    
    def update(self, updates: Dict[str, Any]) -> "WorkflowState":
        """Update multiple values and return a new state."""
        new_data = deepcopy(self.data)
        new_data.update(updates)
        return self.model_copy(update={"data": new_data})
    
    def mark_visited(self, node_name: str) -> "WorkflowState":
        """Mark a node as visited."""
        new_visited = self.visited_nodes + [node_name]
        return self.model_copy(update={
            "visited_nodes": new_visited,
            "current_node": node_name
        })
    
    def increment_iteration(self) -> "WorkflowState":
        """Increment the iteration counter."""
        return self.model_copy(update={"iteration": self.iteration + 1})
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert state to a plain dictionary."""
        return {
            "data": self.data,
            "current_node": self.current_node,
            "iteration": self.iteration,
            "visited_nodes": self.visited_nodes,
            "started_at": self.started_at.isoformat() if self.started_at else None,
            "completed_at": self.completed_at.isoformat() if self.completed_at else None,
        }
    
    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "WorkflowState":
        """Create a WorkflowState from a dictionary."""
        if "data" in data:
            return cls(**data)
        # If it's just raw data, wrap it
        return cls(data=data)


class StateManager:
    """
    Manages state history and snapshots for a workflow run.
    
    This provides debugging capabilities by tracking state changes
    throughout the workflow execution.
    """
    
    def __init__(self, run_id: Optional[str] = None):
        self.run_id = run_id or str(uuid.uuid4())
        self.history: List[StateSnapshot] = []
        self._current_state: Optional[WorkflowState] = None
    
    @property
    def current_state(self) -> Optional[WorkflowState]:
        """Get the current state."""
        return self._current_state
    
    def initialize(self, initial_data: Dict[str, Any]) -> WorkflowState:
        """Initialize the state manager with initial data."""
        self._current_state = WorkflowState(
            data=initial_data,
            started_at=datetime.now()
        )
        return self._current_state
    
    def update(self, new_state: WorkflowState, node_name: str) -> None:
        """Update the current state and record a snapshot."""
        # Record snapshot
        snapshot = StateSnapshot(
            node_name=node_name,
            state_data=deepcopy(new_state.data),
            iteration=new_state.iteration
        )
        self.history.append(snapshot)
        
        # Update current state
        self._current_state = new_state
    
    def finalize(self) -> WorkflowState:
        """Mark the workflow as complete."""
        if self._current_state:
            self._current_state = self._current_state.model_copy(
                update={"completed_at": datetime.now()}
            )
        return self._current_state
    
    def get_history(self) -> List[Dict[str, Any]]:
        """Get the state history as a list of dictionaries."""
        return [
            {
                "timestamp": s.timestamp.isoformat(),
                "node": s.node_name,
                "iteration": s.iteration,
                "state": s.state_data
            }
            for s in self.history
        ]