| import asyncio |
| import logging |
| from typing import Any |
|
|
| from sqlalchemy import select |
| from sqlalchemy.ext.asyncio import AsyncSession |
|
|
| from app.database import AsyncSessionLocal |
| from app.models import Job |
| from app.schemas import job_to_status_response |
|
|
| logger = logging.getLogger(__name__) |
|
|
| _queues: dict[str, list[asyncio.Queue[dict[str, Any]]]] = {} |
|
|
|
|
| def _qs(job_id: str) -> list[asyncio.Queue[dict[str, Any]]]: |
| return _queues.setdefault(job_id, []) |
|
|
|
|
| def subscribe(job_id: str) -> asyncio.Queue[dict[str, Any]]: |
| q: asyncio.Queue[dict[str, Any]] = asyncio.Queue(maxsize=64) |
| _qs(job_id).append(q) |
| return q |
|
|
|
|
| def unsubscribe(job_id: str, q: asyncio.Queue[dict[str, Any]]) -> None: |
| lst = _qs(job_id) |
| if q in lst: |
| lst.remove(q) |
| if not lst and job_id in _queues: |
| del _queues[job_id] |
|
|
|
|
| async def broadcast_job(job_id: str) -> None: |
| try: |
| async with AsyncSessionLocal() as session: |
| result = await session.execute(select(Job).where(Job.id == job_id)) |
| job = result.scalar_one_or_none() |
| if not job: |
| return |
| payload = job_to_status_response(job).model_dump(mode="json") |
| message = {"type": "job_status", "data": payload} |
| for q in list(_qs(job_id)): |
| try: |
| q.put_nowait(message) |
| except asyncio.QueueFull: |
| try: |
| q.get_nowait() |
| except asyncio.QueueEmpty: |
| pass |
| try: |
| q.put_nowait(message) |
| except asyncio.QueueFull: |
| logger.debug("Progress queue still full for job %s", job_id) |
| except Exception: |
| logger.exception("broadcast_job failed for %s", job_id) |
|
|