tiny-vllm / tiny_vllm /scheduler.py
enCoder's picture
minimal continuous-batching LLM engine
c32c359
"""Continuous-batching scheduler with chunked prefill.
A scheduling step produces a SchedulerOutput listing which sequences run and
how many tokens each one advances. Two phases each step:
1. Decodes. Every RUNNING sequence wants exactly one new token; we must
ensure each has space for it. If a sequence needs a new block and the
pool is dry, we *preempt* the most recently admitted running sequence —
free its KV blocks and push it back to the front of the waiting queue
so it restarts prefill later (recompute-style preemption, as in vLLM).
2. Prefill chunks. With remaining token budget, pull from `waiting`. A
newly waiting sequence is admitted (prompt blocks allocated via the
block manager, with prefix-cache hits taken). Then we plan a chunk of
up to `min(remaining_prefill, budget)` tokens. Chunked prefill lets a
long prompt share the budget with concurrent decodes instead of
stalling them.
"""
from __future__ import annotations
from collections import deque
from dataclasses import dataclass, field
from .block_manager import BlockManager
from .config import EngineConfig
from .request import Sequence, SequenceStatus
@dataclass
class ScheduledSeq:
seq: Sequence
num_tokens: int # how many tokens to forward this step for this seq
is_prefill: bool
@dataclass
class SchedulerOutput:
scheduled: list[ScheduledSeq] = field(default_factory=list)
preempted: list[int] = field(default_factory=list) # seq_ids preempted
newly_admitted: list[int] = field(default_factory=list)
total_tokens: int = 0
@property
def is_empty(self) -> bool:
return not self.scheduled
class Scheduler:
def __init__(self, config: EngineConfig, block_manager: BlockManager) -> None:
self.config = config
self.block_manager = block_manager
self.waiting: deque[Sequence] = deque()
self.running: list[Sequence] = []
# Tracks order of admission so preemption picks the youngest first.
self._admission_order: list[int] = []
# ---- queue ops ------------------------------------------------------
def add(self, seq: Sequence) -> None:
self.waiting.append(seq)
def abort(self, seq_id: int) -> bool:
for q in (self.waiting,):
for s in list(q):
if s.seq_id == seq_id:
q.remove(s)
s.status = SequenceStatus.FINISHED
s.finish_reason = "abort"
return True
for s in list(self.running):
if s.seq_id == seq_id:
self.running.remove(s)
self.block_manager.free(s)
s.status = SequenceStatus.FINISHED
s.finish_reason = "abort"
return True
return False
@property
def has_work(self) -> bool:
return bool(self.waiting) or bool(self.running)
# ---- scheduling -----------------------------------------------------
def _preempt_one(self) -> Sequence | None:
"""Free the youngest running sequence and re-enqueue it for restart."""
if not self.running:
return None
victim = self.running.pop() # youngest by insertion order
self.block_manager.free(victim)
# Restart: forget computed-token progress; keep generated outputs so
# the user-visible sequence is preserved. (vLLM full-recompute: we'd
# discard outputs too; we keep them so streaming makes sense.)
victim.num_computed_tokens = 0
victim.num_cached_prefix_tokens = 0
victim.status = SequenceStatus.PREEMPTED
self.waiting.appendleft(victim)
return victim
def schedule(self) -> SchedulerOutput:
out = SchedulerOutput()
budget = self.config.max_num_batched_tokens
# --- Phase 1: decodes for already-running sequences ---
for seq in list(self.running):
if seq.status != SequenceStatus.RUNNING:
continue
if budget <= 0:
break
# Ensure space for one more token.
try:
self.block_manager.append_slot(seq)
except RuntimeError:
# Out of blocks: try to free space by preempting the youngest
# running sequence — which may be `seq` itself.
victim = self._preempt_one()
if victim is seq:
# We preempted ourselves; it's already off `running`.
out.preempted.append(seq.seq_id)
continue
if victim is None:
# Nothing to preempt; preempt this seq manually.
self.running.remove(seq)
self.block_manager.free(seq)
seq.num_computed_tokens = 0
seq.num_cached_prefix_tokens = 0
seq.status = SequenceStatus.PREEMPTED
self.waiting.appendleft(seq)
out.preempted.append(seq.seq_id)
continue
out.preempted.append(victim.seq_id)
try:
self.block_manager.append_slot(seq)
except RuntimeError:
# Still no room — give up on this seq this step.
continue
out.scheduled.append(ScheduledSeq(seq=seq, num_tokens=1, is_prefill=False))
budget -= 1
out.total_tokens += 1
# --- Phase 2: prefill chunks (admitting new sequences as needed) ---
max_concurrent = self.config.max_num_seqs
active_count = sum(1 for s in self.running if s.status != SequenceStatus.FINISHED)
while self.waiting and budget > 0 and active_count < max_concurrent:
seq = self.waiting[0]
# Admit if needed.
if not seq.block_table:
ok, _ = self.block_manager.can_allocate_initial(seq)
if not ok:
# Try to free up space by preempting the youngest running
# seq. If nothing to preempt, we're stuck for this step.
if not self.running:
break
victim = self._preempt_one()
if victim is None:
break
out.preempted.append(victim.seq_id)
continue
self.block_manager.admit(seq)
out.newly_admitted.append(seq.seq_id)
seq.status = SequenceStatus.PREFILLING
# Plan a chunk.
remaining = seq.num_uncomputed_prompt_tokens
chunk = min(remaining, budget)
if chunk <= 0:
# Prompt already fully cached (shouldn't happen due to admit
# capping, but defensive): move straight to RUNNING.
self.waiting.popleft()
seq.status = SequenceStatus.RUNNING
self.running.append(seq)
active_count += 1
continue
# Make sure block_table covers num_computed + chunk.
try:
self.block_manager.ensure_blocks_for_chunk(seq, chunk)
except RuntimeError:
# Couldn't expand. Try preemption; otherwise give up.
if self.running:
victim = self._preempt_one()
if victim is not None:
out.preempted.append(victim.seq_id)
continue
break
out.scheduled.append(ScheduledSeq(seq=seq, num_tokens=chunk, is_prefill=True))
budget -= chunk
out.total_tokens += chunk
if chunk == remaining:
# This step finishes prompt ingestion → seq becomes RUNNING.
self.waiting.popleft()
seq.status = SequenceStatus.RUNNING
self.running.append(seq)
active_count += 1
else:
# Still has more prompt to chew through; leave at head of
# waiting queue with a partial block_table.
break # one prefill per step keeps things tidy
return out
# ---- post-step ------------------------------------------------------
def finalize_step(self, scheduled: list[ScheduledSeq]) -> list[Sequence]:
"""Called after the model has produced new tokens.
Returns the list of sequences that just finished this step (so the
engine can free them and ship the final output to the caller).
"""
finished: list[Sequence] = []
for item in scheduled:
seq = item.seq
self.block_manager.register_filled_blocks(seq, prev_computed=0)
if seq.status == SequenceStatus.FINISHED:
if seq in self.running:
self.running.remove(seq)
self.block_manager.free(seq)
finished.append(seq)
return finished