File size: 5,356 Bytes
422829d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 | """Progress event bus for in-flight generations.
Endpoints (`/api/generate`, `/api/generate/dialog`) wrap their work in
`bus.session(...)` which emits `start` and `done`/`error` events plus a
0.5s `tick` heartbeat. Dialog mode also emits `turn_complete` between
adapter calls. Subscribers receive events via `subscribe()` (used by
the SSE endpoint).
"""
from __future__ import annotations
import asyncio
import time
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from typing import AsyncIterator, Literal
EventType = Literal["start", "tick", "turn_complete", "done", "error"]
@dataclass
class ProgressEvent:
type: EventType
elapsed_s: float
payload: dict = field(default_factory=dict)
def to_dict(self) -> dict:
return {"type": self.type, "elapsed_s": round(self.elapsed_s, 2), **self.payload}
class ProgressBus:
def __init__(self) -> None:
self._subscribers: list[asyncio.Queue[ProgressEvent]] = []
self._lock = asyncio.Lock()
self._current_session: "_Session | None" = None
async def publish(self, event: ProgressEvent) -> None:
async with self._lock:
subs = list(self._subscribers)
for q in subs:
await q.put(event)
@asynccontextmanager
async def subscribe(self) -> AsyncIterator[asyncio.Queue[ProgressEvent]]:
q: asyncio.Queue[ProgressEvent] = asyncio.Queue()
async with self._lock:
self._subscribers.append(q)
if self._current_session is not None:
snapshot = self._current_session.snapshot_event()
if snapshot is not None:
await q.put(snapshot)
try:
yield q
finally:
async with self._lock:
if q in self._subscribers:
self._subscribers.remove(q)
@asynccontextmanager
async def session(
self, kind: Literal["single", "dialog"], total_turns: int = 1,
) -> AsyncIterator["_Session"]:
session = _Session(bus=self, kind=kind, total_turns=total_turns)
async with self._lock:
self._current_session = session
await self.publish(
ProgressEvent(
type="start",
elapsed_s=0.0,
payload={"kind": kind, "total_turns": total_turns, "turn": 0},
),
)
ticker = asyncio.create_task(session._tick_loop())
try:
yield session
await self.publish(
ProgressEvent(
type="done",
elapsed_s=session.elapsed(),
payload={
"kind": kind,
"seed_used": session.seed_used,
"turn": session.turn,
"total_turns": total_turns,
},
),
)
except Exception as exc:
await self.publish(
ProgressEvent(
type="error",
elapsed_s=session.elapsed(),
payload={"message": str(exc)},
),
)
raise
finally:
ticker.cancel()
try:
await ticker
except asyncio.CancelledError:
pass
async with self._lock:
if self._current_session is session:
self._current_session = None
@dataclass
class _Session:
bus: ProgressBus
kind: Literal["single", "dialog"]
total_turns: int
started_at: float = field(default_factory=time.monotonic)
turn: int = 0
seed_used: int | None = None
def elapsed(self) -> float:
return time.monotonic() - self.started_at
def set_seed(self, seed: int) -> None:
self.seed_used = seed
async def turn_complete(self, turn_index: int) -> None:
self.turn = turn_index
await self.bus.publish(
ProgressEvent(
type="turn_complete",
elapsed_s=self.elapsed(),
payload={
"turn": turn_index,
"total_turns": self.total_turns,
"kind": self.kind,
},
),
)
async def _tick_loop(self) -> None:
try:
while True:
await asyncio.sleep(0.5)
await self.bus.publish(
ProgressEvent(
type="tick",
elapsed_s=self.elapsed(),
payload={
"kind": self.kind,
"turn": self.turn,
"total_turns": self.total_turns,
},
),
)
except asyncio.CancelledError:
pass
def snapshot_event(self) -> ProgressEvent | None:
return ProgressEvent(
type="tick",
elapsed_s=self.elapsed(),
payload={
"kind": self.kind,
"turn": self.turn,
"total_turns": self.total_turns,
},
)
_BUS: ProgressBus | None = None
def get_bus() -> ProgressBus:
global _BUS
if _BUS is None:
_BUS = ProgressBus()
return _BUS
|