|
|
""" |
|
|
State Management for Workflow Engine. |
|
|
|
|
|
This module provides the state management system that flows through the workflow. |
|
|
State is immutable - each node receives state and returns a new modified state. |
|
|
""" |
|
|
|
|
|
from typing import Any, Dict, List, Optional |
|
|
from pydantic import BaseModel, Field |
|
|
from datetime import datetime |
|
|
from copy import deepcopy |
|
|
import uuid |
|
|
|
|
|
|
|
|
class StateSnapshot(BaseModel): |
|
|
"""A snapshot of state at a specific point in execution.""" |
|
|
|
|
|
timestamp: datetime = Field(default_factory=datetime.now) |
|
|
node_name: str |
|
|
state_data: Dict[str, Any] |
|
|
iteration: int = 0 |
|
|
|
|
|
|
|
|
class WorkflowState(BaseModel): |
|
|
""" |
|
|
The shared state that flows through the workflow. |
|
|
|
|
|
This is a flexible container that holds all data being processed |
|
|
by the workflow nodes. Each node can read from and write to this state. |
|
|
|
|
|
Attributes: |
|
|
data: The actual workflow data (flexible dictionary) |
|
|
metadata: Execution metadata (iteration count, visited nodes, etc.) |
|
|
""" |
|
|
|
|
|
|
|
|
data: Dict[str, Any] = Field(default_factory=dict) |
|
|
|
|
|
|
|
|
current_node: Optional[str] = None |
|
|
iteration: int = 0 |
|
|
visited_nodes: List[str] = Field(default_factory=list) |
|
|
started_at: Optional[datetime] = None |
|
|
completed_at: Optional[datetime] = None |
|
|
|
|
|
class Config: |
|
|
arbitrary_types_allowed = True |
|
|
|
|
|
def get(self, key: str, default: Any = None) -> Any: |
|
|
"""Get a value from the state data.""" |
|
|
return self.data.get(key, default) |
|
|
|
|
|
def set(self, key: str, value: Any) -> "WorkflowState": |
|
|
"""Set a value in state data and return a new state (immutable pattern).""" |
|
|
new_data = deepcopy(self.data) |
|
|
new_data[key] = value |
|
|
return self.model_copy(update={"data": new_data}) |
|
|
|
|
|
def update(self, updates: Dict[str, Any]) -> "WorkflowState": |
|
|
"""Update multiple values and return a new state.""" |
|
|
new_data = deepcopy(self.data) |
|
|
new_data.update(updates) |
|
|
return self.model_copy(update={"data": new_data}) |
|
|
|
|
|
def mark_visited(self, node_name: str) -> "WorkflowState": |
|
|
"""Mark a node as visited.""" |
|
|
new_visited = self.visited_nodes + [node_name] |
|
|
return self.model_copy(update={ |
|
|
"visited_nodes": new_visited, |
|
|
"current_node": node_name |
|
|
}) |
|
|
|
|
|
def increment_iteration(self) -> "WorkflowState": |
|
|
"""Increment the iteration counter.""" |
|
|
return self.model_copy(update={"iteration": self.iteration + 1}) |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
"""Convert state to a plain dictionary.""" |
|
|
return { |
|
|
"data": self.data, |
|
|
"current_node": self.current_node, |
|
|
"iteration": self.iteration, |
|
|
"visited_nodes": self.visited_nodes, |
|
|
"started_at": self.started_at.isoformat() if self.started_at else None, |
|
|
"completed_at": self.completed_at.isoformat() if self.completed_at else None, |
|
|
} |
|
|
|
|
|
@classmethod |
|
|
def from_dict(cls, data: Dict[str, Any]) -> "WorkflowState": |
|
|
"""Create a WorkflowState from a dictionary.""" |
|
|
if "data" in data: |
|
|
return cls(**data) |
|
|
|
|
|
return cls(data=data) |
|
|
|
|
|
|
|
|
class StateManager: |
|
|
""" |
|
|
Manages state history and snapshots for a workflow run. |
|
|
|
|
|
This provides debugging capabilities by tracking state changes |
|
|
throughout the workflow execution. |
|
|
""" |
|
|
|
|
|
def __init__(self, run_id: Optional[str] = None): |
|
|
self.run_id = run_id or str(uuid.uuid4()) |
|
|
self.history: List[StateSnapshot] = [] |
|
|
self._current_state: Optional[WorkflowState] = None |
|
|
|
|
|
@property |
|
|
def current_state(self) -> Optional[WorkflowState]: |
|
|
"""Get the current state.""" |
|
|
return self._current_state |
|
|
|
|
|
def initialize(self, initial_data: Dict[str, Any]) -> WorkflowState: |
|
|
"""Initialize the state manager with initial data.""" |
|
|
self._current_state = WorkflowState( |
|
|
data=initial_data, |
|
|
started_at=datetime.now() |
|
|
) |
|
|
return self._current_state |
|
|
|
|
|
def update(self, new_state: WorkflowState, node_name: str) -> None: |
|
|
"""Update the current state and record a snapshot.""" |
|
|
|
|
|
snapshot = StateSnapshot( |
|
|
node_name=node_name, |
|
|
state_data=deepcopy(new_state.data), |
|
|
iteration=new_state.iteration |
|
|
) |
|
|
self.history.append(snapshot) |
|
|
|
|
|
|
|
|
self._current_state = new_state |
|
|
|
|
|
def finalize(self) -> WorkflowState: |
|
|
"""Mark the workflow as complete.""" |
|
|
if self._current_state: |
|
|
self._current_state = self._current_state.model_copy( |
|
|
update={"completed_at": datetime.now()} |
|
|
) |
|
|
return self._current_state |
|
|
|
|
|
def get_history(self) -> List[Dict[str, Any]]: |
|
|
"""Get the state history as a list of dictionaries.""" |
|
|
return [ |
|
|
{ |
|
|
"timestamp": s.timestamp.isoformat(), |
|
|
"node": s.node_name, |
|
|
"iteration": s.iteration, |
|
|
"state": s.state_data |
|
|
} |
|
|
for s in self.history |
|
|
] |
|
|
|