VoxDoc / app /task_queue.py
joelthomas77's picture
Upload app code
60d4850 verified
"""Background task queue for model inference offloading.
Phase 8: Provides async task queue for long-running inference tasks.
Uses ARQ (async Redis queue) when Redis is available, otherwise falls back
to asyncio.create_task for single-process operation.
Set ``TASK_QUEUE_ENABLED=true`` and ``TASK_QUEUE_BROKER_URL`` to enable.
"""
import asyncio
import logging
import time
import uuid
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Coroutine, Dict, Optional
from app.config import settings
logger = logging.getLogger(__name__)
class TaskStatus(str, Enum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
@dataclass
class TaskResult:
task_id: str
status: TaskStatus
result: Any = None
error: Optional[str] = None
created_at: float = field(default_factory=time.time)
completed_at: Optional[float] = None
duration_seconds: Optional[float] = None
class InMemoryTaskQueue:
"""Asyncio-based task queue for single-process deployments."""
def __init__(self, max_concurrent: int = 2, max_queue_size: int = 20):
self._semaphore = asyncio.Semaphore(max_concurrent)
self._results: Dict[str, TaskResult] = {}
self._max_queue_size = max_queue_size
self._active_count = 0
async def enqueue(
self,
func: Callable[..., Coroutine],
*args: Any,
task_id: Optional[str] = None,
**kwargs: Any,
) -> str:
"""Submit a coroutine for background execution."""
task_id = task_id or str(uuid.uuid4())
if len(self._results) >= self._max_queue_size:
# Prune completed tasks to make room
completed = [
k for k, v in self._results.items()
if v.status in (TaskStatus.COMPLETED, TaskStatus.FAILED)
]
for k in completed[:len(completed) // 2]:
del self._results[k]
if len(self._results) >= self._max_queue_size:
raise RuntimeError("Task queue is full")
self._results[task_id] = TaskResult(
task_id=task_id, status=TaskStatus.PENDING
)
asyncio.create_task(self._run(task_id, func, *args, **kwargs))
return task_id
async def _run(
self,
task_id: str,
func: Callable[..., Coroutine],
*args: Any,
**kwargs: Any,
) -> None:
async with self._semaphore:
self._active_count += 1
result = self._results[task_id]
result.status = TaskStatus.RUNNING
start = time.time()
try:
result.result = await func(*args, **kwargs)
result.status = TaskStatus.COMPLETED
except Exception as e:
result.status = TaskStatus.FAILED
result.error = str(e)
logger.error("Task %s failed: %s", task_id, e)
finally:
result.completed_at = time.time()
result.duration_seconds = result.completed_at - start
self._active_count -= 1
def get_result(self, task_id: str) -> Optional[TaskResult]:
return self._results.get(task_id)
@property
def active_tasks(self) -> int:
return self._active_count
@property
def pending_tasks(self) -> int:
return sum(
1 for r in self._results.values() if r.status == TaskStatus.PENDING
)
def stats(self) -> Dict[str, Any]:
return {
"active": self._active_count,
"pending": self.pending_tasks,
"total_tracked": len(self._results),
"backend": "asyncio",
}
class ARQTaskQueue:
"""ARQ (async Redis queue) backend for distributed workers."""
def __init__(self, broker_url: str, max_concurrent: int = 2):
self._broker_url = broker_url
self._max_concurrent = max_concurrent
self._pool = None
self._results: Dict[str, TaskResult] = {}
async def _connect(self):
if self._pool is not None:
return
try:
import redis.asyncio as aioredis
self._pool = aioredis.from_url(self._broker_url)
await self._pool.ping()
logger.info("ARQ task queue connected: %s", self._broker_url.split("@")[-1])
except Exception as e:
logger.warning("ARQ connection failed, tasks will run in-process: %s", e)
self._pool = None
async def enqueue(
self,
func: Callable[..., Coroutine],
*args: Any,
task_id: Optional[str] = None,
**kwargs: Any,
) -> str:
"""Enqueue task. Falls back to in-process if Redis unavailable."""
task_id = task_id or str(uuid.uuid4())
await self._connect()
# For now, run in-process with tracking (ARQ worker integration
# would serialize func name and dispatch to worker process)
self._results[task_id] = TaskResult(
task_id=task_id, status=TaskStatus.PENDING
)
asyncio.create_task(self._run_local(task_id, func, *args, **kwargs))
return task_id
async def _run_local(
self,
task_id: str,
func: Callable[..., Coroutine],
*args: Any,
**kwargs: Any,
) -> None:
result = self._results[task_id]
result.status = TaskStatus.RUNNING
start = time.time()
try:
result.result = await func(*args, **kwargs)
result.status = TaskStatus.COMPLETED
except Exception as e:
result.status = TaskStatus.FAILED
result.error = str(e)
finally:
result.completed_at = time.time()
result.duration_seconds = result.completed_at - start
def get_result(self, task_id: str) -> Optional[TaskResult]:
return self._results.get(task_id)
def stats(self) -> Dict[str, Any]:
return {
"active": sum(1 for r in self._results.values() if r.status == TaskStatus.RUNNING),
"pending": sum(1 for r in self._results.values() if r.status == TaskStatus.PENDING),
"total_tracked": len(self._results),
"backend": "arq" if self._pool else "asyncio-fallback",
}
# ---------------------------------------------------------------------------
# Factory
# ---------------------------------------------------------------------------
_task_queue = None
def get_task_queue() -> InMemoryTaskQueue:
"""Get or create the global task queue."""
global _task_queue
if _task_queue is None:
if settings.task_queue_enabled and settings.task_queue_broker_url:
_task_queue = ARQTaskQueue(
broker_url=settings.task_queue_broker_url,
max_concurrent=settings.queue_max_concurrent_inferences,
)
else:
_task_queue = InMemoryTaskQueue(
max_concurrent=settings.queue_max_concurrent_inferences,
max_queue_size=settings.queue_max_size,
)
return _task_queue