import torch import torch.nn as nn import random import logging import copy from typing import Union, List, Optional from transformers import LogitsProcessor, LogitsProcessorList, StoppingCriteriaList, GenerationConfig from transformers.generation.utils import GenerationMixin, GenerateDecoderOnlyOutput logger = logging.getLogger(__name__) class XTCLogitsWarper(LogitsProcessor): """ LogitsWarper that implements Exclude Top Choices (XTC). """ def __init__(self, threshold: float, probability: float, protected_token_ids: Optional[List[int]] = None, filter_value: float = -float("Inf")): self.threshold = threshold self.probability = probability self.filter_value = filter_value self.protected_token_ids = set(protected_token_ids) if protected_token_ids is not None else set() def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: if self.probability <= 0.0 or random.random() >= self.probability: return scores # Sort scores descending sorted_logits, sorted_indices = torch.sort(scores, descending=True) probs = sorted_logits.softmax(dim=-1) # Create a mask for removal sorted_indices_to_remove = torch.full_like(probs, False, dtype=torch.bool) # XTC Logic sorted_indices_to_remove[..., :-1] = probs[..., 1:] >= self.threshold # Scatter back to original indices indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) # Safety: Check if protected tokens would be removed if self.protected_token_ids: for pid in self.protected_token_ids: if indices_to_remove[:, pid].any(): # If any protected token is targeted, abort XTC for this step return scores # Apply the filter scores = scores.masked_fill(indices_to_remove, self.filter_value) return scores def _xtc_decoding( model, input_ids: torch.LongTensor, logits_processor: LogitsProcessorList, stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, synced_gpus: bool = False, streamer: "BaseStreamer" = None, **model_kwargs, ) -> Union[GenerateDecoderOnlyOutput, torch.LongTensor]: """ Custom decoding loop that ensures XTC is applied during sampling. """ # 1. Retrieve XTC params from the config (injected by the generate wrapper) xtc_threshold = getattr(generation_config, "xtc_threshold", 0.1) xtc_probability = getattr(generation_config, "xtc_probability", 0.0) # Identify tokens to protect protected_ids = [] if generation_config.eos_token_id is not None: if isinstance(generation_config.eos_token_id, list): protected_ids.extend(generation_config.eos_token_id) else: protected_ids.append(generation_config.eos_token_id) # Check for custom protected tokens injected via config custom_protected = getattr(generation_config, "xtc_protected_tokens", None) if custom_protected: protected_ids.extend(custom_protected) # 2. Inject XTC into the LogitsProcessorList if xtc_probability > 0: xtc_warper = XTCLogitsWarper( threshold=xtc_threshold, probability=xtc_probability, protected_token_ids=protected_ids ) logits_processor.append(xtc_warper) # 3. Initialization pad_token_id = generation_config._pad_token_tensor output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores return_dict_in_generate = generation_config.return_dict_in_generate has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) # Ensure sampling is on do_sample = True # Init output tuples scores = () if (return_dict_in_generate and output_scores) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # Track finished sequences batch_size, cur_length = input_ids.shape[:2] unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) model_kwargs = model._get_initial_cache_position(cur_length, input_ids.device, model_kwargs) this_peer_finished = False # 4. Decoding Loop while model._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs) outputs = model( **model_inputs, return_dict=True, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) if synced_gpus and this_peer_finished: continue next_token_logits = outputs.logits[:, -1, :] # Apply Logits Processors (XTC happens here) next_token_scores = logits_processor(input_ids, next_token_logits) # Store scores if return_dict_in_generate and output_scores: scores += (next_token_scores,) if return_dict_in_generate and output_attentions: decoder_attentions += ((outputs.decoder_attentions,) if model.config.is_encoder_decoder else (outputs.attentions,)) if return_dict_in_generate and output_hidden_states: decoder_hidden_states += ((outputs.decoder_hidden_states,) if model.config.is_encoder_decoder else (outputs.hidden_states,)) # Sample (Multinomial) probs = nn.functional.softmax(next_token_scores, dim=-1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # EOS check if has_eos_stopping_criteria: next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # Update inputs input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) if streamer is not None: streamer.put(next_tokens.cpu()) model_kwargs = model._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=model.config.is_encoder_decoder ) unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) this_peer_finished = unfinished_sequences.max() == 0 if streamer is not None: streamer.end() if return_dict_in_generate: return GenerateDecoderOnlyOutput( sequences=input_ids, scores=scores, attentions=decoder_attentions, hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get("past_key_values"), ) else: return input_ids def generate(model, *args, **kwargs): """ Wrapper function that prepares parameters and calls the internal decoding loop. """ # 1. Extract XTC parameters from kwargs using .pop() # This prevents the "unused model_kwargs" warning because they are removed from kwargs xtc_probability = kwargs.pop("xtc_probability", 0.0) xtc_threshold = kwargs.pop("xtc_threshold", 0.1) xtc_protected_tokens = kwargs.pop("xtc_protected_tokens", None) # 2. Prepare GenerationConfig # We must handle the case where generation_config is None or not present generation_config = kwargs.get("generation_config", None) if generation_config is None: # If no config passed, copy the model's default generation_config = copy.deepcopy(model.generation_config) else: # If passed, verify it's not None if generation_config is None: generation_config = copy.deepcopy(model.generation_config) # Force sampling (XTC doesn't work with greedy) generation_config.do_sample = True # 3. Inject XTC params into the config object # Python allows dynamic attribute assignment generation_config.xtc_probability = xtc_probability generation_config.xtc_threshold = xtc_threshold generation_config.xtc_protected_tokens = xtc_protected_tokens # Update kwargs with the modified config kwargs["generation_config"] = generation_config # 4. Call standard generation, which will route to `custom_generate` (_xtc_decoding) # We pass _xtc_decoding as the function to execute return GenerationMixin.generate( model, *args, custom_generate=_xtc_decoding, **kwargs )