tiny-vllm / tiny_vllm /request.py
enCoder's picture
minimal continuous-batching LLM engine
c32c359
raw
history blame
2.64 kB
from __future__ import annotations
import enum
import itertools
import time
from dataclasses import dataclass, field
from typing import Optional
from .config import SamplingParams
class SequenceStatus(enum.Enum):
WAITING = "waiting" # not yet started prefill
PREFILLING = "prefilling" # chunked prefill in progress
RUNNING = "running" # in decode loop
FINISHED = "finished"
PREEMPTED = "preempted" # evicted; will restart prefill when capacity returns
_seq_counter = itertools.count()
def _next_seq_id() -> int:
return next(_seq_counter)
@dataclass
class Sequence:
"""One in-flight request.
The token sequence is `prompt_token_ids + output_token_ids`.
`num_computed_tokens` tracks how many tokens already have their KV
materialized in the paged cache. Anything past that boundary is either
waiting prefill (during PREFILLING) or the next token to sample (RUNNING).
"""
prompt_token_ids: list[int]
sampling_params: SamplingParams
request_id: str
arrival_time: float = field(default_factory=time.monotonic)
seq_id: int = field(default_factory=_next_seq_id)
output_token_ids: list[int] = field(default_factory=list)
status: SequenceStatus = SequenceStatus.WAITING
# Paged KV bookkeeping (filled in by the BlockManager).
block_table: list[int] = field(default_factory=list)
num_computed_tokens: int = 0 # tokens with KV in the cache
num_cached_prefix_tokens: int = 0 # tokens served from prefix cache hits
# Outputs / streaming
finish_reason: Optional[str] = None
# ---- helpers --------------------------------------------------------
@property
def prompt_len(self) -> int:
return len(self.prompt_token_ids)
@property
def total_len(self) -> int:
return len(self.prompt_token_ids) + len(self.output_token_ids)
def all_token_ids(self) -> list[int]:
return self.prompt_token_ids + self.output_token_ids
def get_token(self, position: int) -> int:
if position < len(self.prompt_token_ids):
return self.prompt_token_ids[position]
return self.output_token_ids[position - len(self.prompt_token_ids)]
@property
def num_uncomputed_prompt_tokens(self) -> int:
return max(0, self.prompt_len - self.num_computed_tokens)
def append_output_token(self, token_id: int) -> None:
self.output_token_ids.append(token_id)
@dataclass
class Request:
"""A user-submitted request before it becomes a Sequence."""
request_id: str
prompt_token_ids: list[int]
sampling_params: SamplingParams