xtc / custom_generate /generate.py
Todokete's picture
Upload generate.py
9468637 verified
import torch
import torch.nn as nn
import random
import logging
import copy
from typing import Union, List, Optional
from transformers import LogitsProcessor, LogitsProcessorList, StoppingCriteriaList, GenerationConfig
from transformers.generation.utils import GenerationMixin, GenerateDecoderOnlyOutput
logger = logging.getLogger(__name__)
class XTCLogitsWarper(LogitsProcessor):
"""
LogitsWarper that implements Exclude Top Choices (XTC).
"""
def __init__(self, threshold: float, probability: float, protected_token_ids: Optional[List[int]] = None, filter_value: float = -float("Inf")):
self.threshold = threshold
self.probability = probability
self.filter_value = filter_value
self.protected_token_ids = set(protected_token_ids) if protected_token_ids is not None else set()
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if self.probability <= 0.0 or random.random() >= self.probability:
return scores
# Sort scores descending
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
probs = sorted_logits.softmax(dim=-1)
# Create a mask for removal
sorted_indices_to_remove = torch.full_like(probs, False, dtype=torch.bool)
# XTC Logic
sorted_indices_to_remove[..., :-1] = probs[..., 1:] >= self.threshold
# Scatter back to original indices
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
# Safety: Check if protected tokens would be removed
if self.protected_token_ids:
for pid in self.protected_token_ids:
if indices_to_remove[:, pid].any():
# If any protected token is targeted, abort XTC for this step
return scores
# Apply the filter
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores
def _xtc_decoding(
model,
input_ids: torch.LongTensor,
logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig,
synced_gpus: bool = False,
streamer: "BaseStreamer" = None,
**model_kwargs,
) -> Union[GenerateDecoderOnlyOutput, torch.LongTensor]:
"""
Custom decoding loop that ensures XTC is applied during sampling.
"""
# 1. Retrieve XTC params from the config (injected by the generate wrapper)
xtc_threshold = getattr(generation_config, "xtc_threshold", 0.1)
xtc_probability = getattr(generation_config, "xtc_probability", 0.0)
# Identify tokens to protect
protected_ids = []
if generation_config.eos_token_id is not None:
if isinstance(generation_config.eos_token_id, list):
protected_ids.extend(generation_config.eos_token_id)
else:
protected_ids.append(generation_config.eos_token_id)
# Check for custom protected tokens injected via config
custom_protected = getattr(generation_config, "xtc_protected_tokens", None)
if custom_protected:
protected_ids.extend(custom_protected)
# 2. Inject XTC into the LogitsProcessorList
if xtc_probability > 0:
xtc_warper = XTCLogitsWarper(
threshold=xtc_threshold,
probability=xtc_probability,
protected_token_ids=protected_ids
)
logits_processor.append(xtc_warper)
# 3. Initialization
pad_token_id = generation_config._pad_token_tensor
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
return_dict_in_generate = generation_config.return_dict_in_generate
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
# Ensure sampling is on
do_sample = True
# Init output tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# Track finished sequences
batch_size, cur_length = input_ids.shape[:2]
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = model._get_initial_cache_position(cur_length, input_ids.device, model_kwargs)
this_peer_finished = False
# 4. Decoding Loop
while model._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = model(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
if synced_gpus and this_peer_finished:
continue
next_token_logits = outputs.logits[:, -1, :]
# Apply Logits Processors (XTC happens here)
next_token_scores = logits_processor(input_ids, next_token_logits)
# Store scores
if return_dict_in_generate and output_scores:
scores += (next_token_scores,)
if return_dict_in_generate and output_attentions:
decoder_attentions += ((outputs.decoder_attentions,) if model.config.is_encoder_decoder else (outputs.attentions,))
if return_dict_in_generate and output_hidden_states:
decoder_hidden_states += ((outputs.decoder_hidden_states,) if model.config.is_encoder_decoder else (outputs.hidden_states,))
# Sample (Multinomial)
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
# EOS check
if has_eos_stopping_criteria:
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# Update inputs
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = model._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=model.config.is_encoder_decoder
)
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
this_peer_finished = unfinished_sequences.max() == 0
if streamer is not None:
streamer.end()
if return_dict_in_generate:
return GenerateDecoderOnlyOutput(
sequences=input_ids,
scores=scores,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return input_ids
def generate(model, *args, **kwargs):
"""
Wrapper function that prepares parameters and calls the internal decoding loop.
"""
# 1. Extract XTC parameters from kwargs using .pop()
# This prevents the "unused model_kwargs" warning because they are removed from kwargs
xtc_probability = kwargs.pop("xtc_probability", 0.0)
xtc_threshold = kwargs.pop("xtc_threshold", 0.1)
xtc_protected_tokens = kwargs.pop("xtc_protected_tokens", None)
# 2. Prepare GenerationConfig
# We must handle the case where generation_config is None or not present
generation_config = kwargs.get("generation_config", None)
if generation_config is None:
# If no config passed, copy the model's default
generation_config = copy.deepcopy(model.generation_config)
else:
# If passed, verify it's not None
if generation_config is None:
generation_config = copy.deepcopy(model.generation_config)
# Force sampling (XTC doesn't work with greedy)
generation_config.do_sample = True
# 3. Inject XTC params into the config object
# Python allows dynamic attribute assignment
generation_config.xtc_probability = xtc_probability
generation_config.xtc_threshold = xtc_threshold
generation_config.xtc_protected_tokens = xtc_protected_tokens
# Update kwargs with the modified config
kwargs["generation_config"] = generation_config
# 4. Call standard generation, which will route to `custom_generate` (_xtc_decoding)
# We pass _xtc_decoding as the function to execute
return GenerationMixin.generate(
model, *args, custom_generate=_xtc_decoding, **kwargs
)