|
|
import torch |
|
|
import torch.nn as nn |
|
|
import random |
|
|
import logging |
|
|
import copy |
|
|
from typing import Union, List, Optional |
|
|
from transformers import LogitsProcessor, LogitsProcessorList, StoppingCriteriaList, GenerationConfig |
|
|
from transformers.generation.utils import GenerationMixin, GenerateDecoderOnlyOutput |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class XTCLogitsWarper(LogitsProcessor): |
|
|
""" |
|
|
LogitsWarper that implements Exclude Top Choices (XTC). |
|
|
""" |
|
|
def __init__(self, threshold: float, probability: float, protected_token_ids: Optional[List[int]] = None, filter_value: float = -float("Inf")): |
|
|
self.threshold = threshold |
|
|
self.probability = probability |
|
|
self.filter_value = filter_value |
|
|
self.protected_token_ids = set(protected_token_ids) if protected_token_ids is not None else set() |
|
|
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
|
|
if self.probability <= 0.0 or random.random() >= self.probability: |
|
|
return scores |
|
|
|
|
|
|
|
|
sorted_logits, sorted_indices = torch.sort(scores, descending=True) |
|
|
probs = sorted_logits.softmax(dim=-1) |
|
|
|
|
|
|
|
|
sorted_indices_to_remove = torch.full_like(probs, False, dtype=torch.bool) |
|
|
|
|
|
|
|
|
sorted_indices_to_remove[..., :-1] = probs[..., 1:] >= self.threshold |
|
|
|
|
|
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
|
|
|
|
|
|
|
|
if self.protected_token_ids: |
|
|
for pid in self.protected_token_ids: |
|
|
if indices_to_remove[:, pid].any(): |
|
|
|
|
|
return scores |
|
|
|
|
|
|
|
|
scores = scores.masked_fill(indices_to_remove, self.filter_value) |
|
|
return scores |
|
|
|
|
|
def _xtc_decoding( |
|
|
model, |
|
|
input_ids: torch.LongTensor, |
|
|
logits_processor: LogitsProcessorList, |
|
|
stopping_criteria: StoppingCriteriaList, |
|
|
generation_config: GenerationConfig, |
|
|
synced_gpus: bool = False, |
|
|
streamer: "BaseStreamer" = None, |
|
|
**model_kwargs, |
|
|
) -> Union[GenerateDecoderOnlyOutput, torch.LongTensor]: |
|
|
""" |
|
|
Custom decoding loop that ensures XTC is applied during sampling. |
|
|
""" |
|
|
|
|
|
|
|
|
xtc_threshold = getattr(generation_config, "xtc_threshold", 0.1) |
|
|
xtc_probability = getattr(generation_config, "xtc_probability", 0.0) |
|
|
|
|
|
|
|
|
protected_ids = [] |
|
|
if generation_config.eos_token_id is not None: |
|
|
if isinstance(generation_config.eos_token_id, list): |
|
|
protected_ids.extend(generation_config.eos_token_id) |
|
|
else: |
|
|
protected_ids.append(generation_config.eos_token_id) |
|
|
|
|
|
|
|
|
custom_protected = getattr(generation_config, "xtc_protected_tokens", None) |
|
|
if custom_protected: |
|
|
protected_ids.extend(custom_protected) |
|
|
|
|
|
|
|
|
if xtc_probability > 0: |
|
|
xtc_warper = XTCLogitsWarper( |
|
|
threshold=xtc_threshold, |
|
|
probability=xtc_probability, |
|
|
protected_token_ids=protected_ids |
|
|
) |
|
|
logits_processor.append(xtc_warper) |
|
|
|
|
|
|
|
|
pad_token_id = generation_config._pad_token_tensor |
|
|
output_attentions = generation_config.output_attentions |
|
|
output_hidden_states = generation_config.output_hidden_states |
|
|
output_scores = generation_config.output_scores |
|
|
return_dict_in_generate = generation_config.return_dict_in_generate |
|
|
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) |
|
|
|
|
|
|
|
|
do_sample = True |
|
|
|
|
|
|
|
|
scores = () if (return_dict_in_generate and output_scores) else None |
|
|
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None |
|
|
cross_attentions = () if (return_dict_in_generate and output_attentions) else None |
|
|
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None |
|
|
|
|
|
|
|
|
batch_size, cur_length = input_ids.shape[:2] |
|
|
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) |
|
|
model_kwargs = model._get_initial_cache_position(cur_length, input_ids.device, model_kwargs) |
|
|
|
|
|
this_peer_finished = False |
|
|
|
|
|
|
|
|
while model._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): |
|
|
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs) |
|
|
|
|
|
outputs = model( |
|
|
**model_inputs, |
|
|
return_dict=True, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
) |
|
|
|
|
|
if synced_gpus and this_peer_finished: |
|
|
continue |
|
|
|
|
|
next_token_logits = outputs.logits[:, -1, :] |
|
|
|
|
|
|
|
|
next_token_scores = logits_processor(input_ids, next_token_logits) |
|
|
|
|
|
|
|
|
if return_dict_in_generate and output_scores: |
|
|
scores += (next_token_scores,) |
|
|
if return_dict_in_generate and output_attentions: |
|
|
decoder_attentions += ((outputs.decoder_attentions,) if model.config.is_encoder_decoder else (outputs.attentions,)) |
|
|
if return_dict_in_generate and output_hidden_states: |
|
|
decoder_hidden_states += ((outputs.decoder_hidden_states,) if model.config.is_encoder_decoder else (outputs.hidden_states,)) |
|
|
|
|
|
|
|
|
probs = nn.functional.softmax(next_token_scores, dim=-1) |
|
|
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) |
|
|
|
|
|
|
|
|
if has_eos_stopping_criteria: |
|
|
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) |
|
|
|
|
|
|
|
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) |
|
|
|
|
|
if streamer is not None: |
|
|
streamer.put(next_tokens.cpu()) |
|
|
|
|
|
model_kwargs = model._update_model_kwargs_for_generation( |
|
|
outputs, model_kwargs, is_encoder_decoder=model.config.is_encoder_decoder |
|
|
) |
|
|
|
|
|
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) |
|
|
this_peer_finished = unfinished_sequences.max() == 0 |
|
|
|
|
|
if streamer is not None: |
|
|
streamer.end() |
|
|
|
|
|
if return_dict_in_generate: |
|
|
return GenerateDecoderOnlyOutput( |
|
|
sequences=input_ids, |
|
|
scores=scores, |
|
|
attentions=decoder_attentions, |
|
|
hidden_states=decoder_hidden_states, |
|
|
past_key_values=model_kwargs.get("past_key_values"), |
|
|
) |
|
|
else: |
|
|
return input_ids |
|
|
|
|
|
def generate(model, *args, **kwargs): |
|
|
""" |
|
|
Wrapper function that prepares parameters and calls the internal decoding loop. |
|
|
""" |
|
|
|
|
|
|
|
|
xtc_probability = kwargs.pop("xtc_probability", 0.0) |
|
|
xtc_threshold = kwargs.pop("xtc_threshold", 0.1) |
|
|
xtc_protected_tokens = kwargs.pop("xtc_protected_tokens", None) |
|
|
|
|
|
|
|
|
|
|
|
generation_config = kwargs.get("generation_config", None) |
|
|
|
|
|
if generation_config is None: |
|
|
|
|
|
generation_config = copy.deepcopy(model.generation_config) |
|
|
else: |
|
|
|
|
|
if generation_config is None: |
|
|
generation_config = copy.deepcopy(model.generation_config) |
|
|
|
|
|
|
|
|
generation_config.do_sample = True |
|
|
|
|
|
|
|
|
|
|
|
generation_config.xtc_probability = xtc_probability |
|
|
generation_config.xtc_threshold = xtc_threshold |
|
|
generation_config.xtc_protected_tokens = xtc_protected_tokens |
|
|
|
|
|
|
|
|
kwargs["generation_config"] = generation_config |
|
|
|
|
|
|
|
|
|
|
|
return GenerationMixin.generate( |
|
|
model, *args, custom_generate=_xtc_decoding, **kwargs |
|
|
) |