File size: 17,376 Bytes
c32c359
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39fa862
c32c359
 
 
 
39fa862
c32c359
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39fa862
 
 
c32c359
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39fa862
 
 
 
 
 
 
 
 
 
 
 
 
c32c359
 
 
 
 
 
 
 
 
 
39fa862
 
 
 
 
 
c32c359
 
 
 
 
 
 
 
 
 
 
 
 
39fa862
c32c359
 
39fa862
c32c359
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39fa862
 
 
 
 
 
 
c32c359
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39fa862
c32c359
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39fa862
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c32c359
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39fa862
 
 
c32c359
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39fa862
 
 
 
 
 
 
c32c359
 
 
 
 
 
 
 
 
 
 
 
 
 
39fa862
c32c359
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
"""LLMEngine: orchestrates scheduler + block manager + model runner + sampler.

Public surface:

  engine = LLMEngine(EngineConfig(...))
  await engine.startup()
  rid = engine.add_request(prompt_text, SamplingParams(...))
  async for delta in engine.stream(rid):
      ...

A single background task (`_run_loop`) drives the model.  Per-request output
goes through asyncio queues so the HTTP layer can stream incrementally.  A
second pub/sub channel emits engine-state snapshots for the visualization UI.
"""
from __future__ import annotations

import asyncio
import itertools
import json
import time
import uuid
from collections import deque
from dataclasses import dataclass, field
from typing import AsyncIterator, Optional, TextIO

from .block_manager import BlockManager
from .config import EngineConfig, SamplingParams
from .model_runner import ModelRunner
from .request import Sequence, SequenceStatus
from .sampler import Sampler
from .scheduler import Scheduler


@dataclass
class StreamItem:
    request_id: str
    new_text: str
    new_token_ids: list[int]
    finished: bool
    finish_reason: Optional[str] = None
    cumulative_text: str = ""


@dataclass
class EngineEvent:
    step: int
    timestamp: float
    type: str
    payload: dict = field(default_factory=dict)


class LLMEngine:
    def __init__(self, config: EngineConfig) -> None:
        self.config = config
        self.model_runner: Optional[ModelRunner] = None
        self.block_manager: Optional[BlockManager] = None
        self.scheduler: Optional[Scheduler] = None
        self.sampler: Optional[Sampler] = None

        # request_id → asyncio.Queue[StreamItem]
        self._output_queues: dict[str, asyncio.Queue[StreamItem]] = {}
        # request_id → Sequence (for inspection / abort)
        self._sequences: dict[str, Sequence] = {}
        # tracker for incremental detokenization
        self._prev_text_len: dict[str, int] = {}
        # event subscribers
        self._event_subscribers: list[asyncio.Queue[EngineEvent]] = []
        # control
        self._stop = asyncio.Event()
        self._step_idx = 0
        self._run_task: Optional[asyncio.Task] = None
        self._wake = asyncio.Event()
        # recording (for the static GH-Pages replay)
        self._record_fh: Optional[TextIO] = None
        self._record_t0: float = 0.0

    # ---- lifecycle ------------------------------------------------------

    async def startup(self) -> None:
        # Heavy: model load happens in a worker thread so we don't block the loop.
        loop = asyncio.get_running_loop()

        def _build() -> ModelRunner:
            return ModelRunner(self.config)

        self.model_runner = await loop.run_in_executor(None, _build)
        self.block_manager = BlockManager(
            num_blocks=self.config.num_blocks,
            block_size=self.config.block_size,
            enable_prefix_caching=self.config.enable_prefix_caching,
        )
        self.scheduler = Scheduler(self.config, self.block_manager)
        self.sampler = Sampler(self.model_runner.device)

        # Open the recorder *after* the block manager exists so the initial
        # snapshot we write is valid.
        if self.config.record_path:
            self._record_fh = open(self.config.record_path, "w", buffering=1)
            self._record_t0 = time.monotonic()
            self._record({
                "type": "snapshot",
                "step": 0,
                "timestamp": 0.0,
                "payload": self.snapshot(),
            })

        self._run_task = asyncio.create_task(self._run_loop())

    async def shutdown(self) -> None:
        self._stop.set()
        self._wake.set()
        if self._run_task is not None:
            try:
                await asyncio.wait_for(self._run_task, timeout=5)
            except asyncio.TimeoutError:
                self._run_task.cancel()
        if self._record_fh is not None:
            try:
                self._record_fh.close()
            except Exception:
                pass
            self._record_fh = None

    # ---- request submission --------------------------------------------

    def add_request(
        self,
        prompt: str | list[int],
        sampling_params: SamplingParams,
        request_id: Optional[str] = None,
    ) -> str:
        if self.model_runner is None:
            raise RuntimeError("engine not started")
        if isinstance(prompt, str):
            token_ids = self.model_runner.encode(prompt)
            prompt_text = prompt
        else:
            token_ids = list(prompt)
            prompt_text = self.model_runner.decode(token_ids)
        if not token_ids:
            raise ValueError("empty prompt")
        if len(token_ids) >= self.config.max_model_len:
            raise ValueError(
                f"prompt length {len(token_ids)} >= max_model_len {self.config.max_model_len}"
            )
        rid = request_id or uuid.uuid4().hex
        seq = Sequence(
            prompt_token_ids=token_ids,
            sampling_params=sampling_params,
            request_id=rid,
        )
        self._sequences[rid] = seq
        self._output_queues[rid] = asyncio.Queue()
        self._prev_text_len[rid] = 0
        assert self.scheduler is not None
        self.scheduler.add(seq)
        self._emit("request", {
            "request_id": rid,
            "seq_id": seq.seq_id,
            "prompt": prompt_text,
            "prompt_len": len(token_ids),
            "max_tokens": sampling_params.max_tokens,
        })
        self._wake.set()
        return rid

    def abort(self, request_id: str) -> bool:
        seq = self._sequences.get(request_id)
        if seq is None:
            return False
        assert self.scheduler is not None
        ok = self.scheduler.abort(seq.seq_id)
        if ok:
            self._close_request(request_id, finish_reason="abort")
        return ok

    async def stream(self, request_id: str) -> AsyncIterator[StreamItem]:
        q = self._output_queues.get(request_id)
        if q is None:
            raise KeyError(request_id)
        while True:
            item = await q.get()
            yield item
            if item.finished:
                break

    # ---- event subscriptions -------------------------------------------

    def subscribe_events(self) -> asyncio.Queue[EngineEvent]:
        q: asyncio.Queue[EngineEvent] = asyncio.Queue(maxsize=self.config.event_buffer)
        self._event_subscribers.append(q)
        return q

    def unsubscribe_events(self, q: asyncio.Queue[EngineEvent]) -> None:
        try:
            self._event_subscribers.remove(q)
        except ValueError:
            pass

    def _emit(self, event_type: str, payload: dict) -> None:
        if not self.config.emit_events:
            return
        ev = EngineEvent(
            step=self._step_idx,
            timestamp=time.monotonic(),
            type=event_type,
            payload=payload,
        )
        for q in list(self._event_subscribers):
            try:
                q.put_nowait(ev)
            except asyncio.QueueFull:
                try:
                    q.get_nowait()
                except asyncio.QueueEmpty:
                    pass
                try:
                    q.put_nowait(ev)
                except asyncio.QueueFull:
                    pass
        # Mirror into the on-disk recording (timestamps re-based to t0).
        if self._record_fh is not None:
            self._record({
                "type": ev.type,
                "step": ev.step,
                "timestamp": ev.timestamp - self._record_t0,
                "payload": ev.payload,
            })

    def _record(self, ev: dict) -> None:
        fh = self._record_fh
        if fh is None:
            return
        try:
            fh.write(json.dumps(ev, separators=(",", ":")) + "\n")
        except Exception:
            pass

    # ---- inspection ----------------------------------------------------

    def snapshot(self) -> dict:
        assert self.block_manager is not None and self.scheduler is not None
        def seq_view(s: Sequence) -> dict:
            return {
                "seq_id": s.seq_id,
                "request_id": s.request_id,
                "status": s.status.value,
                "prompt_len": s.prompt_len,
                "num_generated": len(s.output_token_ids),
                "num_computed_tokens": s.num_computed_tokens,
                "num_cached_prefix_tokens": s.num_cached_prefix_tokens,
                "block_table": list(s.block_table),
            }
        return {
            "step": self._step_idx,
            "block_pool": self.block_manager.snapshot(),
            "waiting": [seq_view(s) for s in self.scheduler.waiting],
            "running": [seq_view(s) for s in self.scheduler.running],
            "config": {
                "model": self.config.model,
                "block_size": self.config.block_size,
                "num_blocks": self.config.num_blocks,
                "max_num_seqs": self.config.max_num_seqs,
                "max_num_batched_tokens": self.config.max_num_batched_tokens,
                "prefix_caching": self.config.enable_prefix_caching,
            },
        }

    # ---- main loop -----------------------------------------------------

    async def _run_loop(self) -> None:
        assert self.scheduler is not None and self.model_runner is not None
        loop = asyncio.get_running_loop()
        while not self._stop.is_set():
            if not self.scheduler.has_work:
                self._wake.clear()
                try:
                    await asyncio.wait_for(self._wake.wait(), timeout=1.0)
                except asyncio.TimeoutError:
                    pass
                continue

            self._step_idx += 1
            t0 = time.monotonic()
            sched = self.scheduler.schedule()
            if sched.is_empty:
                # Nothing got through this step (probably starved on blocks).
                await asyncio.sleep(0.01)
                continue

            model_input = self.model_runner.prepare_input(sched.scheduled)
            # Run blocking model forward off-thread.
            logits = await loop.run_in_executor(None, self.model_runner.execute, model_input)

            # Update num_computed_tokens AFTER forward (the K/V is now stored).
            for item in sched.scheduled:
                item.seq.num_computed_tokens += item.num_tokens

            # Sample only for sequences that have finished prefill (i.e., the
            # last token in their chunk is the *final* prompt token).
            sampling_items = [item for item in sched.scheduled
                              if item.seq.num_computed_tokens >= item.seq.prompt_len]
            sampling_indices = [i for i, item in enumerate(sched.scheduled)
                                if item.seq.num_computed_tokens >= item.seq.prompt_len]

            new_tokens: dict[int, int] = {}
            if sampling_items:
                import torch  # local; cheap
                sampling_logits = logits.index_select(
                    0, torch.tensor(sampling_indices, device=logits.device)
                )
                params = [item.seq.sampling_params for item in sampling_items]
                generators = [
                    (torch.Generator(device=logits.device).manual_seed(item.seq.sampling_params.seed)
                     if item.seq.sampling_params.seed is not None else None)
                    for item in sampling_items
                ]
                token_ids = self.sampler.sample(sampling_logits, params, generators)
                for item, tok in zip(sampling_items, token_ids):
                    new_tokens[item.seq.seq_id] = tok

            # Apply new tokens, check stopping, register filled blocks.
            assert self.block_manager is not None
            finished_now: list[Sequence] = []
            for item in sched.scheduled:
                seq = item.seq
                if seq.seq_id in new_tokens:
                    tok = new_tokens[seq.seq_id]
                    seq.append_output_token(tok)
                    # The just-produced token's KV will be written on the NEXT
                    # step (when this token is the input).  But the new token
                    # may complete a block once its KV lands; we hash blocks
                    # only after their KV exists, so post-forward in the next
                    # step is the right time.  Here we register newly-filled
                    # blocks based on the just-finalized num_computed_tokens.
                    self.block_manager.register_filled_blocks(seq, prev_computed=0)

                    if self._should_stop(seq, tok):
                        seq.status = SequenceStatus.FINISHED
                        seq.finish_reason = self._stop_reason(seq, tok)
                        finished_now.append(seq)
                else:
                    # Still in prefill; just register newly filled prompt blocks.
                    self.block_manager.register_filled_blocks(seq, prev_computed=0)

            # Free finished sequences.
            for seq in finished_now:
                if seq in self.scheduler.running:
                    self.scheduler.running.remove(seq)
                self.block_manager.free(seq)

            # Emit outputs to per-request queues, and collect per-step deltas
            # for the event stream (powers the replay UI's text panes).
            step_deltas: list[dict] = []
            for item in sched.scheduled:
                seq = item.seq
                rid = seq.request_id
                if seq.seq_id in new_tokens or seq in finished_now:
                    new_text, new_text_len = self.model_runner.detokenize_incremental(
                        seq.all_token_ids(), self._prev_text_len.get(rid, 0)
                    )
                    self._prev_text_len[rid] = new_text_len
                    is_done = seq.status == SequenceStatus.FINISHED
                    new_toks = [new_tokens[seq.seq_id]] if seq.seq_id in new_tokens else []
                    si = StreamItem(
                        request_id=rid,
                        new_text=new_text,
                        new_token_ids=new_toks,
                        finished=is_done,
                        finish_reason=seq.finish_reason,
                        cumulative_text=self.model_runner.tokenizer.decode(
                            seq.output_token_ids, skip_special_tokens=True
                        ),
                    )
                    q = self._output_queues.get(rid)
                    if q is not None:
                        await q.put(si)
                    if new_text or is_done:
                        step_deltas.append({
                            "request_id": rid,
                            "new_text": new_text,
                            "finished": is_done,
                            "finish_reason": seq.finish_reason,
                        })
                    if is_done:
                        self._sequences.pop(rid, None)
                        self._prev_text_len.pop(rid, None)

            # Emit engine events for the UI.
            self._emit("step", {
                "duration_ms": (time.monotonic() - t0) * 1000,
                "num_seqs": len(sched.scheduled),
                "num_tokens": sched.total_tokens,
                "num_prefill_seqs": sum(1 for it in sched.scheduled if it.is_prefill),
                "num_decode_seqs": sum(1 for it in sched.scheduled if not it.is_prefill),
                "preempted": sched.preempted,
                "newly_admitted": sched.newly_admitted,
                "finished": [s.request_id for s in finished_now],
                "deltas": step_deltas,
                "snapshot": self.snapshot(),
            })

            # Yield control between steps so the HTTP layer can ship bytes.
            await asyncio.sleep(0)

    # ---- helpers -------------------------------------------------------

    def _should_stop(self, seq: Sequence, last_token: int) -> bool:
        sp = seq.sampling_params
        if len(seq.output_token_ids) >= sp.max_tokens:
            return True
        if not sp.ignore_eos:
            eos = self.model_runner.eos_token_id if self.model_runner else None
            if eos is not None and last_token == eos:
                return True
        if last_token in sp.stop_token_ids:
            return True
        if seq.total_len >= self.config.max_model_len:
            return True
        return False

    def _stop_reason(self, seq: Sequence, last_token: int) -> str:
        sp = seq.sampling_params
        if len(seq.output_token_ids) >= sp.max_tokens:
            return "length"
        if seq.total_len >= self.config.max_model_len:
            return "length"
        return "stop"

    def _close_request(self, request_id: str, finish_reason: str) -> None:
        q = self._output_queues.get(request_id)
        if q is None:
            return
        q.put_nowait(StreamItem(
            request_id=request_id,
            new_text="",
            new_token_ids=[],
            finished=True,
            finish_reason=finish_reason,
        ))
        self._sequences.pop(request_id, None)
        self._prev_text_len.pop(request_id, None)