ABSA / src /utils /task_manager.py
parthnuwal7's picture
Adding cancellation button:
102e87a
"""
Task Manager for handling processing task cancellation and progress tracking.
Provides thread-safe task tracking, cancellation flags, and status management.
"""
import threading
import time
from typing import Dict, Optional, Any, List
from datetime import datetime
import uuid
import logging
logger = logging.getLogger(__name__)
class TaskManager:
"""Manages processing tasks with cancellation support and progress tracking."""
def __init__(self):
"""Initialize task manager with thread-safe data structures."""
self.active_tasks: Dict[str, Dict[str, Any]] = {}
self.cancel_flags: Dict[str, threading.Event] = {}
self.lock = threading.Lock()
logger.info("TaskManager initialized")
def create_task(self, user_id: str = "default") -> str:
"""
Create a new task and return task ID.
Args:
user_id: User identifier for multi-user support
Returns:
task_id: Unique task identifier
"""
task_id = str(uuid.uuid4())
with self.lock:
self.active_tasks[task_id] = {
'user_id': user_id,
'status': 'pending',
'created_at': datetime.now().isoformat(),
'updated_at': datetime.now().isoformat(),
'progress': 0,
'stage': 'initializing',
'message': 'Task created'
}
self.cancel_flags[task_id] = threading.Event()
logger.info(f"Task created: {task_id} for user: {user_id}")
return task_id
def update_task(self, task_id: str, **kwargs):
"""
Update task status and metadata.
Args:
task_id: Task identifier
**kwargs: Fields to update (status, progress, stage, message, etc.)
"""
with self.lock:
if task_id in self.active_tasks:
self.active_tasks[task_id].update(kwargs)
self.active_tasks[task_id]['updated_at'] = datetime.now().isoformat()
# Log significant status changes
if 'stage' in kwargs:
logger.info(f"Task {task_id} stage: {kwargs['stage']} ({kwargs.get('progress', 0)}%)")
def cancel_task(self, task_id: str) -> bool:
"""
Cancel a running task by setting its cancellation flag.
Args:
task_id: Task identifier to cancel
Returns:
bool: True if task was found and cancelled, False otherwise
"""
with self.lock:
if task_id in self.cancel_flags:
self.cancel_flags[task_id].set() # Signal cancellation
if task_id in self.active_tasks:
self.active_tasks[task_id]['status'] = 'cancelling'
self.active_tasks[task_id]['updated_at'] = datetime.now().isoformat()
self.active_tasks[task_id]['message'] = 'Cancellation requested'
logger.info(f"Task cancelled: {task_id}")
return True
logger.warning(f"Task not found for cancellation: {task_id}")
return False
def is_cancelled(self, task_id: str) -> bool:
"""
Check if task has been cancelled.
Args:
task_id: Task identifier to check
Returns:
bool: True if task is cancelled, False otherwise
"""
if task_id is None:
return False
return self.cancel_flags.get(task_id, threading.Event()).is_set()
def complete_task(self, task_id: str, success: bool = True, message: str = None):
"""
Mark task as complete.
Args:
task_id: Task identifier
success: Whether task completed successfully
message: Optional completion message
"""
with self.lock:
if task_id in self.active_tasks:
self.active_tasks[task_id]['status'] = 'completed' if success else 'failed'
self.active_tasks[task_id]['completed_at'] = datetime.now().isoformat()
self.active_tasks[task_id]['progress'] = 100 if success else self.active_tasks[task_id].get('progress', 0)
if message:
self.active_tasks[task_id]['message'] = message
logger.info(f"Task completed: {task_id} (success={success})")
def mark_cancelled(self, task_id: str):
"""
Mark task as fully cancelled (after cleanup).
Args:
task_id: Task identifier
"""
with self.lock:
if task_id in self.active_tasks:
self.active_tasks[task_id]['status'] = 'cancelled'
self.active_tasks[task_id]['cancelled_at'] = datetime.now().isoformat()
self.active_tasks[task_id]['message'] = 'Task cancelled successfully'
logger.info(f"Task marked as cancelled: {task_id}")
def cleanup_task(self, task_id: str):
"""
Remove task from tracking (for old/completed tasks).
Args:
task_id: Task identifier to remove
"""
with self.lock:
removed_task = self.active_tasks.pop(task_id, None)
self.cancel_flags.pop(task_id, None)
if removed_task:
logger.info(f"Task cleaned up: {task_id}")
def get_task_status(self, task_id: str) -> Optional[Dict]:
"""
Get current task status and metadata.
Args:
task_id: Task identifier
Returns:
Dict with task status or None if not found
"""
with self.lock:
return self.active_tasks.get(task_id, None)
def get_user_tasks(self, user_id: str) -> List[Dict]:
"""
Get all tasks for a specific user.
Args:
user_id: User identifier
Returns:
List of task dictionaries
"""
with self.lock:
user_tasks = [
{'task_id': tid, **task}
for tid, task in self.active_tasks.items()
if task.get('user_id') == user_id
]
return user_tasks
def cancel_user_tasks(self, user_id: str) -> int:
"""
Cancel all active tasks for a user.
Args:
user_id: User identifier
Returns:
int: Number of tasks cancelled
"""
with self.lock:
task_ids = [
tid for tid, task in self.active_tasks.items()
if task.get('user_id') == user_id and task.get('status') in ['pending', 'processing']
]
for tid in task_ids:
self.cancel_task(tid)
logger.info(f"Cancelled {len(task_ids)} tasks for user: {user_id}")
return len(task_ids)
def cleanup_old_tasks(self, max_age_seconds: int = 3600):
"""
Clean up tasks older than specified age.
Args:
max_age_seconds: Maximum age in seconds (default: 1 hour)
"""
current_time = datetime.now()
tasks_to_remove = []
with self.lock:
for task_id, task in self.active_tasks.items():
created_at = datetime.fromisoformat(task['created_at'])
age = (current_time - created_at).total_seconds()
if age > max_age_seconds and task.get('status') in ['completed', 'failed', 'cancelled']:
tasks_to_remove.append(task_id)
for task_id in tasks_to_remove:
self.cleanup_task(task_id)
if tasks_to_remove:
logger.info(f"Cleaned up {len(tasks_to_remove)} old tasks")
def get_stats(self) -> Dict[str, Any]:
"""
Get overall task statistics.
Returns:
Dict with task statistics
"""
with self.lock:
total = len(self.active_tasks)
by_status = {}
by_user = {}
for task in self.active_tasks.values():
status = task.get('status', 'unknown')
user_id = task.get('user_id', 'unknown')
by_status[status] = by_status.get(status, 0) + 1
by_user[user_id] = by_user.get(user_id, 0) + 1
return {
'total_tasks': total,
'by_status': by_status,
'by_user': by_user
}
# Global instance
_task_manager_instance = None
def get_task_manager() -> TaskManager:
"""
Get global TaskManager instance (singleton pattern).
Returns:
TaskManager instance
"""
global _task_manager_instance
if _task_manager_instance is None:
_task_manager_instance = TaskManager()
return _task_manager_instance