File size: 9,046 Bytes
c32c359
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
"""Continuous-batching scheduler with chunked prefill.

A scheduling step produces a SchedulerOutput listing which sequences run and
how many tokens each one advances.  Two phases each step:

  1. Decodes.  Every RUNNING sequence wants exactly one new token; we must
     ensure each has space for it.  If a sequence needs a new block and the
     pool is dry, we *preempt* the most recently admitted running sequence —
     free its KV blocks and push it back to the front of the waiting queue
     so it restarts prefill later (recompute-style preemption, as in vLLM).

  2. Prefill chunks.  With remaining token budget, pull from `waiting`.  A
     newly waiting sequence is admitted (prompt blocks allocated via the
     block manager, with prefix-cache hits taken).  Then we plan a chunk of
     up to `min(remaining_prefill, budget)` tokens.  Chunked prefill lets a
     long prompt share the budget with concurrent decodes instead of
     stalling them.
"""
from __future__ import annotations

from collections import deque
from dataclasses import dataclass, field

from .block_manager import BlockManager
from .config import EngineConfig
from .request import Sequence, SequenceStatus


@dataclass
class ScheduledSeq:
    seq: Sequence
    num_tokens: int     # how many tokens to forward this step for this seq
    is_prefill: bool


@dataclass
class SchedulerOutput:
    scheduled: list[ScheduledSeq] = field(default_factory=list)
    preempted: list[int] = field(default_factory=list)  # seq_ids preempted
    newly_admitted: list[int] = field(default_factory=list)
    total_tokens: int = 0

    @property
    def is_empty(self) -> bool:
        return not self.scheduled


class Scheduler:
    def __init__(self, config: EngineConfig, block_manager: BlockManager) -> None:
        self.config = config
        self.block_manager = block_manager
        self.waiting: deque[Sequence] = deque()
        self.running: list[Sequence] = []
        # Tracks order of admission so preemption picks the youngest first.
        self._admission_order: list[int] = []

    # ---- queue ops ------------------------------------------------------

    def add(self, seq: Sequence) -> None:
        self.waiting.append(seq)

    def abort(self, seq_id: int) -> bool:
        for q in (self.waiting,):
            for s in list(q):
                if s.seq_id == seq_id:
                    q.remove(s)
                    s.status = SequenceStatus.FINISHED
                    s.finish_reason = "abort"
                    return True
        for s in list(self.running):
            if s.seq_id == seq_id:
                self.running.remove(s)
                self.block_manager.free(s)
                s.status = SequenceStatus.FINISHED
                s.finish_reason = "abort"
                return True
        return False

    @property
    def has_work(self) -> bool:
        return bool(self.waiting) or bool(self.running)

    # ---- scheduling -----------------------------------------------------

    def _preempt_one(self) -> Sequence | None:
        """Free the youngest running sequence and re-enqueue it for restart."""
        if not self.running:
            return None
        victim = self.running.pop()  # youngest by insertion order
        self.block_manager.free(victim)
        # Restart: forget computed-token progress; keep generated outputs so
        # the user-visible sequence is preserved.  (vLLM full-recompute: we'd
        # discard outputs too; we keep them so streaming makes sense.)
        victim.num_computed_tokens = 0
        victim.num_cached_prefix_tokens = 0
        victim.status = SequenceStatus.PREEMPTED
        self.waiting.appendleft(victim)
        return victim

    def schedule(self) -> SchedulerOutput:
        out = SchedulerOutput()
        budget = self.config.max_num_batched_tokens

        # --- Phase 1: decodes for already-running sequences ---
        for seq in list(self.running):
            if seq.status != SequenceStatus.RUNNING:
                continue
            if budget <= 0:
                break
            # Ensure space for one more token.
            try:
                self.block_manager.append_slot(seq)
            except RuntimeError:
                # Out of blocks: try to free space by preempting the youngest
                # running sequence — which may be `seq` itself.
                victim = self._preempt_one()
                if victim is seq:
                    # We preempted ourselves; it's already off `running`.
                    out.preempted.append(seq.seq_id)
                    continue
                if victim is None:
                    # Nothing to preempt; preempt this seq manually.
                    self.running.remove(seq)
                    self.block_manager.free(seq)
                    seq.num_computed_tokens = 0
                    seq.num_cached_prefix_tokens = 0
                    seq.status = SequenceStatus.PREEMPTED
                    self.waiting.appendleft(seq)
                    out.preempted.append(seq.seq_id)
                    continue
                out.preempted.append(victim.seq_id)
                try:
                    self.block_manager.append_slot(seq)
                except RuntimeError:
                    # Still no room — give up on this seq this step.
                    continue
            out.scheduled.append(ScheduledSeq(seq=seq, num_tokens=1, is_prefill=False))
            budget -= 1
            out.total_tokens += 1

        # --- Phase 2: prefill chunks (admitting new sequences as needed) ---
        max_concurrent = self.config.max_num_seqs
        active_count = sum(1 for s in self.running if s.status != SequenceStatus.FINISHED)

        while self.waiting and budget > 0 and active_count < max_concurrent:
            seq = self.waiting[0]

            # Admit if needed.
            if not seq.block_table:
                ok, _ = self.block_manager.can_allocate_initial(seq)
                if not ok:
                    # Try to free up space by preempting the youngest running
                    # seq.  If nothing to preempt, we're stuck for this step.
                    if not self.running:
                        break
                    victim = self._preempt_one()
                    if victim is None:
                        break
                    out.preempted.append(victim.seq_id)
                    continue
                self.block_manager.admit(seq)
                out.newly_admitted.append(seq.seq_id)
                seq.status = SequenceStatus.PREFILLING

            # Plan a chunk.
            remaining = seq.num_uncomputed_prompt_tokens
            chunk = min(remaining, budget)
            if chunk <= 0:
                # Prompt already fully cached (shouldn't happen due to admit
                # capping, but defensive): move straight to RUNNING.
                self.waiting.popleft()
                seq.status = SequenceStatus.RUNNING
                self.running.append(seq)
                active_count += 1
                continue

            # Make sure block_table covers num_computed + chunk.
            try:
                self.block_manager.ensure_blocks_for_chunk(seq, chunk)
            except RuntimeError:
                # Couldn't expand.  Try preemption; otherwise give up.
                if self.running:
                    victim = self._preempt_one()
                    if victim is not None:
                        out.preempted.append(victim.seq_id)
                        continue
                break

            out.scheduled.append(ScheduledSeq(seq=seq, num_tokens=chunk, is_prefill=True))
            budget -= chunk
            out.total_tokens += chunk

            if chunk == remaining:
                # This step finishes prompt ingestion → seq becomes RUNNING.
                self.waiting.popleft()
                seq.status = SequenceStatus.RUNNING
                self.running.append(seq)
                active_count += 1
            else:
                # Still has more prompt to chew through; leave at head of
                # waiting queue with a partial block_table.
                break  # one prefill per step keeps things tidy

        return out

    # ---- post-step ------------------------------------------------------

    def finalize_step(self, scheduled: list[ScheduledSeq]) -> list[Sequence]:
        """Called after the model has produced new tokens.

        Returns the list of sequences that just finished this step (so the
        engine can free them and ship the final output to the caller).
        """
        finished: list[Sequence] = []
        for item in scheduled:
            seq = item.seq
            self.block_manager.register_filled_blocks(seq, prev_computed=0)
            if seq.status == SequenceStatus.FINISHED:
                if seq in self.running:
                    self.running.remove(seq)
                self.block_manager.free(seq)
                finished.append(seq)
        return finished