tiny-vllm / tiny_vllm /engine.py
enCoder's picture
Add GitHub Pages demo and recording functionality
39fa862
"""LLMEngine: orchestrates scheduler + block manager + model runner + sampler.
Public surface:
engine = LLMEngine(EngineConfig(...))
await engine.startup()
rid = engine.add_request(prompt_text, SamplingParams(...))
async for delta in engine.stream(rid):
...
A single background task (`_run_loop`) drives the model. Per-request output
goes through asyncio queues so the HTTP layer can stream incrementally. A
second pub/sub channel emits engine-state snapshots for the visualization UI.
"""
from __future__ import annotations
import asyncio
import itertools
import json
import time
import uuid
from collections import deque
from dataclasses import dataclass, field
from typing import AsyncIterator, Optional, TextIO
from .block_manager import BlockManager
from .config import EngineConfig, SamplingParams
from .model_runner import ModelRunner
from .request import Sequence, SequenceStatus
from .sampler import Sampler
from .scheduler import Scheduler
@dataclass
class StreamItem:
request_id: str
new_text: str
new_token_ids: list[int]
finished: bool
finish_reason: Optional[str] = None
cumulative_text: str = ""
@dataclass
class EngineEvent:
step: int
timestamp: float
type: str
payload: dict = field(default_factory=dict)
class LLMEngine:
def __init__(self, config: EngineConfig) -> None:
self.config = config
self.model_runner: Optional[ModelRunner] = None
self.block_manager: Optional[BlockManager] = None
self.scheduler: Optional[Scheduler] = None
self.sampler: Optional[Sampler] = None
# request_id → asyncio.Queue[StreamItem]
self._output_queues: dict[str, asyncio.Queue[StreamItem]] = {}
# request_id → Sequence (for inspection / abort)
self._sequences: dict[str, Sequence] = {}
# tracker for incremental detokenization
self._prev_text_len: dict[str, int] = {}
# event subscribers
self._event_subscribers: list[asyncio.Queue[EngineEvent]] = []
# control
self._stop = asyncio.Event()
self._step_idx = 0
self._run_task: Optional[asyncio.Task] = None
self._wake = asyncio.Event()
# recording (for the static GH-Pages replay)
self._record_fh: Optional[TextIO] = None
self._record_t0: float = 0.0
# ---- lifecycle ------------------------------------------------------
async def startup(self) -> None:
# Heavy: model load happens in a worker thread so we don't block the loop.
loop = asyncio.get_running_loop()
def _build() -> ModelRunner:
return ModelRunner(self.config)
self.model_runner = await loop.run_in_executor(None, _build)
self.block_manager = BlockManager(
num_blocks=self.config.num_blocks,
block_size=self.config.block_size,
enable_prefix_caching=self.config.enable_prefix_caching,
)
self.scheduler = Scheduler(self.config, self.block_manager)
self.sampler = Sampler(self.model_runner.device)
# Open the recorder *after* the block manager exists so the initial
# snapshot we write is valid.
if self.config.record_path:
self._record_fh = open(self.config.record_path, "w", buffering=1)
self._record_t0 = time.monotonic()
self._record({
"type": "snapshot",
"step": 0,
"timestamp": 0.0,
"payload": self.snapshot(),
})
self._run_task = asyncio.create_task(self._run_loop())
async def shutdown(self) -> None:
self._stop.set()
self._wake.set()
if self._run_task is not None:
try:
await asyncio.wait_for(self._run_task, timeout=5)
except asyncio.TimeoutError:
self._run_task.cancel()
if self._record_fh is not None:
try:
self._record_fh.close()
except Exception:
pass
self._record_fh = None
# ---- request submission --------------------------------------------
def add_request(
self,
prompt: str | list[int],
sampling_params: SamplingParams,
request_id: Optional[str] = None,
) -> str:
if self.model_runner is None:
raise RuntimeError("engine not started")
if isinstance(prompt, str):
token_ids = self.model_runner.encode(prompt)
prompt_text = prompt
else:
token_ids = list(prompt)
prompt_text = self.model_runner.decode(token_ids)
if not token_ids:
raise ValueError("empty prompt")
if len(token_ids) >= self.config.max_model_len:
raise ValueError(
f"prompt length {len(token_ids)} >= max_model_len {self.config.max_model_len}"
)
rid = request_id or uuid.uuid4().hex
seq = Sequence(
prompt_token_ids=token_ids,
sampling_params=sampling_params,
request_id=rid,
)
self._sequences[rid] = seq
self._output_queues[rid] = asyncio.Queue()
self._prev_text_len[rid] = 0
assert self.scheduler is not None
self.scheduler.add(seq)
self._emit("request", {
"request_id": rid,
"seq_id": seq.seq_id,
"prompt": prompt_text,
"prompt_len": len(token_ids),
"max_tokens": sampling_params.max_tokens,
})
self._wake.set()
return rid
def abort(self, request_id: str) -> bool:
seq = self._sequences.get(request_id)
if seq is None:
return False
assert self.scheduler is not None
ok = self.scheduler.abort(seq.seq_id)
if ok:
self._close_request(request_id, finish_reason="abort")
return ok
async def stream(self, request_id: str) -> AsyncIterator[StreamItem]:
q = self._output_queues.get(request_id)
if q is None:
raise KeyError(request_id)
while True:
item = await q.get()
yield item
if item.finished:
break
# ---- event subscriptions -------------------------------------------
def subscribe_events(self) -> asyncio.Queue[EngineEvent]:
q: asyncio.Queue[EngineEvent] = asyncio.Queue(maxsize=self.config.event_buffer)
self._event_subscribers.append(q)
return q
def unsubscribe_events(self, q: asyncio.Queue[EngineEvent]) -> None:
try:
self._event_subscribers.remove(q)
except ValueError:
pass
def _emit(self, event_type: str, payload: dict) -> None:
if not self.config.emit_events:
return
ev = EngineEvent(
step=self._step_idx,
timestamp=time.monotonic(),
type=event_type,
payload=payload,
)
for q in list(self._event_subscribers):
try:
q.put_nowait(ev)
except asyncio.QueueFull:
try:
q.get_nowait()
except asyncio.QueueEmpty:
pass
try:
q.put_nowait(ev)
except asyncio.QueueFull:
pass
# Mirror into the on-disk recording (timestamps re-based to t0).
if self._record_fh is not None:
self._record({
"type": ev.type,
"step": ev.step,
"timestamp": ev.timestamp - self._record_t0,
"payload": ev.payload,
})
def _record(self, ev: dict) -> None:
fh = self._record_fh
if fh is None:
return
try:
fh.write(json.dumps(ev, separators=(",", ":")) + "\n")
except Exception:
pass
# ---- inspection ----------------------------------------------------
def snapshot(self) -> dict:
assert self.block_manager is not None and self.scheduler is not None
def seq_view(s: Sequence) -> dict:
return {
"seq_id": s.seq_id,
"request_id": s.request_id,
"status": s.status.value,
"prompt_len": s.prompt_len,
"num_generated": len(s.output_token_ids),
"num_computed_tokens": s.num_computed_tokens,
"num_cached_prefix_tokens": s.num_cached_prefix_tokens,
"block_table": list(s.block_table),
}
return {
"step": self._step_idx,
"block_pool": self.block_manager.snapshot(),
"waiting": [seq_view(s) for s in self.scheduler.waiting],
"running": [seq_view(s) for s in self.scheduler.running],
"config": {
"model": self.config.model,
"block_size": self.config.block_size,
"num_blocks": self.config.num_blocks,
"max_num_seqs": self.config.max_num_seqs,
"max_num_batched_tokens": self.config.max_num_batched_tokens,
"prefix_caching": self.config.enable_prefix_caching,
},
}
# ---- main loop -----------------------------------------------------
async def _run_loop(self) -> None:
assert self.scheduler is not None and self.model_runner is not None
loop = asyncio.get_running_loop()
while not self._stop.is_set():
if not self.scheduler.has_work:
self._wake.clear()
try:
await asyncio.wait_for(self._wake.wait(), timeout=1.0)
except asyncio.TimeoutError:
pass
continue
self._step_idx += 1
t0 = time.monotonic()
sched = self.scheduler.schedule()
if sched.is_empty:
# Nothing got through this step (probably starved on blocks).
await asyncio.sleep(0.01)
continue
model_input = self.model_runner.prepare_input(sched.scheduled)
# Run blocking model forward off-thread.
logits = await loop.run_in_executor(None, self.model_runner.execute, model_input)
# Update num_computed_tokens AFTER forward (the K/V is now stored).
for item in sched.scheduled:
item.seq.num_computed_tokens += item.num_tokens
# Sample only for sequences that have finished prefill (i.e., the
# last token in their chunk is the *final* prompt token).
sampling_items = [item for item in sched.scheduled
if item.seq.num_computed_tokens >= item.seq.prompt_len]
sampling_indices = [i for i, item in enumerate(sched.scheduled)
if item.seq.num_computed_tokens >= item.seq.prompt_len]
new_tokens: dict[int, int] = {}
if sampling_items:
import torch # local; cheap
sampling_logits = logits.index_select(
0, torch.tensor(sampling_indices, device=logits.device)
)
params = [item.seq.sampling_params for item in sampling_items]
generators = [
(torch.Generator(device=logits.device).manual_seed(item.seq.sampling_params.seed)
if item.seq.sampling_params.seed is not None else None)
for item in sampling_items
]
token_ids = self.sampler.sample(sampling_logits, params, generators)
for item, tok in zip(sampling_items, token_ids):
new_tokens[item.seq.seq_id] = tok
# Apply new tokens, check stopping, register filled blocks.
assert self.block_manager is not None
finished_now: list[Sequence] = []
for item in sched.scheduled:
seq = item.seq
if seq.seq_id in new_tokens:
tok = new_tokens[seq.seq_id]
seq.append_output_token(tok)
# The just-produced token's KV will be written on the NEXT
# step (when this token is the input). But the new token
# may complete a block once its KV lands; we hash blocks
# only after their KV exists, so post-forward in the next
# step is the right time. Here we register newly-filled
# blocks based on the just-finalized num_computed_tokens.
self.block_manager.register_filled_blocks(seq, prev_computed=0)
if self._should_stop(seq, tok):
seq.status = SequenceStatus.FINISHED
seq.finish_reason = self._stop_reason(seq, tok)
finished_now.append(seq)
else:
# Still in prefill; just register newly filled prompt blocks.
self.block_manager.register_filled_blocks(seq, prev_computed=0)
# Free finished sequences.
for seq in finished_now:
if seq in self.scheduler.running:
self.scheduler.running.remove(seq)
self.block_manager.free(seq)
# Emit outputs to per-request queues, and collect per-step deltas
# for the event stream (powers the replay UI's text panes).
step_deltas: list[dict] = []
for item in sched.scheduled:
seq = item.seq
rid = seq.request_id
if seq.seq_id in new_tokens or seq in finished_now:
new_text, new_text_len = self.model_runner.detokenize_incremental(
seq.all_token_ids(), self._prev_text_len.get(rid, 0)
)
self._prev_text_len[rid] = new_text_len
is_done = seq.status == SequenceStatus.FINISHED
new_toks = [new_tokens[seq.seq_id]] if seq.seq_id in new_tokens else []
si = StreamItem(
request_id=rid,
new_text=new_text,
new_token_ids=new_toks,
finished=is_done,
finish_reason=seq.finish_reason,
cumulative_text=self.model_runner.tokenizer.decode(
seq.output_token_ids, skip_special_tokens=True
),
)
q = self._output_queues.get(rid)
if q is not None:
await q.put(si)
if new_text or is_done:
step_deltas.append({
"request_id": rid,
"new_text": new_text,
"finished": is_done,
"finish_reason": seq.finish_reason,
})
if is_done:
self._sequences.pop(rid, None)
self._prev_text_len.pop(rid, None)
# Emit engine events for the UI.
self._emit("step", {
"duration_ms": (time.monotonic() - t0) * 1000,
"num_seqs": len(sched.scheduled),
"num_tokens": sched.total_tokens,
"num_prefill_seqs": sum(1 for it in sched.scheduled if it.is_prefill),
"num_decode_seqs": sum(1 for it in sched.scheduled if not it.is_prefill),
"preempted": sched.preempted,
"newly_admitted": sched.newly_admitted,
"finished": [s.request_id for s in finished_now],
"deltas": step_deltas,
"snapshot": self.snapshot(),
})
# Yield control between steps so the HTTP layer can ship bytes.
await asyncio.sleep(0)
# ---- helpers -------------------------------------------------------
def _should_stop(self, seq: Sequence, last_token: int) -> bool:
sp = seq.sampling_params
if len(seq.output_token_ids) >= sp.max_tokens:
return True
if not sp.ignore_eos:
eos = self.model_runner.eos_token_id if self.model_runner else None
if eos is not None and last_token == eos:
return True
if last_token in sp.stop_token_ids:
return True
if seq.total_len >= self.config.max_model_len:
return True
return False
def _stop_reason(self, seq: Sequence, last_token: int) -> str:
sp = seq.sampling_params
if len(seq.output_token_ids) >= sp.max_tokens:
return "length"
if seq.total_len >= self.config.max_model_len:
return "length"
return "stop"
def _close_request(self, request_id: str, finish_reason: str) -> None:
q = self._output_queues.get(request_id)
if q is None:
return
q.put_nowait(StreamItem(
request_id=request_id,
new_text="",
new_token_ids=[],
finished=True,
finish_reason=finish_reason,
))
self._sequences.pop(request_id, None)
self._prev_text_len.pop(request_id, None)