File size: 1,813 Bytes
1c167a4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 | import asyncio
import logging
from typing import Any
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import AsyncSessionLocal
from app.models import Job
from app.schemas import job_to_status_response
logger = logging.getLogger(__name__)
_queues: dict[str, list[asyncio.Queue[dict[str, Any]]]] = {}
def _qs(job_id: str) -> list[asyncio.Queue[dict[str, Any]]]:
return _queues.setdefault(job_id, [])
def subscribe(job_id: str) -> asyncio.Queue[dict[str, Any]]:
q: asyncio.Queue[dict[str, Any]] = asyncio.Queue(maxsize=64)
_qs(job_id).append(q)
return q
def unsubscribe(job_id: str, q: asyncio.Queue[dict[str, Any]]) -> None:
lst = _qs(job_id)
if q in lst:
lst.remove(q)
if not lst and job_id in _queues:
del _queues[job_id]
async def broadcast_job(job_id: str) -> None:
try:
async with AsyncSessionLocal() as session:
result = await session.execute(select(Job).where(Job.id == job_id))
job = result.scalar_one_or_none()
if not job:
return
payload = job_to_status_response(job).model_dump(mode="json")
message = {"type": "job_status", "data": payload}
for q in list(_qs(job_id)):
try:
q.put_nowait(message)
except asyncio.QueueFull:
try:
q.get_nowait()
except asyncio.QueueEmpty:
pass
try:
q.put_nowait(message)
except asyncio.QueueFull:
logger.debug("Progress queue still full for job %s", job_id)
except Exception:
logger.exception("broadcast_job failed for %s", job_id)
|