""" Task Manager for handling processing task cancellation and progress tracking. Provides thread-safe task tracking, cancellation flags, and status management. """ import threading import time from typing import Dict, Optional, Any, List from datetime import datetime import uuid import logging logger = logging.getLogger(__name__) class TaskManager: """Manages processing tasks with cancellation support and progress tracking.""" def __init__(self): """Initialize task manager with thread-safe data structures.""" self.active_tasks: Dict[str, Dict[str, Any]] = {} self.cancel_flags: Dict[str, threading.Event] = {} self.lock = threading.Lock() logger.info("TaskManager initialized") def create_task(self, user_id: str = "default") -> str: """ Create a new task and return task ID. Args: user_id: User identifier for multi-user support Returns: task_id: Unique task identifier """ task_id = str(uuid.uuid4()) with self.lock: self.active_tasks[task_id] = { 'user_id': user_id, 'status': 'pending', 'created_at': datetime.now().isoformat(), 'updated_at': datetime.now().isoformat(), 'progress': 0, 'stage': 'initializing', 'message': 'Task created' } self.cancel_flags[task_id] = threading.Event() logger.info(f"Task created: {task_id} for user: {user_id}") return task_id def update_task(self, task_id: str, **kwargs): """ Update task status and metadata. Args: task_id: Task identifier **kwargs: Fields to update (status, progress, stage, message, etc.) """ with self.lock: if task_id in self.active_tasks: self.active_tasks[task_id].update(kwargs) self.active_tasks[task_id]['updated_at'] = datetime.now().isoformat() # Log significant status changes if 'stage' in kwargs: logger.info(f"Task {task_id} stage: {kwargs['stage']} ({kwargs.get('progress', 0)}%)") def cancel_task(self, task_id: str) -> bool: """ Cancel a running task by setting its cancellation flag. Args: task_id: Task identifier to cancel Returns: bool: True if task was found and cancelled, False otherwise """ with self.lock: if task_id in self.cancel_flags: self.cancel_flags[task_id].set() # Signal cancellation if task_id in self.active_tasks: self.active_tasks[task_id]['status'] = 'cancelling' self.active_tasks[task_id]['updated_at'] = datetime.now().isoformat() self.active_tasks[task_id]['message'] = 'Cancellation requested' logger.info(f"Task cancelled: {task_id}") return True logger.warning(f"Task not found for cancellation: {task_id}") return False def is_cancelled(self, task_id: str) -> bool: """ Check if task has been cancelled. Args: task_id: Task identifier to check Returns: bool: True if task is cancelled, False otherwise """ if task_id is None: return False return self.cancel_flags.get(task_id, threading.Event()).is_set() def complete_task(self, task_id: str, success: bool = True, message: str = None): """ Mark task as complete. Args: task_id: Task identifier success: Whether task completed successfully message: Optional completion message """ with self.lock: if task_id in self.active_tasks: self.active_tasks[task_id]['status'] = 'completed' if success else 'failed' self.active_tasks[task_id]['completed_at'] = datetime.now().isoformat() self.active_tasks[task_id]['progress'] = 100 if success else self.active_tasks[task_id].get('progress', 0) if message: self.active_tasks[task_id]['message'] = message logger.info(f"Task completed: {task_id} (success={success})") def mark_cancelled(self, task_id: str): """ Mark task as fully cancelled (after cleanup). Args: task_id: Task identifier """ with self.lock: if task_id in self.active_tasks: self.active_tasks[task_id]['status'] = 'cancelled' self.active_tasks[task_id]['cancelled_at'] = datetime.now().isoformat() self.active_tasks[task_id]['message'] = 'Task cancelled successfully' logger.info(f"Task marked as cancelled: {task_id}") def cleanup_task(self, task_id: str): """ Remove task from tracking (for old/completed tasks). Args: task_id: Task identifier to remove """ with self.lock: removed_task = self.active_tasks.pop(task_id, None) self.cancel_flags.pop(task_id, None) if removed_task: logger.info(f"Task cleaned up: {task_id}") def get_task_status(self, task_id: str) -> Optional[Dict]: """ Get current task status and metadata. Args: task_id: Task identifier Returns: Dict with task status or None if not found """ with self.lock: return self.active_tasks.get(task_id, None) def get_user_tasks(self, user_id: str) -> List[Dict]: """ Get all tasks for a specific user. Args: user_id: User identifier Returns: List of task dictionaries """ with self.lock: user_tasks = [ {'task_id': tid, **task} for tid, task in self.active_tasks.items() if task.get('user_id') == user_id ] return user_tasks def cancel_user_tasks(self, user_id: str) -> int: """ Cancel all active tasks for a user. Args: user_id: User identifier Returns: int: Number of tasks cancelled """ with self.lock: task_ids = [ tid for tid, task in self.active_tasks.items() if task.get('user_id') == user_id and task.get('status') in ['pending', 'processing'] ] for tid in task_ids: self.cancel_task(tid) logger.info(f"Cancelled {len(task_ids)} tasks for user: {user_id}") return len(task_ids) def cleanup_old_tasks(self, max_age_seconds: int = 3600): """ Clean up tasks older than specified age. Args: max_age_seconds: Maximum age in seconds (default: 1 hour) """ current_time = datetime.now() tasks_to_remove = [] with self.lock: for task_id, task in self.active_tasks.items(): created_at = datetime.fromisoformat(task['created_at']) age = (current_time - created_at).total_seconds() if age > max_age_seconds and task.get('status') in ['completed', 'failed', 'cancelled']: tasks_to_remove.append(task_id) for task_id in tasks_to_remove: self.cleanup_task(task_id) if tasks_to_remove: logger.info(f"Cleaned up {len(tasks_to_remove)} old tasks") def get_stats(self) -> Dict[str, Any]: """ Get overall task statistics. Returns: Dict with task statistics """ with self.lock: total = len(self.active_tasks) by_status = {} by_user = {} for task in self.active_tasks.values(): status = task.get('status', 'unknown') user_id = task.get('user_id', 'unknown') by_status[status] = by_status.get(status, 0) + 1 by_user[user_id] = by_user.get(user_id, 0) + 1 return { 'total_tasks': total, 'by_status': by_status, 'by_user': by_user } # Global instance _task_manager_instance = None def get_task_manager() -> TaskManager: """ Get global TaskManager instance (singleton pattern). Returns: TaskManager instance """ global _task_manager_instance if _task_manager_instance is None: _task_manager_instance = TaskManager() return _task_manager_instance