Spaces:
Running
Running
| """Background task queue for model inference offloading. | |
| Phase 8: Provides async task queue for long-running inference tasks. | |
| Uses ARQ (async Redis queue) when Redis is available, otherwise falls back | |
| to asyncio.create_task for single-process operation. | |
| Set ``TASK_QUEUE_ENABLED=true`` and ``TASK_QUEUE_BROKER_URL`` to enable. | |
| """ | |
| import asyncio | |
| import logging | |
| import time | |
| import uuid | |
| from dataclasses import dataclass, field | |
| from enum import Enum | |
| from typing import Any, Callable, Coroutine, Dict, Optional | |
| from app.config import settings | |
| logger = logging.getLogger(__name__) | |
| class TaskStatus(str, Enum): | |
| PENDING = "pending" | |
| RUNNING = "running" | |
| COMPLETED = "completed" | |
| FAILED = "failed" | |
| class TaskResult: | |
| task_id: str | |
| status: TaskStatus | |
| result: Any = None | |
| error: Optional[str] = None | |
| created_at: float = field(default_factory=time.time) | |
| completed_at: Optional[float] = None | |
| duration_seconds: Optional[float] = None | |
| class InMemoryTaskQueue: | |
| """Asyncio-based task queue for single-process deployments.""" | |
| def __init__(self, max_concurrent: int = 2, max_queue_size: int = 20): | |
| self._semaphore = asyncio.Semaphore(max_concurrent) | |
| self._results: Dict[str, TaskResult] = {} | |
| self._max_queue_size = max_queue_size | |
| self._active_count = 0 | |
| async def enqueue( | |
| self, | |
| func: Callable[..., Coroutine], | |
| *args: Any, | |
| task_id: Optional[str] = None, | |
| **kwargs: Any, | |
| ) -> str: | |
| """Submit a coroutine for background execution.""" | |
| task_id = task_id or str(uuid.uuid4()) | |
| if len(self._results) >= self._max_queue_size: | |
| # Prune completed tasks to make room | |
| completed = [ | |
| k for k, v in self._results.items() | |
| if v.status in (TaskStatus.COMPLETED, TaskStatus.FAILED) | |
| ] | |
| for k in completed[:len(completed) // 2]: | |
| del self._results[k] | |
| if len(self._results) >= self._max_queue_size: | |
| raise RuntimeError("Task queue is full") | |
| self._results[task_id] = TaskResult( | |
| task_id=task_id, status=TaskStatus.PENDING | |
| ) | |
| asyncio.create_task(self._run(task_id, func, *args, **kwargs)) | |
| return task_id | |
| async def _run( | |
| self, | |
| task_id: str, | |
| func: Callable[..., Coroutine], | |
| *args: Any, | |
| **kwargs: Any, | |
| ) -> None: | |
| async with self._semaphore: | |
| self._active_count += 1 | |
| result = self._results[task_id] | |
| result.status = TaskStatus.RUNNING | |
| start = time.time() | |
| try: | |
| result.result = await func(*args, **kwargs) | |
| result.status = TaskStatus.COMPLETED | |
| except Exception as e: | |
| result.status = TaskStatus.FAILED | |
| result.error = str(e) | |
| logger.error("Task %s failed: %s", task_id, e) | |
| finally: | |
| result.completed_at = time.time() | |
| result.duration_seconds = result.completed_at - start | |
| self._active_count -= 1 | |
| def get_result(self, task_id: str) -> Optional[TaskResult]: | |
| return self._results.get(task_id) | |
| def active_tasks(self) -> int: | |
| return self._active_count | |
| def pending_tasks(self) -> int: | |
| return sum( | |
| 1 for r in self._results.values() if r.status == TaskStatus.PENDING | |
| ) | |
| def stats(self) -> Dict[str, Any]: | |
| return { | |
| "active": self._active_count, | |
| "pending": self.pending_tasks, | |
| "total_tracked": len(self._results), | |
| "backend": "asyncio", | |
| } | |
| class ARQTaskQueue: | |
| """ARQ (async Redis queue) backend for distributed workers.""" | |
| def __init__(self, broker_url: str, max_concurrent: int = 2): | |
| self._broker_url = broker_url | |
| self._max_concurrent = max_concurrent | |
| self._pool = None | |
| self._results: Dict[str, TaskResult] = {} | |
| async def _connect(self): | |
| if self._pool is not None: | |
| return | |
| try: | |
| import redis.asyncio as aioredis | |
| self._pool = aioredis.from_url(self._broker_url) | |
| await self._pool.ping() | |
| logger.info("ARQ task queue connected: %s", self._broker_url.split("@")[-1]) | |
| except Exception as e: | |
| logger.warning("ARQ connection failed, tasks will run in-process: %s", e) | |
| self._pool = None | |
| async def enqueue( | |
| self, | |
| func: Callable[..., Coroutine], | |
| *args: Any, | |
| task_id: Optional[str] = None, | |
| **kwargs: Any, | |
| ) -> str: | |
| """Enqueue task. Falls back to in-process if Redis unavailable.""" | |
| task_id = task_id or str(uuid.uuid4()) | |
| await self._connect() | |
| # For now, run in-process with tracking (ARQ worker integration | |
| # would serialize func name and dispatch to worker process) | |
| self._results[task_id] = TaskResult( | |
| task_id=task_id, status=TaskStatus.PENDING | |
| ) | |
| asyncio.create_task(self._run_local(task_id, func, *args, **kwargs)) | |
| return task_id | |
| async def _run_local( | |
| self, | |
| task_id: str, | |
| func: Callable[..., Coroutine], | |
| *args: Any, | |
| **kwargs: Any, | |
| ) -> None: | |
| result = self._results[task_id] | |
| result.status = TaskStatus.RUNNING | |
| start = time.time() | |
| try: | |
| result.result = await func(*args, **kwargs) | |
| result.status = TaskStatus.COMPLETED | |
| except Exception as e: | |
| result.status = TaskStatus.FAILED | |
| result.error = str(e) | |
| finally: | |
| result.completed_at = time.time() | |
| result.duration_seconds = result.completed_at - start | |
| def get_result(self, task_id: str) -> Optional[TaskResult]: | |
| return self._results.get(task_id) | |
| def stats(self) -> Dict[str, Any]: | |
| return { | |
| "active": sum(1 for r in self._results.values() if r.status == TaskStatus.RUNNING), | |
| "pending": sum(1 for r in self._results.values() if r.status == TaskStatus.PENDING), | |
| "total_tracked": len(self._results), | |
| "backend": "arq" if self._pool else "asyncio-fallback", | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Factory | |
| # --------------------------------------------------------------------------- | |
| _task_queue = None | |
| def get_task_queue() -> InMemoryTaskQueue: | |
| """Get or create the global task queue.""" | |
| global _task_queue | |
| if _task_queue is None: | |
| if settings.task_queue_enabled and settings.task_queue_broker_url: | |
| _task_queue = ARQTaskQueue( | |
| broker_url=settings.task_queue_broker_url, | |
| max_concurrent=settings.queue_max_concurrent_inferences, | |
| ) | |
| else: | |
| _task_queue = InMemoryTaskQueue( | |
| max_concurrent=settings.queue_max_concurrent_inferences, | |
| max_queue_size=settings.queue_max_size, | |
| ) | |
| return _task_queue | |