File size: 1,813 Bytes
1c167a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import asyncio
import logging
from typing import Any

from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

from app.database import AsyncSessionLocal
from app.models import Job
from app.schemas import job_to_status_response

logger = logging.getLogger(__name__)

_queues: dict[str, list[asyncio.Queue[dict[str, Any]]]] = {}


def _qs(job_id: str) -> list[asyncio.Queue[dict[str, Any]]]:
    return _queues.setdefault(job_id, [])


def subscribe(job_id: str) -> asyncio.Queue[dict[str, Any]]:
    q: asyncio.Queue[dict[str, Any]] = asyncio.Queue(maxsize=64)
    _qs(job_id).append(q)
    return q


def unsubscribe(job_id: str, q: asyncio.Queue[dict[str, Any]]) -> None:
    lst = _qs(job_id)
    if q in lst:
        lst.remove(q)
    if not lst and job_id in _queues:
        del _queues[job_id]


async def broadcast_job(job_id: str) -> None:
    try:
        async with AsyncSessionLocal() as session:
            result = await session.execute(select(Job).where(Job.id == job_id))
            job = result.scalar_one_or_none()
            if not job:
                return
            payload = job_to_status_response(job).model_dump(mode="json")
            message = {"type": "job_status", "data": payload}
            for q in list(_qs(job_id)):
                try:
                    q.put_nowait(message)
                except asyncio.QueueFull:
                    try:
                        q.get_nowait()
                    except asyncio.QueueEmpty:
                        pass
                    try:
                        q.put_nowait(message)
                    except asyncio.QueueFull:
                        logger.debug("Progress queue still full for job %s", job_id)
    except Exception:
        logger.exception("broadcast_job failed for %s", job_id)