File size: 8,734 Bytes
9468637 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
import torch
import torch.nn as nn
import random
import logging
import copy
from typing import Union, List, Optional
from transformers import LogitsProcessor, LogitsProcessorList, StoppingCriteriaList, GenerationConfig
from transformers.generation.utils import GenerationMixin, GenerateDecoderOnlyOutput
logger = logging.getLogger(__name__)
class XTCLogitsWarper(LogitsProcessor):
"""
LogitsWarper that implements Exclude Top Choices (XTC).
"""
def __init__(self, threshold: float, probability: float, protected_token_ids: Optional[List[int]] = None, filter_value: float = -float("Inf")):
self.threshold = threshold
self.probability = probability
self.filter_value = filter_value
self.protected_token_ids = set(protected_token_ids) if protected_token_ids is not None else set()
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if self.probability <= 0.0 or random.random() >= self.probability:
return scores
# Sort scores descending
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
probs = sorted_logits.softmax(dim=-1)
# Create a mask for removal
sorted_indices_to_remove = torch.full_like(probs, False, dtype=torch.bool)
# XTC Logic
sorted_indices_to_remove[..., :-1] = probs[..., 1:] >= self.threshold
# Scatter back to original indices
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
# Safety: Check if protected tokens would be removed
if self.protected_token_ids:
for pid in self.protected_token_ids:
if indices_to_remove[:, pid].any():
# If any protected token is targeted, abort XTC for this step
return scores
# Apply the filter
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores
def _xtc_decoding(
model,
input_ids: torch.LongTensor,
logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig,
synced_gpus: bool = False,
streamer: "BaseStreamer" = None,
**model_kwargs,
) -> Union[GenerateDecoderOnlyOutput, torch.LongTensor]:
"""
Custom decoding loop that ensures XTC is applied during sampling.
"""
# 1. Retrieve XTC params from the config (injected by the generate wrapper)
xtc_threshold = getattr(generation_config, "xtc_threshold", 0.1)
xtc_probability = getattr(generation_config, "xtc_probability", 0.0)
# Identify tokens to protect
protected_ids = []
if generation_config.eos_token_id is not None:
if isinstance(generation_config.eos_token_id, list):
protected_ids.extend(generation_config.eos_token_id)
else:
protected_ids.append(generation_config.eos_token_id)
# Check for custom protected tokens injected via config
custom_protected = getattr(generation_config, "xtc_protected_tokens", None)
if custom_protected:
protected_ids.extend(custom_protected)
# 2. Inject XTC into the LogitsProcessorList
if xtc_probability > 0:
xtc_warper = XTCLogitsWarper(
threshold=xtc_threshold,
probability=xtc_probability,
protected_token_ids=protected_ids
)
logits_processor.append(xtc_warper)
# 3. Initialization
pad_token_id = generation_config._pad_token_tensor
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
return_dict_in_generate = generation_config.return_dict_in_generate
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
# Ensure sampling is on
do_sample = True
# Init output tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# Track finished sequences
batch_size, cur_length = input_ids.shape[:2]
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = model._get_initial_cache_position(cur_length, input_ids.device, model_kwargs)
this_peer_finished = False
# 4. Decoding Loop
while model._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = model(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
if synced_gpus and this_peer_finished:
continue
next_token_logits = outputs.logits[:, -1, :]
# Apply Logits Processors (XTC happens here)
next_token_scores = logits_processor(input_ids, next_token_logits)
# Store scores
if return_dict_in_generate and output_scores:
scores += (next_token_scores,)
if return_dict_in_generate and output_attentions:
decoder_attentions += ((outputs.decoder_attentions,) if model.config.is_encoder_decoder else (outputs.attentions,))
if return_dict_in_generate and output_hidden_states:
decoder_hidden_states += ((outputs.decoder_hidden_states,) if model.config.is_encoder_decoder else (outputs.hidden_states,))
# Sample (Multinomial)
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
# EOS check
if has_eos_stopping_criteria:
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# Update inputs
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = model._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=model.config.is_encoder_decoder
)
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
this_peer_finished = unfinished_sequences.max() == 0
if streamer is not None:
streamer.end()
if return_dict_in_generate:
return GenerateDecoderOnlyOutput(
sequences=input_ids,
scores=scores,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return input_ids
def generate(model, *args, **kwargs):
"""
Wrapper function that prepares parameters and calls the internal decoding loop.
"""
# 1. Extract XTC parameters from kwargs using .pop()
# This prevents the "unused model_kwargs" warning because they are removed from kwargs
xtc_probability = kwargs.pop("xtc_probability", 0.0)
xtc_threshold = kwargs.pop("xtc_threshold", 0.1)
xtc_protected_tokens = kwargs.pop("xtc_protected_tokens", None)
# 2. Prepare GenerationConfig
# We must handle the case where generation_config is None or not present
generation_config = kwargs.get("generation_config", None)
if generation_config is None:
# If no config passed, copy the model's default
generation_config = copy.deepcopy(model.generation_config)
else:
# If passed, verify it's not None
if generation_config is None:
generation_config = copy.deepcopy(model.generation_config)
# Force sampling (XTC doesn't work with greedy)
generation_config.do_sample = True
# 3. Inject XTC params into the config object
# Python allows dynamic attribute assignment
generation_config.xtc_probability = xtc_probability
generation_config.xtc_threshold = xtc_threshold
generation_config.xtc_protected_tokens = xtc_protected_tokens
# Update kwargs with the modified config
kwargs["generation_config"] = generation_config
# 4. Call standard generation, which will route to `custom_generate` (_xtc_decoding)
# We pass _xtc_decoding as the function to execute
return GenerationMixin.generate(
model, *args, custom_generate=_xtc_decoding, **kwargs
) |