Spaces:
Sleeping
Sleeping
| """ | |
| Task Manager for handling processing task cancellation and progress tracking. | |
| Provides thread-safe task tracking, cancellation flags, and status management. | |
| """ | |
| import threading | |
| import time | |
| from typing import Dict, Optional, Any, List | |
| from datetime import datetime | |
| import uuid | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class TaskManager: | |
| """Manages processing tasks with cancellation support and progress tracking.""" | |
| def __init__(self): | |
| """Initialize task manager with thread-safe data structures.""" | |
| self.active_tasks: Dict[str, Dict[str, Any]] = {} | |
| self.cancel_flags: Dict[str, threading.Event] = {} | |
| self.lock = threading.Lock() | |
| logger.info("TaskManager initialized") | |
| def create_task(self, user_id: str = "default") -> str: | |
| """ | |
| Create a new task and return task ID. | |
| Args: | |
| user_id: User identifier for multi-user support | |
| Returns: | |
| task_id: Unique task identifier | |
| """ | |
| task_id = str(uuid.uuid4()) | |
| with self.lock: | |
| self.active_tasks[task_id] = { | |
| 'user_id': user_id, | |
| 'status': 'pending', | |
| 'created_at': datetime.now().isoformat(), | |
| 'updated_at': datetime.now().isoformat(), | |
| 'progress': 0, | |
| 'stage': 'initializing', | |
| 'message': 'Task created' | |
| } | |
| self.cancel_flags[task_id] = threading.Event() | |
| logger.info(f"Task created: {task_id} for user: {user_id}") | |
| return task_id | |
| def update_task(self, task_id: str, **kwargs): | |
| """ | |
| Update task status and metadata. | |
| Args: | |
| task_id: Task identifier | |
| **kwargs: Fields to update (status, progress, stage, message, etc.) | |
| """ | |
| with self.lock: | |
| if task_id in self.active_tasks: | |
| self.active_tasks[task_id].update(kwargs) | |
| self.active_tasks[task_id]['updated_at'] = datetime.now().isoformat() | |
| # Log significant status changes | |
| if 'stage' in kwargs: | |
| logger.info(f"Task {task_id} stage: {kwargs['stage']} ({kwargs.get('progress', 0)}%)") | |
| def cancel_task(self, task_id: str) -> bool: | |
| """ | |
| Cancel a running task by setting its cancellation flag. | |
| Args: | |
| task_id: Task identifier to cancel | |
| Returns: | |
| bool: True if task was found and cancelled, False otherwise | |
| """ | |
| with self.lock: | |
| if task_id in self.cancel_flags: | |
| self.cancel_flags[task_id].set() # Signal cancellation | |
| if task_id in self.active_tasks: | |
| self.active_tasks[task_id]['status'] = 'cancelling' | |
| self.active_tasks[task_id]['updated_at'] = datetime.now().isoformat() | |
| self.active_tasks[task_id]['message'] = 'Cancellation requested' | |
| logger.info(f"Task cancelled: {task_id}") | |
| return True | |
| logger.warning(f"Task not found for cancellation: {task_id}") | |
| return False | |
| def is_cancelled(self, task_id: str) -> bool: | |
| """ | |
| Check if task has been cancelled. | |
| Args: | |
| task_id: Task identifier to check | |
| Returns: | |
| bool: True if task is cancelled, False otherwise | |
| """ | |
| if task_id is None: | |
| return False | |
| return self.cancel_flags.get(task_id, threading.Event()).is_set() | |
| def complete_task(self, task_id: str, success: bool = True, message: str = None): | |
| """ | |
| Mark task as complete. | |
| Args: | |
| task_id: Task identifier | |
| success: Whether task completed successfully | |
| message: Optional completion message | |
| """ | |
| with self.lock: | |
| if task_id in self.active_tasks: | |
| self.active_tasks[task_id]['status'] = 'completed' if success else 'failed' | |
| self.active_tasks[task_id]['completed_at'] = datetime.now().isoformat() | |
| self.active_tasks[task_id]['progress'] = 100 if success else self.active_tasks[task_id].get('progress', 0) | |
| if message: | |
| self.active_tasks[task_id]['message'] = message | |
| logger.info(f"Task completed: {task_id} (success={success})") | |
| def mark_cancelled(self, task_id: str): | |
| """ | |
| Mark task as fully cancelled (after cleanup). | |
| Args: | |
| task_id: Task identifier | |
| """ | |
| with self.lock: | |
| if task_id in self.active_tasks: | |
| self.active_tasks[task_id]['status'] = 'cancelled' | |
| self.active_tasks[task_id]['cancelled_at'] = datetime.now().isoformat() | |
| self.active_tasks[task_id]['message'] = 'Task cancelled successfully' | |
| logger.info(f"Task marked as cancelled: {task_id}") | |
| def cleanup_task(self, task_id: str): | |
| """ | |
| Remove task from tracking (for old/completed tasks). | |
| Args: | |
| task_id: Task identifier to remove | |
| """ | |
| with self.lock: | |
| removed_task = self.active_tasks.pop(task_id, None) | |
| self.cancel_flags.pop(task_id, None) | |
| if removed_task: | |
| logger.info(f"Task cleaned up: {task_id}") | |
| def get_task_status(self, task_id: str) -> Optional[Dict]: | |
| """ | |
| Get current task status and metadata. | |
| Args: | |
| task_id: Task identifier | |
| Returns: | |
| Dict with task status or None if not found | |
| """ | |
| with self.lock: | |
| return self.active_tasks.get(task_id, None) | |
| def get_user_tasks(self, user_id: str) -> List[Dict]: | |
| """ | |
| Get all tasks for a specific user. | |
| Args: | |
| user_id: User identifier | |
| Returns: | |
| List of task dictionaries | |
| """ | |
| with self.lock: | |
| user_tasks = [ | |
| {'task_id': tid, **task} | |
| for tid, task in self.active_tasks.items() | |
| if task.get('user_id') == user_id | |
| ] | |
| return user_tasks | |
| def cancel_user_tasks(self, user_id: str) -> int: | |
| """ | |
| Cancel all active tasks for a user. | |
| Args: | |
| user_id: User identifier | |
| Returns: | |
| int: Number of tasks cancelled | |
| """ | |
| with self.lock: | |
| task_ids = [ | |
| tid for tid, task in self.active_tasks.items() | |
| if task.get('user_id') == user_id and task.get('status') in ['pending', 'processing'] | |
| ] | |
| for tid in task_ids: | |
| self.cancel_task(tid) | |
| logger.info(f"Cancelled {len(task_ids)} tasks for user: {user_id}") | |
| return len(task_ids) | |
| def cleanup_old_tasks(self, max_age_seconds: int = 3600): | |
| """ | |
| Clean up tasks older than specified age. | |
| Args: | |
| max_age_seconds: Maximum age in seconds (default: 1 hour) | |
| """ | |
| current_time = datetime.now() | |
| tasks_to_remove = [] | |
| with self.lock: | |
| for task_id, task in self.active_tasks.items(): | |
| created_at = datetime.fromisoformat(task['created_at']) | |
| age = (current_time - created_at).total_seconds() | |
| if age > max_age_seconds and task.get('status') in ['completed', 'failed', 'cancelled']: | |
| tasks_to_remove.append(task_id) | |
| for task_id in tasks_to_remove: | |
| self.cleanup_task(task_id) | |
| if tasks_to_remove: | |
| logger.info(f"Cleaned up {len(tasks_to_remove)} old tasks") | |
| def get_stats(self) -> Dict[str, Any]: | |
| """ | |
| Get overall task statistics. | |
| Returns: | |
| Dict with task statistics | |
| """ | |
| with self.lock: | |
| total = len(self.active_tasks) | |
| by_status = {} | |
| by_user = {} | |
| for task in self.active_tasks.values(): | |
| status = task.get('status', 'unknown') | |
| user_id = task.get('user_id', 'unknown') | |
| by_status[status] = by_status.get(status, 0) + 1 | |
| by_user[user_id] = by_user.get(user_id, 0) + 1 | |
| return { | |
| 'total_tasks': total, | |
| 'by_status': by_status, | |
| 'by_user': by_user | |
| } | |
| # Global instance | |
| _task_manager_instance = None | |
| def get_task_manager() -> TaskManager: | |
| """ | |
| Get global TaskManager instance (singleton pattern). | |
| Returns: | |
| TaskManager instance | |
| """ | |
| global _task_manager_instance | |
| if _task_manager_instance is None: | |
| _task_manager_instance = TaskManager() | |
| return _task_manager_instance | |