chmielvu's picture
feat: add production refinements (Phase 1-3)
4454066 verified
"""
Workflow schema definitions using Pydantic
"""
from pydantic import BaseModel, Field, field_validator
from typing import List, Dict, Any, Optional
class WorkflowTask(BaseModel):
"""
Single task in a workflow.
Attributes:
id: Unique task identifier
tool: Tool name to execute
args: Arguments to pass to tool
depends_on: List of task IDs this task depends on
retry_on_failure: Whether to retry if task fails
max_retries: Maximum retry attempts
timeout_seconds: Task timeout in seconds
"""
id: str = Field(..., description="Unique task ID")
tool: str = Field(..., description="Tool name to execute")
args: Dict[str, Any] = Field(default_factory=dict, description="Tool arguments")
depends_on: List[str] = Field(default_factory=list, description="Task dependencies")
retry_on_failure: bool = Field(default=True, description="Retry on failure")
max_retries: int = Field(default=3, ge=0, le=10, description="Max retries")
timeout_seconds: int = Field(default=60, ge=1, le=600, description="Task timeout")
@field_validator("id")
@classmethod
def validate_id(cls, v: str) -> str:
"""Validate task ID is alphanumeric with underscores/hyphens"""
if not v.replace("_", "").replace("-", "").isalnum():
raise ValueError("Task ID must be alphanumeric with underscores/hyphens")
return v
class WorkflowDefinition(BaseModel):
"""
Complete workflow definition with tasks and execution config.
Attributes:
name: Workflow name
description: Optional workflow description
tasks: List of tasks to execute
final_task: ID of final task (workflow result)
max_parallel: Maximum parallel task execution
timeout_seconds: Total workflow timeout
"""
name: str = Field(..., description="Workflow name")
description: Optional[str] = Field(default=None, description="Workflow description")
tasks: List[WorkflowTask] = Field(..., min_length=1, description="Workflow tasks")
final_task: str = Field(..., description="Final task ID for result")
max_parallel: int = Field(default=3, ge=1, le=10, description="Max parallel tasks")
timeout_seconds: int = Field(default=600, ge=1, le=3600, description="Workflow timeout")
@field_validator("final_task")
@classmethod
def validate_final_task(cls, v: str, info) -> str:
"""Validate final_task exists in tasks"""
# Note: info.data contains previously validated fields
tasks = info.data.get("tasks", [])
task_ids = {task.id for task in tasks}
if v not in task_ids:
raise ValueError(f"final_task '{v}' not found in tasks")
return v
@field_validator("tasks")
@classmethod
def validate_dependencies(cls, v: List[WorkflowTask]) -> List[WorkflowTask]:
"""Validate task dependencies exist and form DAG (no cycles)"""
task_ids = {task.id for task in v}
# Check all dependencies exist
for task in v:
for dep in task.depends_on:
if dep not in task_ids:
raise ValueError(
f"Task '{task.id}' depends on non-existent task '{dep}'"
)
# Check for cycles using DFS
def has_cycle(task_id: str, visited: set, rec_stack: set) -> bool:
visited.add(task_id)
rec_stack.add(task_id)
# Get dependencies for this task
task = next(t for t in v if t.id == task_id)
for dep in task.depends_on:
if dep not in visited:
if has_cycle(dep, visited, rec_stack):
return True
elif dep in rec_stack:
return True
rec_stack.remove(task_id)
return False
visited = set()
for task in v:
if task.id not in visited:
if has_cycle(task.id, visited, set()):
raise ValueError(
f"Workflow contains cycle involving task '{task.id}'"
)
return v
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary"""
return self.model_dump()
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "WorkflowDefinition":
"""Create from dictionary"""
return cls(**data)