flowgraph / app /engine /state.py
kbsss's picture
Upload folder using huggingface_hub
7b2787b verified
"""
State Management for Workflow Engine.
This module provides the state management system that flows through the workflow.
State is immutable - each node receives state and returns a new modified state.
"""
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
from datetime import datetime
from copy import deepcopy
import uuid
class StateSnapshot(BaseModel):
"""A snapshot of state at a specific point in execution."""
timestamp: datetime = Field(default_factory=datetime.now)
node_name: str
state_data: Dict[str, Any]
iteration: int = 0
class WorkflowState(BaseModel):
"""
The shared state that flows through the workflow.
This is a flexible container that holds all data being processed
by the workflow nodes. Each node can read from and write to this state.
Attributes:
data: The actual workflow data (flexible dictionary)
metadata: Execution metadata (iteration count, visited nodes, etc.)
"""
# The actual data being processed
data: Dict[str, Any] = Field(default_factory=dict)
# Execution metadata
current_node: Optional[str] = None
iteration: int = 0
visited_nodes: List[str] = Field(default_factory=list)
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
class Config:
arbitrary_types_allowed = True
def get(self, key: str, default: Any = None) -> Any:
"""Get a value from the state data."""
return self.data.get(key, default)
def set(self, key: str, value: Any) -> "WorkflowState":
"""Set a value in state data and return a new state (immutable pattern)."""
new_data = deepcopy(self.data)
new_data[key] = value
return self.model_copy(update={"data": new_data})
def update(self, updates: Dict[str, Any]) -> "WorkflowState":
"""Update multiple values and return a new state."""
new_data = deepcopy(self.data)
new_data.update(updates)
return self.model_copy(update={"data": new_data})
def mark_visited(self, node_name: str) -> "WorkflowState":
"""Mark a node as visited."""
new_visited = self.visited_nodes + [node_name]
return self.model_copy(update={
"visited_nodes": new_visited,
"current_node": node_name
})
def increment_iteration(self) -> "WorkflowState":
"""Increment the iteration counter."""
return self.model_copy(update={"iteration": self.iteration + 1})
def to_dict(self) -> Dict[str, Any]:
"""Convert state to a plain dictionary."""
return {
"data": self.data,
"current_node": self.current_node,
"iteration": self.iteration,
"visited_nodes": self.visited_nodes,
"started_at": self.started_at.isoformat() if self.started_at else None,
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "WorkflowState":
"""Create a WorkflowState from a dictionary."""
if "data" in data:
return cls(**data)
# If it's just raw data, wrap it
return cls(data=data)
class StateManager:
"""
Manages state history and snapshots for a workflow run.
This provides debugging capabilities by tracking state changes
throughout the workflow execution.
"""
def __init__(self, run_id: Optional[str] = None):
self.run_id = run_id or str(uuid.uuid4())
self.history: List[StateSnapshot] = []
self._current_state: Optional[WorkflowState] = None
@property
def current_state(self) -> Optional[WorkflowState]:
"""Get the current state."""
return self._current_state
def initialize(self, initial_data: Dict[str, Any]) -> WorkflowState:
"""Initialize the state manager with initial data."""
self._current_state = WorkflowState(
data=initial_data,
started_at=datetime.now()
)
return self._current_state
def update(self, new_state: WorkflowState, node_name: str) -> None:
"""Update the current state and record a snapshot."""
# Record snapshot
snapshot = StateSnapshot(
node_name=node_name,
state_data=deepcopy(new_state.data),
iteration=new_state.iteration
)
self.history.append(snapshot)
# Update current state
self._current_state = new_state
def finalize(self) -> WorkflowState:
"""Mark the workflow as complete."""
if self._current_state:
self._current_state = self._current_state.model_copy(
update={"completed_at": datetime.now()}
)
return self._current_state
def get_history(self) -> List[Dict[str, Any]]:
"""Get the state history as a list of dictionaries."""
return [
{
"timestamp": s.timestamp.isoformat(),
"node": s.node_name,
"iteration": s.iteration,
"state": s.state_data
}
for s in self.history
]