File size: 13,833 Bytes
9496ebf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
from abc import ABC, abstractmethod
from collections import deque

from ...utils.metrics import attach_tracer, traced
from .cache import PagedAttentionCache
from .requests import RequestState, RequestStatus


class Scheduler(ABC):
    """
    Abstract base class for scheduling requests in the continuous batch processor. Schedulers manage the lifecycle of
    requests from when they are added to the waiting queue to when they are scheduled for processing. Different
    schedulers implement different strategies for prioritizing and batching requests.
    """

    def __init__(self, cache: PagedAttentionCache, retain_cache_on_finish: bool = False):
        self.active_requests: dict[str, RequestState] = {}
        self.waiting_requests: dict[str, RequestState] = {}
        self.waiting_requests_order: deque[str] = deque()
        self.cache = cache
        self.retain_cache_on_finish = retain_cache_on_finish
        self._cancellation_lock = threading.Lock()
        self._requests_to_cancel: set[str] = set()

    @traced
    def add_waiting_request(self, state: RequestState):
        """Adds a request to the waiting list."""
        if self.retain_cache_on_finish and state.request_id in self.active_requests:
            old_state = self.active_requests.pop(state.request_id)
            state.prompt_ids = state.prompt_ids[len(old_state.full_prompt_ids) :]  # XXX: check for indexing error?
            state.allocated_blocks = old_state.allocated_blocks
            state.position_offset = old_state.position_offset
        self.waiting_requests[state.request_id] = state
        self.waiting_requests_order.append(state.request_id)

    @abstractmethod
    def schedule_batch(self, token_budget: int) -> list[RequestState]:
        """Schedules requests for the next batch based on available token budget. This method selects which requests
        should be processed in the current batch, considering the token budget and the scheduler's prioritization rules.
        The token_budget is the maximum number of tokens that can be processed in this batch."""
        pass

    @traced
    def has_pending_requests(self) -> bool:
        """Checks if there are requests ready to be processed."""
        return len(self.active_requests) or len(self.waiting_requests)

    @traced
    def finish_request(self, request_id: str, evict_from_cache: bool = True):
        """Completes processing of a request and optionally frees its allocated cache blocks. This method is called
        when a request has finished generation or encountered an error.
        """
        if evict_from_cache:
            self.cache.free_blocks(request_id)
            if request_id in self.active_requests:
                del self.active_requests[request_id]

    @traced
    def get_active_request_static_outputs(self, request_id: str) -> list[int]:
        """Gets generated tokens for an active request."""
        if request_id in self.active_requests:
            return self.active_requests[request_id].static_outputs
        return []

    @traced
    def set_request_cancellation(self, request_id: str):
        """Marks a request for cancellation."""
        with self._cancellation_lock:
            self._requests_to_cancel.add(request_id)

    @traced
    def clear_cancelled_requests(self):
        """Remove all cancelled requests from active and waiting queues."""
        with self._cancellation_lock:
            for request_id in self._requests_to_cancel:
                if request_id in self.active_requests:
                    del self.active_requests[request_id]
                if request_id in self.waiting_requests:
                    del self.waiting_requests[request_id]
                if request_id in self.waiting_requests_order:
                    self.waiting_requests_order.remove(request_id)
                self.cache.free_blocks(request_id)
            self._requests_to_cancel = set()

    @traced
    def request_is_cancelled(self, request_id: str) -> bool:
        """Checks if a request has been cancelled or removed."""
        return request_id in self._requests_to_cancel or (
            request_id not in self.active_requests and request_id not in self.waiting_requests
        )

    @traced
    def _allocate_blocks_if_needed(self, state: RequestState, len_next_tokens: int) -> bool:
        """Allocate additional cache blocks for a request if the currently allocated blocks are insufficient to
        accommodate the next tokens. It calculates how many blocks are needed based on the request's current
        cache occupancy and the number of tokens to be processed. The allocation itself is done by the CacheAllocator
        objects. Returns a boolean indicating if the allocation was successful or not.
        """
        # 1. we check that the occupancy is less than the requested length
        # 2. we allocate enough blocks to cover the requested length
        current_len = state.current_len()
        occupancy = state.allocated_blocks * self.cache.block_size - current_len
        if occupancy < len_next_tokens or state.allocated_blocks == 0:
            blocks_needed = ((len_next_tokens - occupancy + 1) // self.cache.block_size) + 1
            allocated = self.cache.allocate_blocks(blocks_needed, state.request_id)
            if allocated is None:
                return False
            state.allocated_blocks += allocated
        return True

    @traced(span_name="prepare_request")
    def _prepare_request_for_processing(
        self, state: RequestState, token_budget: int, request_ids_to_remove_from_waiting: set[str]
    ):
        """Prepares a request for processing in the current batch."""
        request_tokens = (
            state.remaining_prompt_ids if state.status == RequestStatus.SPLIT_PENDING_REMAINDER else state.prompt_ids
        )
        if len(request_tokens) < token_budget:
            # Can process the entire prompt/remainder
            if state.status == RequestStatus.PENDING:
                self.active_requests[state.request_id] = state
                state.status = RequestStatus.PREFILLING
                request_ids_to_remove_from_waiting.add(state.request_id)
            elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER:
                state.status = RequestStatus.PREFILLING
                state.prompt_ids = state.remaining_prompt_ids
                state.remaining_prompt_ids = []
        else:
            # Need to split the request
            if state.status == RequestStatus.PENDING:
                self.active_requests[state.request_id] = state
                state.status = RequestStatus.PREFILLING_SPLIT
                request_ids_to_remove_from_waiting.add(state.request_id)
            elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER:
                state.status = RequestStatus.PREFILLING_SPLIT
            state.remaining_prompt_ids = request_tokens[token_budget:]
            state.prompt_ids = request_tokens[:token_budget]


@attach_tracer()
class FIFOScheduler(Scheduler):
    """This scheduler processes requests in the order they arrive, meaning decoding requests has priority over
    prefilling requests. Additionally, it includes a safety margin mechanism to prevent cache exhaustion. By default,
    when 80% of the cache is full, new requests will not be scheduled to prioritize decoding active requests."""

    def __init__(self, cache: PagedAttentionCache, retain_cache_on_finish: bool = False, safety_margin: float = 0.2):
        """Initializes the FIFO scheduler. The safety margin is the percentage of free blocks under which we stop
        scheduling new prefill requests, so safety_margin = 0.1 means that when there is less than 10% of free blocks,
        or equivalently when more than 90% of blocks are already allocated, we stop scheduling new prefill requests.
        """
        super().__init__(cache, retain_cache_on_finish)
        self.safety_margin = safety_margin

    @traced
    def schedule_batch(self, token_budget: int) -> list[RequestState]:
        priority_states: list[RequestState] = []
        second_priority_states: list[RequestState] = []
        scheduled_requests = []

        for state in self.active_requests.values():
            if state.status == RequestStatus.DECODING:
                priority_states.append(state)
            if state.status in [RequestStatus.SPLIT_PENDING_REMAINDER, RequestStatus.PREFILLING_SPLIT]:
                second_priority_states.append(state)

        # Add waiting requests to second priority
        for req_id in self.waiting_requests_order:
            second_priority_states.append(self.waiting_requests[req_id])

        candidates = priority_states + second_priority_states
        request_ids_to_remove_from_waiting = set()
        safety_margins = self.safety_margin * self.cache.num_blocks

        for state in candidates:
            # If we are out the safety margin, we only accept decoding requests or the first prefill request
            num_free_blocks = self.cache.get_num_free_blocks()
            outside_safety_margin = num_free_blocks < safety_margins
            if outside_safety_margin and scheduled_requests and state.status != RequestStatus.DECODING:
                break

            self._prepare_request_for_processing(state, token_budget, request_ids_to_remove_from_waiting)
            request_len = len(state.prompt_ids)
            if not self._allocate_blocks_if_needed(
                state, len(state.prompt_ids)
            ):  # don't schedule if we can't allocate blocks
                if len(self.cache._free_blocks) == 0:
                    break
                continue

            @traced
            def _add_to_scheduled_requests(state: RequestState):
                scheduled_requests.append(state)

            _add_to_scheduled_requests(state)

            token_budget -= request_len

            @traced
            def _remove_from_waiting_requests(state: RequestState):
                req_id = state.request_id
                if req_id in self.waiting_requests:
                    del self.waiting_requests[req_id]
                    request_ids_to_remove_from_waiting.add(req_id)

            _remove_from_waiting_requests(state)

            if token_budget == 0:
                break

        self.waiting_requests_order = deque(
            [req_id for req_id in self.waiting_requests_order if req_id not in request_ids_to_remove_from_waiting]
        )

        return scheduled_requests


# FIXME: prioritize adding from waiting reqs before scheduling `RequestStatus.DECODING` when cache space allows it
@attach_tracer()
class PrefillFirstScheduler(Scheduler):
    """Scheduler that prioritizes split prefill requests over decoding requests. This scheduler ensures that split
    prefill requests (which are continuations of partially processed prompts) are completed before processing new
    decoding requests."""

    @traced
    def schedule_batch(self, token_budget: int) -> list[RequestState]:
        priority_states: list[RequestState] = []
        second_priority_states: list[RequestState] = []
        scheduled_requests = []

        for state in self.active_requests.values():
            # XXX: when cache is full, state can stay on `PREFILLING_SPLIT` so we need to take those into account
            if state.status in [RequestStatus.PREFILLING_SPLIT, RequestStatus.SPLIT_PENDING_REMAINDER]:
                priority_states.append(state)
            elif state.status == RequestStatus.DECODING:
                second_priority_states.append(state)

        for req_id in self.waiting_requests_order:
            second_priority_states.append(self.waiting_requests[req_id])

        candidates = priority_states + second_priority_states

        request_ids_to_remove_from_waiting = set()

        for state in candidates:
            self._prepare_request_for_processing(state, token_budget, request_ids_to_remove_from_waiting)
            request_len = len(state.prompt_ids)
            if not self._allocate_blocks_if_needed(
                state, len(state.prompt_ids)
            ):  # don't schedule if we can't allocate blocks
                if len(self.cache._free_blocks) == 0:
                    break
                continue

            @traced
            def _add_to_scheduled_requests(state: RequestState):
                scheduled_requests.append(state)

            _add_to_scheduled_requests(state)

            token_budget -= request_len

            @traced
            def _remove_from_waiting_requests(state: RequestState):
                req_id = state.request_id
                if req_id in self.waiting_requests:
                    del self.waiting_requests[req_id]
                    request_ids_to_remove_from_waiting.add(req_id)

            _remove_from_waiting_requests(state)

            if token_budget == 0:
                break

        self.waiting_requests_order = deque(
            [req_id for req_id in self.waiting_requests_order if req_id not in request_ids_to_remove_from_waiting]
        )

        return scheduled_requests


SCHEDULER_MAPPING = {
    "fifo": FIFOScheduler,
    "prefill_first": PrefillFirstScheduler,
}