| from __future__ import annotations | |
| import dataclasses | |
| import logging | |
| from typing import TYPE_CHECKING, List, Optional | |
| import torch | |
| from sglang.srt.layers.logits_processor import LogitsProcessorOutput | |
| from sglang.srt.managers.overlap_utils import FutureIndices | |
| from sglang.srt.managers.schedule_batch import Req | |
| from sglang.srt.model_executor.forward_batch_info import PPProxyTensors | |
| if TYPE_CHECKING: | |
| from sglang.srt.managers.scheduler import GenerationBatchResult | |
| from sglang.srt.speculative.eagle_info import EagleDraftInput | |
| logger = logging.getLogger(__name__) | |
| class GenerationBatchResult: | |
| logits_output: Optional[LogitsProcessorOutput] = None | |
| pp_hidden_states_proxy_tensors: Optional[PPProxyTensors] = None | |
| next_token_ids: Optional[torch.Tensor] = None | |
| num_accepted_tokens: Optional[int] = None | |
| can_run_cuda_graph: bool = False | |
| # For output processing | |
| extend_input_len_per_req: Optional[List[int]] = None | |
| extend_logprob_start_len_per_req: Optional[List[int]] = None | |
| # For overlap scheduling | |
| copy_done: Optional[torch.cuda.Event] = None | |
| delay_sample_func: Optional[callable] = None | |
| future_indices: Optional[FutureIndices] = None | |
| # FIXME(lsyin): maybe move to a better place? | |
| # sync path: forward stream -> output processor | |
| accept_lens: Optional[torch.Tensor] = None | |
| allocate_lens: Optional[torch.Tensor] = None | |
| # relay path: forward stream -> next step forward | |
| next_draft_input: Optional[EagleDraftInput] = None | |
| def copy_to_cpu(self, return_logprob: bool = False): | |
| """Copy tensors to CPU in overlap scheduling. | |
| Only the tensors which are needed for processing results are copied, | |
| e.g., next_token_ids, logits outputs | |
| """ | |
| if return_logprob: | |
| if self.logits_output.next_token_logits is not None: | |
| self.logits_output.next_token_logits = ( | |
| self.logits_output.next_token_logits.to("cpu", non_blocking=True) | |
| ) | |
| if self.logits_output.input_token_logprobs is not None: | |
| self.logits_output.input_token_logprobs = ( | |
| self.logits_output.input_token_logprobs.to("cpu", non_blocking=True) | |
| ) | |
| if self.logits_output.hidden_states is not None: | |
| self.logits_output.hidden_states = self.logits_output.hidden_states.to( | |
| "cpu", non_blocking=True | |
| ) | |
| self.next_token_ids = self.next_token_ids.to("cpu", non_blocking=True) | |
| if self.accept_lens is not None: | |
| self.accept_lens = self.accept_lens.to("cpu", non_blocking=True) | |
| if self.allocate_lens is not None: | |
| self.allocate_lens = self.allocate_lens.to("cpu", non_blocking=True) | |
| self.copy_done.record() | |
| def from_pp_proxy( | |
| cls, logits_output, next_pp_outputs: PPProxyTensors, can_run_cuda_graph | |
| ): | |
| # TODO(lsyin): refactor PP and avoid using dict | |
| proxy_dict = next_pp_outputs.tensors | |
| return cls( | |
| logits_output=logits_output, | |
| pp_hidden_states_proxy_tensors=None, | |
| next_token_ids=next_pp_outputs["next_token_ids"], | |
| extend_input_len_per_req=proxy_dict.get("extend_input_len_per_req", None), | |
| extend_logprob_start_len_per_req=proxy_dict.get( | |
| "extend_logprob_start_len_per_req", None | |
| ), | |
| can_run_cuda_graph=can_run_cuda_graph, | |
| ) | |
| def validate_input_length( | |
| req: Req, max_req_input_len: int, allow_auto_truncate: bool | |
| ) -> Optional[str]: | |
| """Validate and potentially truncate input length. | |
| Args: | |
| req: The request containing input_ids to validate | |
| max_req_input_len: Maximum allowed input length | |
| allow_auto_truncate: Whether to truncate long inputs | |
| Returns: | |
| Error message if validation fails, None if successful | |
| """ | |
| if len(req.origin_input_ids) >= max_req_input_len: | |
| if allow_auto_truncate: | |
| logger.warning( | |
| "Request length is longer than the KV cache pool size or " | |
| "the max context length. Truncated. " | |
| f"{len(req.origin_input_ids)=}, {max_req_input_len=}." | |
| ) | |
| req.origin_input_ids = req.origin_input_ids[:max_req_input_len] | |
| return None | |
| else: | |
| error_msg = ( | |
| f"Input length ({len(req.origin_input_ids)} tokens) exceeds " | |
| f"the maximum allowed length ({max_req_input_len} tokens). " | |
| f"Use a shorter input or enable --allow-auto-truncate." | |
| ) | |
| return error_msg | |
| return None | |
| def get_logprob_dict_from_result(result: GenerationBatchResult) -> dict: | |
| logits_output = result.logits_output | |
| assert logits_output is not None | |
| return { | |
| "extend_input_len_per_req": result.extend_input_len_per_req, | |
| "extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req, | |
| "next_token_logprobs": result.logits_output.next_token_logprobs, | |
| "next_token_top_logprobs_val": result.logits_output.next_token_top_logprobs_val, | |
| "next_token_top_logprobs_idx": result.logits_output.next_token_top_logprobs_idx, | |
| "next_token_token_ids_logprobs_val": result.logits_output.next_token_token_ids_logprobs_val, | |
| "next_token_token_ids_logprobs_idx": result.logits_output.next_token_token_ids_logprobs_idx, | |
| "input_token_logprobs": result.logits_output.input_token_logprobs, | |
| "input_top_logprobs_val": result.logits_output.input_top_logprobs_val, | |
| "input_top_logprobs_idx": result.logits_output.input_top_logprobs_idx, | |
| "input_token_ids_logprobs_val": result.logits_output.input_token_ids_logprobs_val, | |
| "input_token_ids_logprobs_idx": result.logits_output.input_token_ids_logprobs_idx, | |
| } | |
| def get_logprob_from_pp_outputs( | |
| next_pp_outputs: PPProxyTensors, | |
| ) -> tuple[LogitsProcessorOutput, list[int], list[int]]: | |
| logits_output = LogitsProcessorOutput( | |
| # Do not send logits and hidden states because they are large | |
| next_token_logits=None, | |
| hidden_states=None, | |
| next_token_logprobs=next_pp_outputs["next_token_logprobs"], | |
| next_token_top_logprobs_val=next_pp_outputs["next_token_top_logprobs_val"], | |
| next_token_top_logprobs_idx=next_pp_outputs["next_token_top_logprobs_idx"], | |
| next_token_token_ids_logprobs_val=next_pp_outputs[ | |
| "next_token_token_ids_logprobs_val" | |
| ], | |
| next_token_token_ids_logprobs_idx=next_pp_outputs[ | |
| "next_token_token_ids_logprobs_idx" | |
| ], | |
| input_token_logprobs=next_pp_outputs["input_token_logprobs"], | |
| input_top_logprobs_val=next_pp_outputs["input_top_logprobs_val"], | |
| input_top_logprobs_idx=next_pp_outputs["input_top_logprobs_idx"], | |
| input_token_ids_logprobs_val=next_pp_outputs["input_token_ids_logprobs_val"], | |
| input_token_ids_logprobs_idx=next_pp_outputs["input_token_ids_logprobs_idx"], | |
| ) | |
| extend_input_len_per_req = next_pp_outputs["extend_input_len_per_req"] | |
| extend_logprob_start_len_per_req = next_pp_outputs[ | |
| "extend_logprob_start_len_per_req" | |
| ] | |
| return logits_output, extend_input_len_per_req, extend_logprob_start_len_per_req | |
Xet Storage Details
- Size:
- 7.32 kB
- Xet hash:
- 01361be670b91eb3c0be4beade0982115128672f24ce9629f450e4ed248ccf8a
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.