Spaces:
Paused
Paused
| import asyncio | |
| import time | |
| from typing import Any, Callable, Dict, List | |
| from copy import deepcopy | |
| from starfish.common.logger import get_logger | |
| from starfish.data_factory.config import TASK_RUNNER_TIMEOUT | |
| from starfish.data_factory.constants import IDX | |
| from starfish.data_factory.utils.errors import TimeoutErrorAsyncio | |
| logger = get_logger(__name__) | |
| # from starfish.common.logger_new import logger | |
| class TaskRunner: | |
| """A task runner that executes asynchronous tasks with retry logic and timeout handling. | |
| Attributes: | |
| max_retries: Maximum number of retry attempts for failed tasks | |
| timeout: Maximum execution time allowed for each task | |
| master_job_id: Optional identifier for the parent job | |
| """ | |
| def __init__(self, max_retries: int = 1, timeout: int = TASK_RUNNER_TIMEOUT, master_job_id: str = None): | |
| """Initializes the TaskRunner with configuration parameters. | |
| Args: | |
| max_retries: Maximum number of retry attempts (default: 1) | |
| timeout: Timeout in seconds for task execution (default: TASK_RUNNER_TIMEOUT) | |
| master_job_id: Optional identifier for the parent job (default: None) | |
| """ | |
| self.max_retries = max_retries | |
| self.timeout = timeout | |
| self.master_job_id = master_job_id | |
| async def run_task(self, func: Callable, input_data: Dict, input_data_idx: str) -> List[Any]: | |
| """Process a single task with asyncio.""" | |
| retries = 0 | |
| start_time = time.time() | |
| result = None | |
| # Create a copy of input_data without 'IDX' tp prevent insertion of IDX due to race condition | |
| copy_input = deepcopy({k: v for k, v in input_data.items() if k != IDX}) | |
| while retries <= self.max_retries: | |
| try: | |
| result = await asyncio.wait_for(func(**copy_input), timeout=self.timeout) | |
| logger.debug(f"Task execution completed in {time.time() - start_time:.2f} seconds") | |
| break | |
| except asyncio.TimeoutError as timeout_error: | |
| logger.error( | |
| f"Task execution timed out after {self.timeout} seconds, " | |
| "please set the timeout in data_factory decorator like this: " | |
| "task_runner_timeout=60" | |
| ) | |
| raise TimeoutErrorAsyncio(f"Task execution timed out after {self.timeout} seconds") from timeout_error | |
| except Exception as e: | |
| retries += 1 | |
| if retries > self.max_retries: | |
| # logger.error(f"Task execution failed after {self.max_retries} retries") | |
| raise e | |
| logger.debug(f"Retry attempt {retries}/{self.max_retries} for input data index {input_data_idx}") | |
| await asyncio.sleep(1**retries) # exponential backoff | |
| return result | |