MRI
/
venv
/lib
/python3.13
/site-packages
/transformers
/generation
/continuous_batching
/requests.py
| # coding=utf-8 | |
| # Copyright 2025 The HuggingFace Inc. team | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import time | |
| from dataclasses import dataclass, field | |
| from enum import Enum | |
| from typing import Optional | |
| import torch | |
| from ...utils.logging import logging | |
| from ...utils.metrics import traced | |
| # We centralize the logger here to coordinate between logging and progress bar | |
| logger = logging.getLogger("ContinuousBatchingLogger") | |
| # logger.setLevel(logging.INFO) | |
| def get_device_and_memory_breakdown() -> tuple[torch.device, int, int, int]: | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| total_memory = torch.cuda.get_device_properties(device).total_memory | |
| reserved_memory = torch.cuda.memory_reserved(device) | |
| allocated_memory = torch.cuda.memory_allocated(device) | |
| elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): | |
| device = torch.device("mps") | |
| # MPS memory reporting (PyTorch 2.0+) | |
| total_memory = torch.mps.driver_allocated_memory() | |
| allocated_memory = total_memory - torch.mps.recommended_max_memory() | |
| reserved_memory = 0 # MPS does not track reserved separately | |
| else: | |
| device = torch.device("cpu") | |
| total_memory = None | |
| reserved_memory = 0 | |
| allocated_memory = 0 | |
| return device, total_memory, reserved_memory, allocated_memory | |
| class RequestStatus(Enum): | |
| """Status of a generation request through its lifecycle.""" | |
| PENDING = "pending" | |
| PREFILLING = "prefilling" | |
| PREFILLING_SPLIT = "prefilling_split" | |
| SPLIT_PENDING_REMAINDER = "split_pending_remainder" | |
| DECODING = "decoding" | |
| FINISHED = "finished" | |
| FAILED = "failed" | |
| class GenerationOutput: | |
| """Tracks the output of a generation request. | |
| Attributes: | |
| request_id (str): The ID of the generation request. | |
| prompt_ids (list[int]): The IDs of the prompt tokens. | |
| generated_tokens (list[int]): The generated tokens. | |
| logprobs (list[float]): The log probabilities of the generated tokens. | |
| error (Optional[str]): Any error message associated with the request. When None, the request was successful. | |
| status (RequestStatus): The status of the request. | |
| created_time (float): The time the request was created. | |
| """ | |
| request_id: str | |
| prompt_ids: list[int] = field(default_factory=list) | |
| generated_tokens: list[int] = field(default_factory=list) | |
| logprobs: list[float] = field(default_factory=list) | |
| error: Optional[str] = None | |
| status: RequestStatus = RequestStatus.PENDING | |
| created_time: float = field(default_factory=time.time) | |
| class RequestState: | |
| """Tracks the state of a generation request through its lifecycle. | |
| Attributes: | |
| request_id (str): The ID of the generation request. | |
| full_prompt_ids (list[int] | None): The tokens IDs of the full prompt. | |
| prompt_ids (list[int] | None): The tokens IDs currently being processed. | |
| remaining_prompt_ids (list[int]): The tokens IDs remaining to be processed (for split requests). | |
| static_outputs (list[int]): The generated tokens. | |
| allocated_blocks (int): The number of blocks allocated to the request. | |
| position_offset (int): The current position in the sequence for position_ids. | |
| status (RequestStatus): The status of the request: can be one of PENDING, PREFILLING, PREFILLING_SPLIT, | |
| SPLIT_PENDING_REMAINDER, DECODING, FINISHED, FAILED | |
| max_new_tokens (int): The maximum number of new tokens to generate. | |
| eos_token_id (int): The ID of the end-of-sequence token. | |
| created_time (float): The time the request was created. | |
| error (Optional[str]): Any error message associated with the request. When None, has had no error yet. | |
| """ | |
| # Required fields | |
| request_id: str | |
| full_prompt_ids: Optional[list[int]] = None # Full initial prompt | |
| prompt_ids: Optional[list[int]] = None # Tokens IDs currently being processed (initial + generated) | |
| remaining_prompt_ids: list[int] = field(default_factory=list) # For split requests, prefill left to process | |
| static_outputs: list[int] = field(default_factory=list) # Generated tokens | |
| allocated_blocks: int = 0 # Number of blocks allocated to the request | |
| position_offset: int = 0 # Current position in the sequence for position_ids | |
| _status: RequestStatus = RequestStatus.PENDING # Status of the request, hidden behind a property | |
| max_new_tokens: int = 20 # Maximum number of new tokens to generate | |
| eos_token_id: int = -1 # ID of the end-of-sequence token | |
| created_time: float = field(default_factory=time.time) # Time the request was created | |
| error: Optional[str] = None # Error message if the request failed | |
| lifespan: tuple[float, float] = (-1, -1) # (time request was no longer pending, time request finished) | |
| def status(self) -> RequestStatus: | |
| return self._status | |
| def status(self, value: RequestStatus): | |
| if self._status == RequestStatus.PENDING: | |
| self.lifespan = (time.time(), -1) | |
| elif value == RequestStatus.FINISHED: | |
| self.lifespan = (self.lifespan[0], time.time()) | |
| self.log_end_of_request() | |
| self._status = value | |
| def log_end_of_request(self): | |
| prefill_len = len(self.full_prompt_ids) | |
| decode_len = self.generated_len() | |
| start_time = self.lifespan[0] - self.created_time | |
| end_time = self.lifespan[1] - self.created_time | |
| logger.info( | |
| f"Request {self.request_id} finished: {prefill_len = } {decode_len = } {start_time = } {end_time = }" | |
| ) | |
| def current_len(self) -> int: | |
| """Get the current length of the sequence (prompt + generated tokens).""" | |
| return self.position_offset | |
| def generated_len(self) -> int: | |
| """Get the number of tokens generated so far.""" | |
| return len(self.static_outputs) | |
| # TODO: this logic seems one token off, check it out | |
| def update_with_token(self, token_id: int) -> bool: | |
| """Update the request with a newly generated token and check for completion. | |
| Args: | |
| token_id: The token ID to add to the output sequence | |
| Returns: | |
| bool: True if the request is now complete, False otherwise | |
| """ | |
| # Only update if we're in decoding state | |
| if self.status != RequestStatus.DECODING: | |
| return False | |
| is_eos = token_id == self.eos_token_id and self.eos_token_id != -1 | |
| is_max_len = self.generated_len() >= self.max_new_tokens | |
| # Only add the token if we're not finishing due to max length | |
| # (EOS tokens should still be added to the output) | |
| if not (is_max_len and not is_eos): | |
| self.static_outputs.extend([token_id]) | |
| if is_eos or is_max_len: | |
| self.status = RequestStatus.FINISHED | |
| return True | |
| return False | |
| def __repr__(self): | |
| msg = [ | |
| f"request_id={self.request_id}", | |
| f"status={self._status}", | |
| f"out_tokens={self.generated_len()}", | |
| f"query_length={len(self.prompt_ids)}", | |
| f"remaining_tokens={len(self.remaining_prompt_ids)}", | |
| f"kv_length={self.position_offset}", | |
| f"full_prompt_length={len(self.full_prompt_ids)}", | |
| f"allocated_blocks={self.allocated_blocks}", | |
| f"generated_tokens={self.static_outputs}", | |
| ] | |
| return "RequestState(\n\t" + ",\n\t".join(msg) + "\n)" | |
| def to_generation_output(self): | |
| """Convert the request state to a GenerationOutput object.""" | |
| return GenerationOutput( | |
| request_id=self.request_id, | |
| prompt_ids=self.full_prompt_ids, | |
| status=self.status, | |
| generated_tokens=self.static_outputs, | |
| logprobs=[], | |
| error=self.error, | |
| ) | |