"""Progress event bus for in-flight generations. Endpoints (`/api/generate`, `/api/generate/dialog`) wrap their work in `bus.session(...)` which emits `start` and `done`/`error` events plus a 0.5s `tick` heartbeat. Dialog mode also emits `turn_complete` between adapter calls. Subscribers receive events via `subscribe()` (used by the SSE endpoint). """ from __future__ import annotations import asyncio import time from contextlib import asynccontextmanager from dataclasses import dataclass, field from typing import AsyncIterator, Literal EventType = Literal["start", "tick", "turn_complete", "done", "error"] @dataclass class ProgressEvent: type: EventType elapsed_s: float payload: dict = field(default_factory=dict) def to_dict(self) -> dict: return {"type": self.type, "elapsed_s": round(self.elapsed_s, 2), **self.payload} class ProgressBus: def __init__(self) -> None: self._subscribers: list[asyncio.Queue[ProgressEvent]] = [] self._lock = asyncio.Lock() self._current_session: "_Session | None" = None async def publish(self, event: ProgressEvent) -> None: async with self._lock: subs = list(self._subscribers) for q in subs: await q.put(event) @asynccontextmanager async def subscribe(self) -> AsyncIterator[asyncio.Queue[ProgressEvent]]: q: asyncio.Queue[ProgressEvent] = asyncio.Queue() async with self._lock: self._subscribers.append(q) if self._current_session is not None: snapshot = self._current_session.snapshot_event() if snapshot is not None: await q.put(snapshot) try: yield q finally: async with self._lock: if q in self._subscribers: self._subscribers.remove(q) @asynccontextmanager async def session( self, kind: Literal["single", "dialog"], total_turns: int = 1, ) -> AsyncIterator["_Session"]: session = _Session(bus=self, kind=kind, total_turns=total_turns) async with self._lock: self._current_session = session await self.publish( ProgressEvent( type="start", elapsed_s=0.0, payload={"kind": kind, "total_turns": total_turns, "turn": 0}, ), ) ticker = asyncio.create_task(session._tick_loop()) try: yield session await self.publish( ProgressEvent( type="done", elapsed_s=session.elapsed(), payload={ "kind": kind, "seed_used": session.seed_used, "turn": session.turn, "total_turns": total_turns, }, ), ) except Exception as exc: await self.publish( ProgressEvent( type="error", elapsed_s=session.elapsed(), payload={"message": str(exc)}, ), ) raise finally: ticker.cancel() try: await ticker except asyncio.CancelledError: pass async with self._lock: if self._current_session is session: self._current_session = None @dataclass class _Session: bus: ProgressBus kind: Literal["single", "dialog"] total_turns: int started_at: float = field(default_factory=time.monotonic) turn: int = 0 seed_used: int | None = None def elapsed(self) -> float: return time.monotonic() - self.started_at def set_seed(self, seed: int) -> None: self.seed_used = seed async def turn_complete(self, turn_index: int) -> None: self.turn = turn_index await self.bus.publish( ProgressEvent( type="turn_complete", elapsed_s=self.elapsed(), payload={ "turn": turn_index, "total_turns": self.total_turns, "kind": self.kind, }, ), ) async def _tick_loop(self) -> None: try: while True: await asyncio.sleep(0.5) await self.bus.publish( ProgressEvent( type="tick", elapsed_s=self.elapsed(), payload={ "kind": self.kind, "turn": self.turn, "total_turns": self.total_turns, }, ), ) except asyncio.CancelledError: pass def snapshot_event(self) -> ProgressEvent | None: return ProgressEvent( type="tick", elapsed_s=self.elapsed(), payload={ "kind": self.kind, "turn": self.turn, "total_turns": self.total_turns, }, ) _BUS: ProgressBus | None = None def get_bus() -> ProgressBus: global _BUS if _BUS is None: _BUS = ProgressBus() return _BUS