File size: 2,641 Bytes
c32c359
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from __future__ import annotations

import enum
import itertools
import time
from dataclasses import dataclass, field
from typing import Optional

from .config import SamplingParams


class SequenceStatus(enum.Enum):
    WAITING = "waiting"        # not yet started prefill
    PREFILLING = "prefilling"  # chunked prefill in progress
    RUNNING = "running"        # in decode loop
    FINISHED = "finished"
    PREEMPTED = "preempted"    # evicted; will restart prefill when capacity returns


_seq_counter = itertools.count()


def _next_seq_id() -> int:
    return next(_seq_counter)


@dataclass
class Sequence:
    """One in-flight request.

    The token sequence is `prompt_token_ids + output_token_ids`.
    `num_computed_tokens` tracks how many tokens already have their KV
    materialized in the paged cache.  Anything past that boundary is either
    waiting prefill (during PREFILLING) or the next token to sample (RUNNING).
    """

    prompt_token_ids: list[int]
    sampling_params: SamplingParams
    request_id: str
    arrival_time: float = field(default_factory=time.monotonic)
    seq_id: int = field(default_factory=_next_seq_id)

    output_token_ids: list[int] = field(default_factory=list)
    status: SequenceStatus = SequenceStatus.WAITING

    # Paged KV bookkeeping (filled in by the BlockManager).
    block_table: list[int] = field(default_factory=list)
    num_computed_tokens: int = 0          # tokens with KV in the cache
    num_cached_prefix_tokens: int = 0     # tokens served from prefix cache hits

    # Outputs / streaming
    finish_reason: Optional[str] = None

    # ---- helpers --------------------------------------------------------

    @property
    def prompt_len(self) -> int:
        return len(self.prompt_token_ids)

    @property
    def total_len(self) -> int:
        return len(self.prompt_token_ids) + len(self.output_token_ids)

    def all_token_ids(self) -> list[int]:
        return self.prompt_token_ids + self.output_token_ids

    def get_token(self, position: int) -> int:
        if position < len(self.prompt_token_ids):
            return self.prompt_token_ids[position]
        return self.output_token_ids[position - len(self.prompt_token_ids)]

    @property
    def num_uncomputed_prompt_tokens(self) -> int:
        return max(0, self.prompt_len - self.num_computed_tokens)

    def append_output_token(self, token_id: int) -> None:
        self.output_token_ids.append(token_id)


@dataclass
class Request:
    """A user-submitted request before it becomes a Sequence."""

    request_id: str
    prompt_token_ids: list[int]
    sampling_params: SamplingParams