Spaces:
Sleeping
Sleeping
| """ | |
| Workflow schema definitions using Pydantic | |
| """ | |
| from pydantic import BaseModel, Field, field_validator | |
| from typing import List, Dict, Any, Optional | |
| class WorkflowTask(BaseModel): | |
| """ | |
| Single task in a workflow. | |
| Attributes: | |
| id: Unique task identifier | |
| tool: Tool name to execute | |
| args: Arguments to pass to tool | |
| depends_on: List of task IDs this task depends on | |
| retry_on_failure: Whether to retry if task fails | |
| max_retries: Maximum retry attempts | |
| timeout_seconds: Task timeout in seconds | |
| """ | |
| id: str = Field(..., description="Unique task ID") | |
| tool: str = Field(..., description="Tool name to execute") | |
| args: Dict[str, Any] = Field(default_factory=dict, description="Tool arguments") | |
| depends_on: List[str] = Field(default_factory=list, description="Task dependencies") | |
| retry_on_failure: bool = Field(default=True, description="Retry on failure") | |
| max_retries: int = Field(default=3, ge=0, le=10, description="Max retries") | |
| timeout_seconds: int = Field(default=60, ge=1, le=600, description="Task timeout") | |
| def validate_id(cls, v: str) -> str: | |
| """Validate task ID is alphanumeric with underscores/hyphens""" | |
| if not v.replace("_", "").replace("-", "").isalnum(): | |
| raise ValueError("Task ID must be alphanumeric with underscores/hyphens") | |
| return v | |
| class WorkflowDefinition(BaseModel): | |
| """ | |
| Complete workflow definition with tasks and execution config. | |
| Attributes: | |
| name: Workflow name | |
| description: Optional workflow description | |
| tasks: List of tasks to execute | |
| final_task: ID of final task (workflow result) | |
| max_parallel: Maximum parallel task execution | |
| timeout_seconds: Total workflow timeout | |
| """ | |
| name: str = Field(..., description="Workflow name") | |
| description: Optional[str] = Field(default=None, description="Workflow description") | |
| tasks: List[WorkflowTask] = Field(..., min_length=1, description="Workflow tasks") | |
| final_task: str = Field(..., description="Final task ID for result") | |
| max_parallel: int = Field(default=3, ge=1, le=10, description="Max parallel tasks") | |
| timeout_seconds: int = Field(default=600, ge=1, le=3600, description="Workflow timeout") | |
| def validate_final_task(cls, v: str, info) -> str: | |
| """Validate final_task exists in tasks""" | |
| # Note: info.data contains previously validated fields | |
| tasks = info.data.get("tasks", []) | |
| task_ids = {task.id for task in tasks} | |
| if v not in task_ids: | |
| raise ValueError(f"final_task '{v}' not found in tasks") | |
| return v | |
| def validate_dependencies(cls, v: List[WorkflowTask]) -> List[WorkflowTask]: | |
| """Validate task dependencies exist and form DAG (no cycles)""" | |
| task_ids = {task.id for task in v} | |
| # Check all dependencies exist | |
| for task in v: | |
| for dep in task.depends_on: | |
| if dep not in task_ids: | |
| raise ValueError( | |
| f"Task '{task.id}' depends on non-existent task '{dep}'" | |
| ) | |
| # Check for cycles using DFS | |
| def has_cycle(task_id: str, visited: set, rec_stack: set) -> bool: | |
| visited.add(task_id) | |
| rec_stack.add(task_id) | |
| # Get dependencies for this task | |
| task = next(t for t in v if t.id == task_id) | |
| for dep in task.depends_on: | |
| if dep not in visited: | |
| if has_cycle(dep, visited, rec_stack): | |
| return True | |
| elif dep in rec_stack: | |
| return True | |
| rec_stack.remove(task_id) | |
| return False | |
| visited = set() | |
| for task in v: | |
| if task.id not in visited: | |
| if has_cycle(task.id, visited, set()): | |
| raise ValueError( | |
| f"Workflow contains cycle involving task '{task.id}'" | |
| ) | |
| return v | |
| def to_dict(self) -> Dict[str, Any]: | |
| """Convert to dictionary""" | |
| return self.model_dump() | |
| def from_dict(cls, data: Dict[str, Any]) -> "WorkflowDefinition": | |
| """Create from dictionary""" | |
| return cls(**data) | |