techfreakworm's picture
feat(progress): ProgressBus with sessions, ticks, and turn-complete events
422829d unverified
"""Progress event bus for in-flight generations.
Endpoints (`/api/generate`, `/api/generate/dialog`) wrap their work in
`bus.session(...)` which emits `start` and `done`/`error` events plus a
0.5s `tick` heartbeat. Dialog mode also emits `turn_complete` between
adapter calls. Subscribers receive events via `subscribe()` (used by
the SSE endpoint).
"""
from __future__ import annotations
import asyncio
import time
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from typing import AsyncIterator, Literal
EventType = Literal["start", "tick", "turn_complete", "done", "error"]
@dataclass
class ProgressEvent:
type: EventType
elapsed_s: float
payload: dict = field(default_factory=dict)
def to_dict(self) -> dict:
return {"type": self.type, "elapsed_s": round(self.elapsed_s, 2), **self.payload}
class ProgressBus:
def __init__(self) -> None:
self._subscribers: list[asyncio.Queue[ProgressEvent]] = []
self._lock = asyncio.Lock()
self._current_session: "_Session | None" = None
async def publish(self, event: ProgressEvent) -> None:
async with self._lock:
subs = list(self._subscribers)
for q in subs:
await q.put(event)
@asynccontextmanager
async def subscribe(self) -> AsyncIterator[asyncio.Queue[ProgressEvent]]:
q: asyncio.Queue[ProgressEvent] = asyncio.Queue()
async with self._lock:
self._subscribers.append(q)
if self._current_session is not None:
snapshot = self._current_session.snapshot_event()
if snapshot is not None:
await q.put(snapshot)
try:
yield q
finally:
async with self._lock:
if q in self._subscribers:
self._subscribers.remove(q)
@asynccontextmanager
async def session(
self, kind: Literal["single", "dialog"], total_turns: int = 1,
) -> AsyncIterator["_Session"]:
session = _Session(bus=self, kind=kind, total_turns=total_turns)
async with self._lock:
self._current_session = session
await self.publish(
ProgressEvent(
type="start",
elapsed_s=0.0,
payload={"kind": kind, "total_turns": total_turns, "turn": 0},
),
)
ticker = asyncio.create_task(session._tick_loop())
try:
yield session
await self.publish(
ProgressEvent(
type="done",
elapsed_s=session.elapsed(),
payload={
"kind": kind,
"seed_used": session.seed_used,
"turn": session.turn,
"total_turns": total_turns,
},
),
)
except Exception as exc:
await self.publish(
ProgressEvent(
type="error",
elapsed_s=session.elapsed(),
payload={"message": str(exc)},
),
)
raise
finally:
ticker.cancel()
try:
await ticker
except asyncio.CancelledError:
pass
async with self._lock:
if self._current_session is session:
self._current_session = None
@dataclass
class _Session:
bus: ProgressBus
kind: Literal["single", "dialog"]
total_turns: int
started_at: float = field(default_factory=time.monotonic)
turn: int = 0
seed_used: int | None = None
def elapsed(self) -> float:
return time.monotonic() - self.started_at
def set_seed(self, seed: int) -> None:
self.seed_used = seed
async def turn_complete(self, turn_index: int) -> None:
self.turn = turn_index
await self.bus.publish(
ProgressEvent(
type="turn_complete",
elapsed_s=self.elapsed(),
payload={
"turn": turn_index,
"total_turns": self.total_turns,
"kind": self.kind,
},
),
)
async def _tick_loop(self) -> None:
try:
while True:
await asyncio.sleep(0.5)
await self.bus.publish(
ProgressEvent(
type="tick",
elapsed_s=self.elapsed(),
payload={
"kind": self.kind,
"turn": self.turn,
"total_turns": self.total_turns,
},
),
)
except asyncio.CancelledError:
pass
def snapshot_event(self) -> ProgressEvent | None:
return ProgressEvent(
type="tick",
elapsed_s=self.elapsed(),
payload={
"kind": self.kind,
"turn": self.turn,
"total_turns": self.total_turns,
},
)
_BUS: ProgressBus | None = None
def get_bus() -> ProgressBus:
global _BUS
if _BUS is None:
_BUS = ProgressBus()
return _BUS