File size: 5,302 Bytes
7b2787b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
"""
State Management for Workflow Engine.
This module provides the state management system that flows through the workflow.
State is immutable - each node receives state and returns a new modified state.
"""
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
from datetime import datetime
from copy import deepcopy
import uuid
class StateSnapshot(BaseModel):
"""A snapshot of state at a specific point in execution."""
timestamp: datetime = Field(default_factory=datetime.now)
node_name: str
state_data: Dict[str, Any]
iteration: int = 0
class WorkflowState(BaseModel):
"""
The shared state that flows through the workflow.
This is a flexible container that holds all data being processed
by the workflow nodes. Each node can read from and write to this state.
Attributes:
data: The actual workflow data (flexible dictionary)
metadata: Execution metadata (iteration count, visited nodes, etc.)
"""
# The actual data being processed
data: Dict[str, Any] = Field(default_factory=dict)
# Execution metadata
current_node: Optional[str] = None
iteration: int = 0
visited_nodes: List[str] = Field(default_factory=list)
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
class Config:
arbitrary_types_allowed = True
def get(self, key: str, default: Any = None) -> Any:
"""Get a value from the state data."""
return self.data.get(key, default)
def set(self, key: str, value: Any) -> "WorkflowState":
"""Set a value in state data and return a new state (immutable pattern)."""
new_data = deepcopy(self.data)
new_data[key] = value
return self.model_copy(update={"data": new_data})
def update(self, updates: Dict[str, Any]) -> "WorkflowState":
"""Update multiple values and return a new state."""
new_data = deepcopy(self.data)
new_data.update(updates)
return self.model_copy(update={"data": new_data})
def mark_visited(self, node_name: str) -> "WorkflowState":
"""Mark a node as visited."""
new_visited = self.visited_nodes + [node_name]
return self.model_copy(update={
"visited_nodes": new_visited,
"current_node": node_name
})
def increment_iteration(self) -> "WorkflowState":
"""Increment the iteration counter."""
return self.model_copy(update={"iteration": self.iteration + 1})
def to_dict(self) -> Dict[str, Any]:
"""Convert state to a plain dictionary."""
return {
"data": self.data,
"current_node": self.current_node,
"iteration": self.iteration,
"visited_nodes": self.visited_nodes,
"started_at": self.started_at.isoformat() if self.started_at else None,
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "WorkflowState":
"""Create a WorkflowState from a dictionary."""
if "data" in data:
return cls(**data)
# If it's just raw data, wrap it
return cls(data=data)
class StateManager:
"""
Manages state history and snapshots for a workflow run.
This provides debugging capabilities by tracking state changes
throughout the workflow execution.
"""
def __init__(self, run_id: Optional[str] = None):
self.run_id = run_id or str(uuid.uuid4())
self.history: List[StateSnapshot] = []
self._current_state: Optional[WorkflowState] = None
@property
def current_state(self) -> Optional[WorkflowState]:
"""Get the current state."""
return self._current_state
def initialize(self, initial_data: Dict[str, Any]) -> WorkflowState:
"""Initialize the state manager with initial data."""
self._current_state = WorkflowState(
data=initial_data,
started_at=datetime.now()
)
return self._current_state
def update(self, new_state: WorkflowState, node_name: str) -> None:
"""Update the current state and record a snapshot."""
# Record snapshot
snapshot = StateSnapshot(
node_name=node_name,
state_data=deepcopy(new_state.data),
iteration=new_state.iteration
)
self.history.append(snapshot)
# Update current state
self._current_state = new_state
def finalize(self) -> WorkflowState:
"""Mark the workflow as complete."""
if self._current_state:
self._current_state = self._current_state.model_copy(
update={"completed_at": datetime.now()}
)
return self._current_state
def get_history(self) -> List[Dict[str, Any]]:
"""Get the state history as a list of dictionaries."""
return [
{
"timestamp": s.timestamp.isoformat(),
"node": s.node_name,
"iteration": s.iteration,
"state": s.state_data
}
for s in self.history
]
|