File size: 5,946 Bytes
62dca4c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 | """
This file contains the wrapper for the SGL model.
"""
from dataclasses import dataclass
from typing import Optional, Union
import torch
import torch.nn as nn
from sglang.srt.layers.logits_processor import (
LogitsMetadata,
LogitsProcessor,
LogitsProcessorOutput,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.server_args import get_global_server_args
@dataclass
class ReplacedLogitsProcessorEagle3Output:
"""
A dataclass to store the logits and aux hidden states needed for EAGLE3.
"""
logits: torch.Tensor
aux_hidden_states: torch.Tensor
last_hidden_states: Optional[torch.Tensor] = None
def replaced_logits_processor_forward_for_eagle3(
self,
input_ids,
hidden_states,
lm_head,
logits_metadata: Union[LogitsMetadata, ForwardBatch],
aux_hidden_states: Optional[torch.Tensor] = None,
return_last_hidden_states: bool = False,
return_logits: bool = False,
) -> LogitsProcessorOutput:
"""
This is a modified forward function for the SGLang's logits processor, adapted from https://github.com/sgl-project/sglang/blob/v0.5.4/python/sglang/srt/layers/logits_processor.py.
The modification is to return the logits and aux hidden states instead of the last hidden states.
"""
if isinstance(logits_metadata, ForwardBatch):
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
# Check if multi-item scoring is enabled via server args (only for prefill-only requests)
multi_item_delimiter = get_global_server_args().multi_item_scoring_delimiter
if multi_item_delimiter is not None and logits_metadata.is_prefill_only:
return self.compute_logprobs_for_multi_item_scoring(
input_ids, hidden_states, lm_head, logits_metadata, multi_item_delimiter
)
# Get the last hidden states and last logits for the next token prediction
if (
logits_metadata.forward_mode.is_decode_or_idle()
or logits_metadata.forward_mode.is_target_verify()
or logits_metadata.forward_mode.is_draft_extend_v2()
):
pruned_states = hidden_states
if aux_hidden_states is not None:
aux_pruned_states = [hidden for hidden in aux_hidden_states]
sample_indices = None
input_logprob_indices = None
else:
raise RuntimeError(
f"The modified logits processor is not supported for this forward mode: {logits_metadata.forward_mode}"
)
if return_last_hidden_states:
last_hidden_states = pruned_states
else:
last_hidden_states = None
if return_logits:
# Compute logits for both input and sampled tokens.
logits = self._get_logits(pruned_states, lm_head, logits_metadata)
else:
logits = None
# get the aux hidden states
hidden_states_to_store: Optional[torch.Tensor] = None
if logits_metadata.capture_hidden_mode.need_capture():
if logits_metadata.capture_hidden_mode.is_full():
if aux_hidden_states is not None:
aux_hidden_states = torch.cat(aux_hidden_states, dim=-1)
hidden_states_to_store = aux_hidden_states
else:
hidden_states_to_store = hidden_states
elif logits_metadata.capture_hidden_mode.is_last():
# Get the last token hidden states. If sample_indices is None,
# pruned states only contain the last tokens already.
if aux_hidden_states is not None:
aux_pruned_states = torch.cat(aux_pruned_states, dim=-1)
hidden_states_to_store = (
aux_pruned_states[sample_indices]
if sample_indices is not None
else aux_pruned_states
)
else:
hidden_states_to_store = (
pruned_states[sample_indices]
if sample_indices is not None
else pruned_states
)
else:
assert False, "Should never reach"
assert (
not logits_metadata.extend_return_logprob
), "extend_return_logprob is not supported"
# Decode mode or extend mode without return_logprob.
return ReplacedLogitsProcessorEagle3Output(
logits=logits,
aux_hidden_states=hidden_states_to_store,
last_hidden_states=last_hidden_states,
)
class LogitsProcessorForEAGLE3(torch.nn.Module):
def __init__(
self,
logits_processor: LogitsProcessor,
return_last_hidden_states: bool = False,
return_logits: bool = False,
):
super().__init__()
self.logits_processor = logits_processor
self.return_last_hidden_states = return_last_hidden_states
self.return_logits = return_logits
def forward(
self,
input_ids,
hidden_states,
lm_head,
logits_metadata,
aux_hidden_states: Optional[torch.Tensor] = None,
) -> LogitsProcessorOutput:
logits_metadata.forward_mode = ForwardMode.DECODE
ret = replaced_logits_processor_forward_for_eagle3(
self.logits_processor,
input_ids,
hidden_states,
lm_head,
logits_metadata,
aux_hidden_states,
self.return_last_hidden_states,
self.return_logits,
)
return ret
def wrap_eagle3_logits_processors_in_module(
module: nn.Module, return_full_logits: bool = False
):
"""
This function will wrap the SGLang's original logits processor with the modified one for EAGLE3.
"""
for name, submodule in module.named_modules():
if isinstance(submodule, LogitsProcessor):
wrapped = LogitsProcessorForEAGLE3(submodule, return_full_logits)
setattr(module, name, wrapped)
print(f"wrapped {name} with LogitsProcessorForEAGLE3")
|