| from __future__ import annotations |
|
|
| import enum |
| import itertools |
| import time |
| from dataclasses import dataclass, field |
| from typing import Optional |
|
|
| from .config import SamplingParams |
|
|
|
|
| class SequenceStatus(enum.Enum): |
| WAITING = "waiting" |
| PREFILLING = "prefilling" |
| RUNNING = "running" |
| FINISHED = "finished" |
| PREEMPTED = "preempted" |
|
|
|
|
| _seq_counter = itertools.count() |
|
|
|
|
| def _next_seq_id() -> int: |
| return next(_seq_counter) |
|
|
|
|
| @dataclass |
| class Sequence: |
| """One in-flight request. |
| |
| The token sequence is `prompt_token_ids + output_token_ids`. |
| `num_computed_tokens` tracks how many tokens already have their KV |
| materialized in the paged cache. Anything past that boundary is either |
| waiting prefill (during PREFILLING) or the next token to sample (RUNNING). |
| """ |
|
|
| prompt_token_ids: list[int] |
| sampling_params: SamplingParams |
| request_id: str |
| arrival_time: float = field(default_factory=time.monotonic) |
| seq_id: int = field(default_factory=_next_seq_id) |
|
|
| output_token_ids: list[int] = field(default_factory=list) |
| status: SequenceStatus = SequenceStatus.WAITING |
|
|
| |
| block_table: list[int] = field(default_factory=list) |
| num_computed_tokens: int = 0 |
| num_cached_prefix_tokens: int = 0 |
|
|
| |
| finish_reason: Optional[str] = None |
|
|
| |
|
|
| @property |
| def prompt_len(self) -> int: |
| return len(self.prompt_token_ids) |
|
|
| @property |
| def total_len(self) -> int: |
| return len(self.prompt_token_ids) + len(self.output_token_ids) |
|
|
| def all_token_ids(self) -> list[int]: |
| return self.prompt_token_ids + self.output_token_ids |
|
|
| def get_token(self, position: int) -> int: |
| if position < len(self.prompt_token_ids): |
| return self.prompt_token_ids[position] |
| return self.output_token_ids[position - len(self.prompt_token_ids)] |
|
|
| @property |
| def num_uncomputed_prompt_tokens(self) -> int: |
| return max(0, self.prompt_len - self.num_computed_tokens) |
|
|
| def append_output_token(self, token_id: int) -> None: |
| self.output_token_ids.append(token_id) |
|
|
|
|
| @dataclass |
| class Request: |
| """A user-submitted request before it becomes a Sequence.""" |
|
|
| request_id: str |
| prompt_token_ids: list[int] |
| sampling_params: SamplingParams |
|
|