File size: 5,356 Bytes
422829d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
"""Progress event bus for in-flight generations.

Endpoints (`/api/generate`, `/api/generate/dialog`) wrap their work in
`bus.session(...)` which emits `start` and `done`/`error` events plus a
0.5s `tick` heartbeat. Dialog mode also emits `turn_complete` between
adapter calls. Subscribers receive events via `subscribe()` (used by
the SSE endpoint).
"""
from __future__ import annotations

import asyncio
import time
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from typing import AsyncIterator, Literal


EventType = Literal["start", "tick", "turn_complete", "done", "error"]


@dataclass
class ProgressEvent:
    type: EventType
    elapsed_s: float
    payload: dict = field(default_factory=dict)

    def to_dict(self) -> dict:
        return {"type": self.type, "elapsed_s": round(self.elapsed_s, 2), **self.payload}


class ProgressBus:
    def __init__(self) -> None:
        self._subscribers: list[asyncio.Queue[ProgressEvent]] = []
        self._lock = asyncio.Lock()
        self._current_session: "_Session | None" = None

    async def publish(self, event: ProgressEvent) -> None:
        async with self._lock:
            subs = list(self._subscribers)
        for q in subs:
            await q.put(event)

    @asynccontextmanager
    async def subscribe(self) -> AsyncIterator[asyncio.Queue[ProgressEvent]]:
        q: asyncio.Queue[ProgressEvent] = asyncio.Queue()
        async with self._lock:
            self._subscribers.append(q)
            if self._current_session is not None:
                snapshot = self._current_session.snapshot_event()
                if snapshot is not None:
                    await q.put(snapshot)
        try:
            yield q
        finally:
            async with self._lock:
                if q in self._subscribers:
                    self._subscribers.remove(q)

    @asynccontextmanager
    async def session(
        self, kind: Literal["single", "dialog"], total_turns: int = 1,
    ) -> AsyncIterator["_Session"]:
        session = _Session(bus=self, kind=kind, total_turns=total_turns)
        async with self._lock:
            self._current_session = session
        await self.publish(
            ProgressEvent(
                type="start",
                elapsed_s=0.0,
                payload={"kind": kind, "total_turns": total_turns, "turn": 0},
            ),
        )
        ticker = asyncio.create_task(session._tick_loop())
        try:
            yield session
            await self.publish(
                ProgressEvent(
                    type="done",
                    elapsed_s=session.elapsed(),
                    payload={
                        "kind": kind,
                        "seed_used": session.seed_used,
                        "turn": session.turn,
                        "total_turns": total_turns,
                    },
                ),
            )
        except Exception as exc:
            await self.publish(
                ProgressEvent(
                    type="error",
                    elapsed_s=session.elapsed(),
                    payload={"message": str(exc)},
                ),
            )
            raise
        finally:
            ticker.cancel()
            try:
                await ticker
            except asyncio.CancelledError:
                pass
            async with self._lock:
                if self._current_session is session:
                    self._current_session = None


@dataclass
class _Session:
    bus: ProgressBus
    kind: Literal["single", "dialog"]
    total_turns: int
    started_at: float = field(default_factory=time.monotonic)
    turn: int = 0
    seed_used: int | None = None

    def elapsed(self) -> float:
        return time.monotonic() - self.started_at

    def set_seed(self, seed: int) -> None:
        self.seed_used = seed

    async def turn_complete(self, turn_index: int) -> None:
        self.turn = turn_index
        await self.bus.publish(
            ProgressEvent(
                type="turn_complete",
                elapsed_s=self.elapsed(),
                payload={
                    "turn": turn_index,
                    "total_turns": self.total_turns,
                    "kind": self.kind,
                },
            ),
        )

    async def _tick_loop(self) -> None:
        try:
            while True:
                await asyncio.sleep(0.5)
                await self.bus.publish(
                    ProgressEvent(
                        type="tick",
                        elapsed_s=self.elapsed(),
                        payload={
                            "kind": self.kind,
                            "turn": self.turn,
                            "total_turns": self.total_turns,
                        },
                    ),
                )
        except asyncio.CancelledError:
            pass

    def snapshot_event(self) -> ProgressEvent | None:
        return ProgressEvent(
            type="tick",
            elapsed_s=self.elapsed(),
            payload={
                "kind": self.kind,
                "turn": self.turn,
                "total_turns": self.total_turns,
            },
        )


_BUS: ProgressBus | None = None


def get_bus() -> ProgressBus:
    global _BUS
    if _BUS is None:
        _BUS = ProgressBus()
    return _BUS