| """LLMEngine: orchestrates scheduler + block manager + model runner + sampler. |
| |
| Public surface: |
| |
| engine = LLMEngine(EngineConfig(...)) |
| await engine.startup() |
| rid = engine.add_request(prompt_text, SamplingParams(...)) |
| async for delta in engine.stream(rid): |
| ... |
| |
| A single background task (`_run_loop`) drives the model. Per-request output |
| goes through asyncio queues so the HTTP layer can stream incrementally. A |
| second pub/sub channel emits engine-state snapshots for the visualization UI. |
| """ |
| from __future__ import annotations |
|
|
| import asyncio |
| import itertools |
| import json |
| import time |
| import uuid |
| from collections import deque |
| from dataclasses import dataclass, field |
| from typing import AsyncIterator, Optional, TextIO |
|
|
| from .block_manager import BlockManager |
| from .config import EngineConfig, SamplingParams |
| from .model_runner import ModelRunner |
| from .request import Sequence, SequenceStatus |
| from .sampler import Sampler |
| from .scheduler import Scheduler |
|
|
|
|
| @dataclass |
| class StreamItem: |
| request_id: str |
| new_text: str |
| new_token_ids: list[int] |
| finished: bool |
| finish_reason: Optional[str] = None |
| cumulative_text: str = "" |
|
|
|
|
| @dataclass |
| class EngineEvent: |
| step: int |
| timestamp: float |
| type: str |
| payload: dict = field(default_factory=dict) |
|
|
|
|
| class LLMEngine: |
| def __init__(self, config: EngineConfig) -> None: |
| self.config = config |
| self.model_runner: Optional[ModelRunner] = None |
| self.block_manager: Optional[BlockManager] = None |
| self.scheduler: Optional[Scheduler] = None |
| self.sampler: Optional[Sampler] = None |
|
|
| |
| self._output_queues: dict[str, asyncio.Queue[StreamItem]] = {} |
| |
| self._sequences: dict[str, Sequence] = {} |
| |
| self._prev_text_len: dict[str, int] = {} |
| |
| self._event_subscribers: list[asyncio.Queue[EngineEvent]] = [] |
| |
| self._stop = asyncio.Event() |
| self._step_idx = 0 |
| self._run_task: Optional[asyncio.Task] = None |
| self._wake = asyncio.Event() |
| |
| self._record_fh: Optional[TextIO] = None |
| self._record_t0: float = 0.0 |
|
|
| |
|
|
| async def startup(self) -> None: |
| |
| loop = asyncio.get_running_loop() |
|
|
| def _build() -> ModelRunner: |
| return ModelRunner(self.config) |
|
|
| self.model_runner = await loop.run_in_executor(None, _build) |
| self.block_manager = BlockManager( |
| num_blocks=self.config.num_blocks, |
| block_size=self.config.block_size, |
| enable_prefix_caching=self.config.enable_prefix_caching, |
| ) |
| self.scheduler = Scheduler(self.config, self.block_manager) |
| self.sampler = Sampler(self.model_runner.device) |
|
|
| |
| |
| if self.config.record_path: |
| self._record_fh = open(self.config.record_path, "w", buffering=1) |
| self._record_t0 = time.monotonic() |
| self._record({ |
| "type": "snapshot", |
| "step": 0, |
| "timestamp": 0.0, |
| "payload": self.snapshot(), |
| }) |
|
|
| self._run_task = asyncio.create_task(self._run_loop()) |
|
|
| async def shutdown(self) -> None: |
| self._stop.set() |
| self._wake.set() |
| if self._run_task is not None: |
| try: |
| await asyncio.wait_for(self._run_task, timeout=5) |
| except asyncio.TimeoutError: |
| self._run_task.cancel() |
| if self._record_fh is not None: |
| try: |
| self._record_fh.close() |
| except Exception: |
| pass |
| self._record_fh = None |
|
|
| |
|
|
| def add_request( |
| self, |
| prompt: str | list[int], |
| sampling_params: SamplingParams, |
| request_id: Optional[str] = None, |
| ) -> str: |
| if self.model_runner is None: |
| raise RuntimeError("engine not started") |
| if isinstance(prompt, str): |
| token_ids = self.model_runner.encode(prompt) |
| prompt_text = prompt |
| else: |
| token_ids = list(prompt) |
| prompt_text = self.model_runner.decode(token_ids) |
| if not token_ids: |
| raise ValueError("empty prompt") |
| if len(token_ids) >= self.config.max_model_len: |
| raise ValueError( |
| f"prompt length {len(token_ids)} >= max_model_len {self.config.max_model_len}" |
| ) |
| rid = request_id or uuid.uuid4().hex |
| seq = Sequence( |
| prompt_token_ids=token_ids, |
| sampling_params=sampling_params, |
| request_id=rid, |
| ) |
| self._sequences[rid] = seq |
| self._output_queues[rid] = asyncio.Queue() |
| self._prev_text_len[rid] = 0 |
| assert self.scheduler is not None |
| self.scheduler.add(seq) |
| self._emit("request", { |
| "request_id": rid, |
| "seq_id": seq.seq_id, |
| "prompt": prompt_text, |
| "prompt_len": len(token_ids), |
| "max_tokens": sampling_params.max_tokens, |
| }) |
| self._wake.set() |
| return rid |
|
|
| def abort(self, request_id: str) -> bool: |
| seq = self._sequences.get(request_id) |
| if seq is None: |
| return False |
| assert self.scheduler is not None |
| ok = self.scheduler.abort(seq.seq_id) |
| if ok: |
| self._close_request(request_id, finish_reason="abort") |
| return ok |
|
|
| async def stream(self, request_id: str) -> AsyncIterator[StreamItem]: |
| q = self._output_queues.get(request_id) |
| if q is None: |
| raise KeyError(request_id) |
| while True: |
| item = await q.get() |
| yield item |
| if item.finished: |
| break |
|
|
| |
|
|
| def subscribe_events(self) -> asyncio.Queue[EngineEvent]: |
| q: asyncio.Queue[EngineEvent] = asyncio.Queue(maxsize=self.config.event_buffer) |
| self._event_subscribers.append(q) |
| return q |
|
|
| def unsubscribe_events(self, q: asyncio.Queue[EngineEvent]) -> None: |
| try: |
| self._event_subscribers.remove(q) |
| except ValueError: |
| pass |
|
|
| def _emit(self, event_type: str, payload: dict) -> None: |
| if not self.config.emit_events: |
| return |
| ev = EngineEvent( |
| step=self._step_idx, |
| timestamp=time.monotonic(), |
| type=event_type, |
| payload=payload, |
| ) |
| for q in list(self._event_subscribers): |
| try: |
| q.put_nowait(ev) |
| except asyncio.QueueFull: |
| try: |
| q.get_nowait() |
| except asyncio.QueueEmpty: |
| pass |
| try: |
| q.put_nowait(ev) |
| except asyncio.QueueFull: |
| pass |
| |
| if self._record_fh is not None: |
| self._record({ |
| "type": ev.type, |
| "step": ev.step, |
| "timestamp": ev.timestamp - self._record_t0, |
| "payload": ev.payload, |
| }) |
|
|
| def _record(self, ev: dict) -> None: |
| fh = self._record_fh |
| if fh is None: |
| return |
| try: |
| fh.write(json.dumps(ev, separators=(",", ":")) + "\n") |
| except Exception: |
| pass |
|
|
| |
|
|
| def snapshot(self) -> dict: |
| assert self.block_manager is not None and self.scheduler is not None |
| def seq_view(s: Sequence) -> dict: |
| return { |
| "seq_id": s.seq_id, |
| "request_id": s.request_id, |
| "status": s.status.value, |
| "prompt_len": s.prompt_len, |
| "num_generated": len(s.output_token_ids), |
| "num_computed_tokens": s.num_computed_tokens, |
| "num_cached_prefix_tokens": s.num_cached_prefix_tokens, |
| "block_table": list(s.block_table), |
| } |
| return { |
| "step": self._step_idx, |
| "block_pool": self.block_manager.snapshot(), |
| "waiting": [seq_view(s) for s in self.scheduler.waiting], |
| "running": [seq_view(s) for s in self.scheduler.running], |
| "config": { |
| "model": self.config.model, |
| "block_size": self.config.block_size, |
| "num_blocks": self.config.num_blocks, |
| "max_num_seqs": self.config.max_num_seqs, |
| "max_num_batched_tokens": self.config.max_num_batched_tokens, |
| "prefix_caching": self.config.enable_prefix_caching, |
| }, |
| } |
|
|
| |
|
|
| async def _run_loop(self) -> None: |
| assert self.scheduler is not None and self.model_runner is not None |
| loop = asyncio.get_running_loop() |
| while not self._stop.is_set(): |
| if not self.scheduler.has_work: |
| self._wake.clear() |
| try: |
| await asyncio.wait_for(self._wake.wait(), timeout=1.0) |
| except asyncio.TimeoutError: |
| pass |
| continue |
|
|
| self._step_idx += 1 |
| t0 = time.monotonic() |
| sched = self.scheduler.schedule() |
| if sched.is_empty: |
| |
| await asyncio.sleep(0.01) |
| continue |
|
|
| model_input = self.model_runner.prepare_input(sched.scheduled) |
| |
| logits = await loop.run_in_executor(None, self.model_runner.execute, model_input) |
|
|
| |
| for item in sched.scheduled: |
| item.seq.num_computed_tokens += item.num_tokens |
|
|
| |
| |
| sampling_items = [item for item in sched.scheduled |
| if item.seq.num_computed_tokens >= item.seq.prompt_len] |
| sampling_indices = [i for i, item in enumerate(sched.scheduled) |
| if item.seq.num_computed_tokens >= item.seq.prompt_len] |
|
|
| new_tokens: dict[int, int] = {} |
| if sampling_items: |
| import torch |
| sampling_logits = logits.index_select( |
| 0, torch.tensor(sampling_indices, device=logits.device) |
| ) |
| params = [item.seq.sampling_params for item in sampling_items] |
| generators = [ |
| (torch.Generator(device=logits.device).manual_seed(item.seq.sampling_params.seed) |
| if item.seq.sampling_params.seed is not None else None) |
| for item in sampling_items |
| ] |
| token_ids = self.sampler.sample(sampling_logits, params, generators) |
| for item, tok in zip(sampling_items, token_ids): |
| new_tokens[item.seq.seq_id] = tok |
|
|
| |
| assert self.block_manager is not None |
| finished_now: list[Sequence] = [] |
| for item in sched.scheduled: |
| seq = item.seq |
| if seq.seq_id in new_tokens: |
| tok = new_tokens[seq.seq_id] |
| seq.append_output_token(tok) |
| |
| |
| |
| |
| |
| |
| self.block_manager.register_filled_blocks(seq, prev_computed=0) |
|
|
| if self._should_stop(seq, tok): |
| seq.status = SequenceStatus.FINISHED |
| seq.finish_reason = self._stop_reason(seq, tok) |
| finished_now.append(seq) |
| else: |
| |
| self.block_manager.register_filled_blocks(seq, prev_computed=0) |
|
|
| |
| for seq in finished_now: |
| if seq in self.scheduler.running: |
| self.scheduler.running.remove(seq) |
| self.block_manager.free(seq) |
|
|
| |
| |
| step_deltas: list[dict] = [] |
| for item in sched.scheduled: |
| seq = item.seq |
| rid = seq.request_id |
| if seq.seq_id in new_tokens or seq in finished_now: |
| new_text, new_text_len = self.model_runner.detokenize_incremental( |
| seq.all_token_ids(), self._prev_text_len.get(rid, 0) |
| ) |
| self._prev_text_len[rid] = new_text_len |
| is_done = seq.status == SequenceStatus.FINISHED |
| new_toks = [new_tokens[seq.seq_id]] if seq.seq_id in new_tokens else [] |
| si = StreamItem( |
| request_id=rid, |
| new_text=new_text, |
| new_token_ids=new_toks, |
| finished=is_done, |
| finish_reason=seq.finish_reason, |
| cumulative_text=self.model_runner.tokenizer.decode( |
| seq.output_token_ids, skip_special_tokens=True |
| ), |
| ) |
| q = self._output_queues.get(rid) |
| if q is not None: |
| await q.put(si) |
| if new_text or is_done: |
| step_deltas.append({ |
| "request_id": rid, |
| "new_text": new_text, |
| "finished": is_done, |
| "finish_reason": seq.finish_reason, |
| }) |
| if is_done: |
| self._sequences.pop(rid, None) |
| self._prev_text_len.pop(rid, None) |
|
|
| |
| self._emit("step", { |
| "duration_ms": (time.monotonic() - t0) * 1000, |
| "num_seqs": len(sched.scheduled), |
| "num_tokens": sched.total_tokens, |
| "num_prefill_seqs": sum(1 for it in sched.scheduled if it.is_prefill), |
| "num_decode_seqs": sum(1 for it in sched.scheduled if not it.is_prefill), |
| "preempted": sched.preempted, |
| "newly_admitted": sched.newly_admitted, |
| "finished": [s.request_id for s in finished_now], |
| "deltas": step_deltas, |
| "snapshot": self.snapshot(), |
| }) |
|
|
| |
| await asyncio.sleep(0) |
|
|
| |
|
|
| def _should_stop(self, seq: Sequence, last_token: int) -> bool: |
| sp = seq.sampling_params |
| if len(seq.output_token_ids) >= sp.max_tokens: |
| return True |
| if not sp.ignore_eos: |
| eos = self.model_runner.eos_token_id if self.model_runner else None |
| if eos is not None and last_token == eos: |
| return True |
| if last_token in sp.stop_token_ids: |
| return True |
| if seq.total_len >= self.config.max_model_len: |
| return True |
| return False |
|
|
| def _stop_reason(self, seq: Sequence, last_token: int) -> str: |
| sp = seq.sampling_params |
| if len(seq.output_token_ids) >= sp.max_tokens: |
| return "length" |
| if seq.total_len >= self.config.max_model_len: |
| return "length" |
| return "stop" |
|
|
| def _close_request(self, request_id: str, finish_reason: str) -> None: |
| q = self._output_queues.get(request_id) |
| if q is None: |
| return |
| q.put_nowait(StreamItem( |
| request_id=request_id, |
| new_text="", |
| new_token_ids=[], |
| finished=True, |
| finish_reason=finish_reason, |
| )) |
| self._sequences.pop(request_id, None) |
| self._prev_text_len.pop(request_id, None) |
|
|