| from dataclasses import replace |
| import logging |
| from typing import TYPE_CHECKING, Optional, Union |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList |
| from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache |
| from transformers.configuration_utils import PretrainedConfig |
| from transformers.generation.utils import ( |
| ALL_CACHE_NAMES, |
| GenerateDecoderOnlyOutput, |
| GenerateEncoderDecoderOutput, |
| GenerateNonBeamOutput, |
| GenerationMixin, |
| ) |
| from transformers.utils import ModelOutput |
|
|
|
|
| if TYPE_CHECKING: |
| from transformers.generation.streamers import BaseStreamer |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def stack_model_outputs(model_outputs: list[ModelOutput], config: PretrainedConfig) -> ModelOutput: |
| """ |
| Stack a list of ModelOutput objects (or its subclasses) along the batch_size dimension. The function infers the |
| specific ModelOutput subclass from the list provided. |
| """ |
| if not model_outputs: |
| raise ValueError("Input list is empty.") |
|
|
| |
| model_output_cls = type(model_outputs[0]) |
|
|
| |
| if not all(isinstance(obj, model_output_cls) for obj in model_outputs): |
| raise ValueError("All elements in the list should be of the same type.") |
|
|
| |
| def _concat(data): |
| """ |
| Reverse of `_split` function above. |
| """ |
| if any(data is None for data in data): |
| return None |
| if isinstance(data[0], torch.Tensor): |
| return torch.cat(data, dim=0) |
| elif isinstance(data[0], tuple): |
| |
| if isinstance(data[0][0], tuple): |
| return tuple( |
| tuple(torch.cat([attr[i][j] for attr in data], dim=0) for j in range(len(data[0][0]))) |
| for i in range(len(data[0])) |
| ) |
| else: |
| return tuple(torch.cat([attr[i] for attr in data], dim=0) for i in range(len(data[0]))) |
| elif isinstance(data[0], (int, float)): |
| |
| return torch.tensor(data) |
| else: |
| raise TypeError(f"Unexpected attribute type: {type(data[0])}") |
|
|
| |
| concatenated_data = { |
| k: _concat([getattr(model_output, k) for model_output in model_outputs]) |
| for k in model_output_cls.__dataclass_fields__ |
| } |
|
|
| |
| return model_output_cls(**concatenated_data) |
|
|
|
|
| def _ranking_fast( |
| context_hidden: torch.FloatTensor, |
| next_hidden: torch.FloatTensor, |
| next_top_k_probs: torch.FloatTensor, |
| cosine_matrix_mask: torch.LongTensor, |
| alpha: float, |
| beam_width: int, |
| ) -> torch.FloatTensor: |
| """ |
| Reranks the top_k candidates based on a degeneration penalty (cosine similarity with previous tokens), as described |
| in the paper "A Contrastive Framework for Neural Text Generation". Returns the index of the best candidate for each |
| row in the batch. |
| """ |
| norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True) |
| norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True) |
| cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1, 2)).squeeze(-1) |
|
|
| |
| |
| cosine_matrix_mask = cosine_matrix_mask.to(dtype=cosine_matrix.dtype) |
| cosine_matrix_mask = (1 - cosine_matrix_mask) * torch.finfo(cosine_matrix.dtype).min |
| cosine_matrix = cosine_matrix + cosine_matrix_mask |
|
|
| degeneration_penalty, _ = torch.max(cosine_matrix, dim=-1) |
| next_top_k_probs = next_top_k_probs.view(-1) |
| contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty |
| contrastive_score = torch.stack(torch.split(contrastive_score, beam_width)) |
| _, selected_idx = contrastive_score.max(dim=-1) |
| return selected_idx |
|
|
|
|
| @torch.no_grad() |
| def _contrastive_search( |
| model, |
| input_ids: torch.LongTensor, |
| logits_processor: LogitsProcessorList, |
| stopping_criteria: StoppingCriteriaList, |
| generation_config: GenerationConfig, |
| synced_gpus: bool = False, |
| streamer: Optional["BaseStreamer"] = None, |
| **model_kwargs, |
| ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: |
| r""" |
| Generates sequences of token ids for models with a language modeling head using **contrastive search** and can |
| be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. |
| |
| Parameters: |
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
| The sequence used as a prompt for the generation. |
| logits_processor (`LogitsProcessorList`): |
| An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] |
| used to modify the prediction scores of the language modeling head applied at each generation step. |
| stopping_criteria (`StoppingCriteriaList`): |
| An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] |
| used to tell if the generation loop should stop. |
| generation_config ([`~generation.GenerationConfig`]): |
| The generation configuration to be used as parametrization of the decoding method. |
| synced_gpus (`bool`, *optional*, defaults to `False`): |
| Whether to continue running the while loop until max_length (needed to avoid deadlocking with |
| `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). |
| streamer (`BaseStreamer`, *optional*): |
| Streamer object that will be used to stream the generated sequences. Generated tokens are passed |
| through `streamer.put(token_ids)` and the streamer is responsible for any further processing. |
| model_kwargs: |
| Additional model specific keyword arguments will be forwarded to the `forward` function of the model. |
| If model is an encoder-decoder model the kwargs should include `encoder_outputs`. |
| |
| Return: |
| [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] |
| or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a |
| [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and |
| `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if |
| `model.config.is_encoder_decoder=True`. |
| """ |
| if not model_kwargs["use_cache"]: |
| raise ValueError("Contrastive search requires `use_cache=True`") |
| if model._is_stateful: |
| |
| raise ValueError( |
| f"contrastive search is not supported with stateful models, such as {model.__class__.__name__}" |
| ) |
| |
| has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) |
| top_k = generation_config.top_k |
| penalty_alpha = generation_config.penalty_alpha |
| pad_token_id = generation_config._pad_token_tensor |
| output_attentions = generation_config.output_attentions |
| output_hidden_states = generation_config.output_hidden_states |
| output_scores = generation_config.output_scores |
| output_logits = generation_config.output_logits |
| return_dict_in_generate = generation_config.return_dict_in_generate |
| sequential = generation_config.low_memory |
|
|
| |
| raw_logits = () if (return_dict_in_generate and output_logits) else None |
| scores = () if (return_dict_in_generate and output_scores) else None |
| decoder_attentions = () if (return_dict_in_generate and output_attentions) else None |
| cross_attentions = () if (return_dict_in_generate and output_attentions) else None |
| decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None |
|
|
| |
| if return_dict_in_generate and model.config.is_encoder_decoder: |
| encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None |
| encoder_hidden_states = model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None |
|
|
| |
| batch_size, cur_len = input_ids.shape[:2] |
| unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) |
| model_kwargs = model._get_initial_cache_position(cur_len, input_ids.device, model_kwargs) |
|
|
| |
| cosine_matrix_mask = torch.ones_like(input_ids, dtype=torch.long) |
| if model.config.is_encoder_decoder: |
| if "decoder_attention_mask" in model_kwargs and model_kwargs["decoder_attention_mask"] is not None: |
| cosine_matrix_mask = model_kwargs["decoder_attention_mask"] |
| else: |
| cosine_matrix_mask = model_kwargs["attention_mask"] |
| cosine_matrix_mask = cosine_matrix_mask.repeat_interleave(top_k, dim=0) |
|
|
| this_peer_finished = False |
|
|
| while model._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): |
| |
| |
| if model_kwargs.get("past_key_values") is None or ( |
| isinstance(model_kwargs["past_key_values"], (Cache, EncoderDecoderCache)) |
| and model_kwargs["past_key_values"].get_seq_length() == 0 |
| ): |
| |
| model_kwargs["use_cache"] = True |
| model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs) |
|
|
| |
| |
| outputs = model( |
| **model_inputs, |
| return_dict=True, |
| output_hidden_states=True, |
| output_attentions=output_attentions, |
| ) |
|
|
| |
| |
| if model.config.is_encoder_decoder: |
| last_hidden_states = outputs.decoder_hidden_states[-1] |
| else: |
| last_hidden_states = outputs.hidden_states[-1] |
|
|
| |
| |
| |
| |
| logit_for_next_step = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device) |
|
|
| model_kwargs = model._update_model_kwargs_for_generation( |
| outputs, |
| model_kwargs, |
| is_encoder_decoder=model.config.is_encoder_decoder, |
| ) |
|
|
| if not sequential: |
| |
| |
| _, model_kwargs = model._expand_inputs_for_generation( |
| input_ids=input_ids, |
| expand_size=top_k, |
| is_encoder_decoder=model.config.is_encoder_decoder, |
| **model_kwargs, |
| ) |
|
|
| past_key_values = model_kwargs.get("past_key_values") |
| if past_key_values is None: |
| raise ValueError( |
| f"{model.__class__.__name__} does not support caching and therefore **can't** be used " |
| "for contrastive search." |
| ) |
| |
| elif not ( |
| isinstance(past_key_values, DynamicCache) |
| or ( |
| isinstance(past_key_values, EncoderDecoderCache) |
| and isinstance(past_key_values.self_attention_cache, DynamicCache) |
| ) |
| ): |
| raise ValueError( |
| f"Unsupported cache type: {type(outputs['past_key_values'])}. Contrastive search requires " |
| "dynamic cache, so set `cache_implementation='dynamic'` in the generation config." |
| ) |
|
|
| |
| |
| |
| processed_logit_for_next_step = logits_processor(input_ids, logit_for_next_step) |
| next_probs = nn.functional.softmax(processed_logit_for_next_step, dim=-1) |
|
|
| top_k_probs, top_k_ids = torch.topk(next_probs, dim=-1, k=top_k) |
|
|
| |
| if return_dict_in_generate: |
| if output_logits: |
| raw_logits += (logit_for_next_step,) |
| if output_scores: |
| scores += (processed_logit_for_next_step,) |
| if output_attentions: |
| decoder_attentions += ( |
| (outputs.decoder_attentions,) if model.config.is_encoder_decoder else (outputs.attentions,) |
| ) |
| if model.config.is_encoder_decoder: |
| cross_attentions += (outputs.cross_attentions,) |
|
|
| if output_hidden_states: |
| decoder_hidden_states += ( |
| (outputs.decoder_hidden_states,) if model.config.is_encoder_decoder else (outputs.hidden_states,) |
| ) |
|
|
| |
| |
| del outputs |
|
|
| if not sequential: |
| |
| model_kwargs["past_key_values"].batch_repeat_interleave(top_k) |
|
|
| if sequential: |
| all_outputs = [] |
| for i in range(top_k): |
| |
| next_model_inputs = model.prepare_inputs_for_generation(top_k_ids[:, i].view(-1, 1), **model_kwargs) |
|
|
| outputs = model( |
| **next_model_inputs, |
| return_dict=True, |
| output_hidden_states=True, |
| output_attentions=output_attentions, |
| ) |
| |
| outputs["past_key_values"] = None |
| |
| model_kwargs["past_key_values"].crop(-1) |
|
|
| all_outputs.append(outputs) |
| outputs = stack_model_outputs(all_outputs, model.config.get_text_config()) |
|
|
| else: |
| |
| |
| next_model_inputs = model.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs) |
|
|
| outputs = model( |
| **next_model_inputs, |
| return_dict=True, |
| output_hidden_states=True, |
| output_attentions=output_attentions, |
| ) |
|
|
| |
| |
| del next_model_inputs |
|
|
| |
| if model.config.is_encoder_decoder: |
| next_hidden = outputs.decoder_hidden_states[-1] |
| full_hidden_states = outputs.decoder_hidden_states |
| else: |
| next_hidden = outputs.hidden_states[-1] |
| full_hidden_states = outputs.hidden_states |
|
|
| |
| logits = outputs.logits[:, -1, :].float() |
| context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) |
|
|
| |
| |
| |
| selected_idx = _ranking_fast( |
| context_hidden, |
| next_hidden, |
| top_k_probs, |
| cosine_matrix_mask, |
| penalty_alpha, |
| top_k, |
| ) |
| cosine_matrix_mask = torch.cat( |
| [ |
| cosine_matrix_mask, |
| cosine_matrix_mask.new_ones((cosine_matrix_mask.shape[0], 1)), |
| ], |
| dim=-1, |
| ) |
| selected_idx = selected_idx.to("cpu") |
|
|
| |
| augmented_idx = torch.tensor([x + i * top_k for i, x in enumerate(selected_idx)]) |
|
|
| |
| |
| |
| next_tokens = top_k_ids[range(len(top_k_ids)), selected_idx] |
| next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), top_k)) |
| next_hidden = next_hidden[range(batch_size), selected_idx, :] |
| last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1) |
|
|
| next_decoder_hidden_states = () |
| for layer in full_hidden_states: |
| layer = torch.stack(torch.split(layer, top_k))[range(batch_size), selected_idx, :] |
| next_decoder_hidden_states += (layer,) |
|
|
| |
| if sequential: |
| next_model_input = model.prepare_inputs_for_generation( |
| top_k_ids[:, selected_idx].view(-1, 1), **model_kwargs |
| ) |
|
|
| selected_outputs = model( |
| **next_model_input, |
| return_dict=True, |
| output_hidden_states=False, |
| output_attentions=False, |
| ) |
| next_past_key_values = selected_outputs["past_key_values"] |
|
|
| else: |
| next_past_key_values = None |
| for possible_cache_name in ALL_CACHE_NAMES: |
| next_past_key_values = next_past_key_values or getattr(outputs, possible_cache_name, None) |
| next_past_key_values.batch_select_indices(augmented_idx) |
|
|
| logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :] |
| logit_for_next_step = logit_for_next_step.to(input_ids.device) |
|
|
| |
| if model.config.is_encoder_decoder: |
| next_step_cross_attentions = () |
| next_step_decoder_attentions = () |
| if output_attentions: |
| for layer in outputs.cross_attentions: |
| layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...] |
| next_step_cross_attentions += (layer,) |
| for layer in outputs.decoder_attentions: |
| layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...] |
| next_step_decoder_attentions += (layer,) |
| outputs = replace( |
| outputs, |
| past_key_values=next_past_key_values, |
| decoder_hidden_states=next_decoder_hidden_states, |
| decoder_attentions=next_step_decoder_attentions or None, |
| cross_attentions=next_step_cross_attentions or None, |
| ) |
| else: |
| next_step_attentions = () |
| if output_attentions: |
| for layer in outputs.attentions: |
| layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...] |
| next_step_attentions += (layer,) |
| outputs = replace( |
| outputs, |
| past_key_values=next_past_key_values, |
| hidden_states=next_decoder_hidden_states, |
| attentions=next_step_attentions or None, |
| ) |
|
|
| |
|
|
| |
| model_kwargs = model._update_model_kwargs_for_generation( |
| outputs, |
| model_kwargs, |
| is_encoder_decoder=model.config.is_encoder_decoder, |
| ) |
| if synced_gpus and this_peer_finished: |
| continue |
|
|
| |
| if has_eos_stopping_criteria: |
| next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) |
|
|
| |
| input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) |
| if streamer is not None: |
| streamer.put(next_tokens.cpu()) |
|
|
| |
| unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) |
| this_peer_finished = unfinished_sequences.max() == 0 |
|
|
| if streamer is not None: |
| streamer.end() |
|
|
| if return_dict_in_generate: |
| |
| |
| if model_kwargs.get("past_key_values") is not None: |
| model_kwargs["past_key_values"].crop(-1) |
|
|
| if model.config.is_encoder_decoder: |
| return GenerateEncoderDecoderOutput( |
| sequences=input_ids, |
| scores=scores, |
| logits=raw_logits, |
| encoder_attentions=encoder_attentions, |
| encoder_hidden_states=encoder_hidden_states, |
| decoder_attentions=decoder_attentions, |
| cross_attentions=cross_attentions, |
| decoder_hidden_states=decoder_hidden_states, |
| past_key_values=model_kwargs.get("past_key_values"), |
| ) |
| else: |
| return GenerateDecoderOnlyOutput( |
| sequences=input_ids, |
| scores=scores, |
| logits=raw_logits, |
| attentions=decoder_attentions, |
| hidden_states=decoder_hidden_states, |
| past_key_values=model_kwargs.get("past_key_values"), |
| ) |
| else: |
| return input_ids |
|
|
|
|
| def generate(model, *args, **kwargs): |
| """Custom generate function for Contrastive Search decoding. |
| Args: |
| model (`PreTrainedModel`): |
| The model to generate from. |
| penalty_alpha (`float`): The alpha value for the degeneration penalty. |
| top_k (`int`): The number of candidates to consider at each step. |
| """ |
| cache_implementation = kwargs.pop("cache_implementation", "dynamic_full") |
| if cache_implementation != "dynamic_full" and ( |
| "sliding_attention" in getattr(model.config.get_text_config(), "layer_types", []) |
| or getattr(model.config.get_text_config(), "sliding_window", 0) > 0 |
| ): |
| logger.warning_once( |
| "Contrastive search with sliding window attention requires `cache_implementation='dynamic_full'`. " |
| "Using other cache types may break rollback and cause incorrect results." |
| ) |
|
|
| generation_outputs = GenerationMixin.generate( |
| model, |
| *args, |
| custom_generate=_contrastive_search, |
| cache_implementation=cache_implementation, |
| **kwargs, |
| ) |
| return generation_outputs |
|
|