tritan-api / core /engine.py
Madras1's picture
Upload 17 files
63bfd20 verified
"""
Workflow Execution Engine
"""
from models.workflow import Workflow, Node, NodeType
from models.execution import ExecutionState, ExecutionResult, ExecutionStatus, NodeResult
from .nodes import NodeExecutor
from .guardian import Guardian
from datetime import datetime
from typing import Any, AsyncGenerator
import asyncio
class WorkflowEngine:
"""
Main workflow execution engine.
Handles node execution, parallel branching, and state management.
"""
def __init__(self):
self.node_executor = NodeExecutor()
self.guardian = Guardian()
async def execute(self, workflow: Workflow) -> ExecutionResult:
"""Execute a workflow and return the final result"""
start_time = datetime.now()
# Initialize execution state
state = ExecutionState(
workflow_id=workflow.id,
status=ExecutionStatus.RUNNING,
started_at=start_time.isoformat()
)
try:
# Build execution order (topological sort)
execution_order = self._build_execution_order(workflow)
# Execute nodes in order
for node_id in execution_order:
node = self._get_node_by_id(workflow, node_id)
if not node:
continue
state.current_node = node_id
node_result = await self._execute_node(node, state)
state.node_results[node_id] = node_result
# Check for failures
if node_result.status == ExecutionStatus.FAILED:
state.status = ExecutionStatus.FAILED
break
if state.status != ExecutionStatus.FAILED:
state.status = ExecutionStatus.COMPLETED
except Exception as e:
state.status = ExecutionStatus.FAILED
state.node_results["_error"] = NodeResult(
node_id="_error",
status=ExecutionStatus.FAILED,
error=str(e)
)
end_time = datetime.now()
state.completed_at = end_time.isoformat()
return ExecutionResult(
execution_id=state.execution_id,
workflow_id=workflow.id,
status=state.status,
node_results=list(state.node_results.values()),
final_output=self._get_final_output(state),
total_duration_ms=(end_time - start_time).total_seconds() * 1000,
started_at=state.started_at,
completed_at=state.completed_at
)
async def execute_stream(self, workflow: Workflow) -> AsyncGenerator[dict, None]:
"""Execute workflow with streaming events"""
start_time = datetime.now()
yield {"type": "start", "workflow_id": workflow.id, "timestamp": start_time.isoformat()}
state = ExecutionState(
workflow_id=workflow.id,
status=ExecutionStatus.RUNNING,
started_at=start_time.isoformat()
)
try:
execution_order = self._build_execution_order(workflow)
for node_id in execution_order:
node = self._get_node_by_id(workflow, node_id)
if not node:
continue
yield {"type": "node_start", "node_id": node_id, "node_type": node.type}
node_result = await self._execute_node(node, state)
state.node_results[node_id] = node_result
yield {
"type": "node_complete",
"node_id": node_id,
"status": node_result.status,
"output": node_result.output,
"error": node_result.error
}
if node_result.status == ExecutionStatus.FAILED:
state.status = ExecutionStatus.FAILED
break
if state.status != ExecutionStatus.FAILED:
state.status = ExecutionStatus.COMPLETED
except Exception as e:
state.status = ExecutionStatus.FAILED
yield {"type": "error", "message": str(e)}
end_time = datetime.now()
yield {
"type": "complete",
"status": state.status,
"duration_ms": (end_time - start_time).total_seconds() * 1000
}
async def execute_node(self, node_type: str, config: dict) -> dict:
"""Execute a single node (for testing)"""
return await self.node_executor.execute(node_type, config, {})
async def _execute_node(self, node: Node, state: ExecutionState) -> NodeResult:
"""Execute a single node and return result"""
start_time = datetime.now()
try:
# Convert enum to string value if needed
node_type = node.type.value if hasattr(node.type, 'value') else str(node.type)
output = await self.node_executor.execute(
node_type,
node.data.model_dump(),
state.variables
)
end_time = datetime.now()
# Store output in variables for next nodes
state.variables[node.id] = output
return NodeResult(
node_id=node.id,
status=ExecutionStatus.COMPLETED,
output=output,
started_at=start_time.isoformat(),
completed_at=end_time.isoformat(),
duration_ms=(end_time - start_time).total_seconds() * 1000
)
except Exception as e:
end_time = datetime.now()
return NodeResult(
node_id=node.id,
status=ExecutionStatus.FAILED,
error=str(e),
started_at=start_time.isoformat(),
completed_at=end_time.isoformat(),
duration_ms=(end_time - start_time).total_seconds() * 1000
)
def _build_execution_order(self, workflow: Workflow) -> list[str]:
"""Build topological order for node execution"""
# Build adjacency list
graph: dict[str, list[str]] = {node.id: [] for node in workflow.nodes}
in_degree: dict[str, int] = {node.id: 0 for node in workflow.nodes}
for edge in workflow.edges:
graph[edge.source].append(edge.target)
in_degree[edge.target] += 1
# Find trigger nodes (in_degree = 0)
queue = [node_id for node_id, degree in in_degree.items() if degree == 0]
result = []
while queue:
node_id = queue.pop(0)
result.append(node_id)
for neighbor in graph[node_id]:
in_degree[neighbor] -= 1
if in_degree[neighbor] == 0:
queue.append(neighbor)
return result
def _get_node_by_id(self, workflow: Workflow, node_id: str) -> Node | None:
"""Get a node by its ID"""
for node in workflow.nodes:
if node.id == node_id:
return node
return None
def _get_final_output(self, state: ExecutionState) -> Any:
"""Get the final output from the last executed node"""
if not state.node_results:
return None
# Get the last completed node's output
for node_id in reversed(list(state.node_results.keys())):
result = state.node_results[node_id]
if result.status == ExecutionStatus.COMPLETED and result.output:
return result.output
return None