| """ |
| This file contains the wrapper for the SGL model. |
| """ |
|
|
| from dataclasses import dataclass |
| from typing import Optional, Union |
|
|
| import torch |
| import torch.nn as nn |
| from sglang.srt.layers.logits_processor import ( |
| LogitsMetadata, |
| LogitsProcessor, |
| LogitsProcessorOutput, |
| ) |
| from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode |
| from sglang.srt.server_args import get_global_server_args |
|
|
|
|
| @dataclass |
| class ReplacedLogitsProcessorEagle3Output: |
| """ |
| A dataclass to store the logits and aux hidden states needed for EAGLE3. |
| """ |
|
|
| logits: torch.Tensor |
| aux_hidden_states: torch.Tensor |
| last_hidden_states: Optional[torch.Tensor] = None |
|
|
|
|
| def replaced_logits_processor_forward_for_eagle3( |
| self, |
| input_ids, |
| hidden_states, |
| lm_head, |
| logits_metadata: Union[LogitsMetadata, ForwardBatch], |
| aux_hidden_states: Optional[torch.Tensor] = None, |
| return_last_hidden_states: bool = False, |
| return_logits: bool = False, |
| ) -> LogitsProcessorOutput: |
| """ |
| This is a modified forward function for the SGLang's logits processor, adapted from https://github.com/sgl-project/sglang/blob/v0.5.4/python/sglang/srt/layers/logits_processor.py. |
| The modification is to return the logits and aux hidden states instead of the last hidden states. |
| """ |
|
|
| if isinstance(logits_metadata, ForwardBatch): |
| logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata) |
|
|
| |
| multi_item_delimiter = get_global_server_args().multi_item_scoring_delimiter |
| if multi_item_delimiter is not None and logits_metadata.is_prefill_only: |
| return self.compute_logprobs_for_multi_item_scoring( |
| input_ids, hidden_states, lm_head, logits_metadata, multi_item_delimiter |
| ) |
|
|
| |
| if ( |
| logits_metadata.forward_mode.is_decode_or_idle() |
| or logits_metadata.forward_mode.is_target_verify() |
| or logits_metadata.forward_mode.is_draft_extend_v2() |
| ): |
| pruned_states = hidden_states |
| if aux_hidden_states is not None: |
| aux_pruned_states = [hidden for hidden in aux_hidden_states] |
| sample_indices = None |
| input_logprob_indices = None |
| else: |
| raise RuntimeError( |
| f"The modified logits processor is not supported for this forward mode: {logits_metadata.forward_mode}" |
| ) |
|
|
| if return_last_hidden_states: |
| last_hidden_states = pruned_states |
| else: |
| last_hidden_states = None |
|
|
| if return_logits: |
| |
| logits = self._get_logits(pruned_states, lm_head, logits_metadata) |
| else: |
| logits = None |
|
|
| |
| hidden_states_to_store: Optional[torch.Tensor] = None |
| if logits_metadata.capture_hidden_mode.need_capture(): |
| if logits_metadata.capture_hidden_mode.is_full(): |
| if aux_hidden_states is not None: |
| aux_hidden_states = torch.cat(aux_hidden_states, dim=-1) |
| hidden_states_to_store = aux_hidden_states |
| else: |
| hidden_states_to_store = hidden_states |
| elif logits_metadata.capture_hidden_mode.is_last(): |
| |
| |
| if aux_hidden_states is not None: |
| aux_pruned_states = torch.cat(aux_pruned_states, dim=-1) |
| hidden_states_to_store = ( |
| aux_pruned_states[sample_indices] |
| if sample_indices is not None |
| else aux_pruned_states |
| ) |
| else: |
| hidden_states_to_store = ( |
| pruned_states[sample_indices] |
| if sample_indices is not None |
| else pruned_states |
| ) |
| else: |
| assert False, "Should never reach" |
|
|
| assert ( |
| not logits_metadata.extend_return_logprob |
| ), "extend_return_logprob is not supported" |
| |
| return ReplacedLogitsProcessorEagle3Output( |
| logits=logits, |
| aux_hidden_states=hidden_states_to_store, |
| last_hidden_states=last_hidden_states, |
| ) |
|
|
|
|
| class LogitsProcessorForEAGLE3(torch.nn.Module): |
| def __init__( |
| self, |
| logits_processor: LogitsProcessor, |
| return_last_hidden_states: bool = False, |
| return_logits: bool = False, |
| ): |
| super().__init__() |
| self.logits_processor = logits_processor |
| self.return_last_hidden_states = return_last_hidden_states |
| self.return_logits = return_logits |
|
|
| def forward( |
| self, |
| input_ids, |
| hidden_states, |
| lm_head, |
| logits_metadata, |
| aux_hidden_states: Optional[torch.Tensor] = None, |
| ) -> LogitsProcessorOutput: |
| logits_metadata.forward_mode = ForwardMode.DECODE |
| ret = replaced_logits_processor_forward_for_eagle3( |
| self.logits_processor, |
| input_ids, |
| hidden_states, |
| lm_head, |
| logits_metadata, |
| aux_hidden_states, |
| self.return_last_hidden_states, |
| self.return_logits, |
| ) |
| return ret |
|
|
|
|
| def wrap_eagle3_logits_processors_in_module( |
| module: nn.Module, return_full_logits: bool = False |
| ): |
| """ |
| This function will wrap the SGLang's original logits processor with the modified one for EAGLE3. |
| """ |
| for name, submodule in module.named_modules(): |
| if isinstance(submodule, LogitsProcessor): |
| wrapped = LogitsProcessorForEAGLE3(submodule, return_full_logits) |
| setattr(module, name, wrapped) |
| print(f"wrapped {name} with LogitsProcessorForEAGLE3") |
|
|