| | from typing import Union, Optional, TYPE_CHECKING |
| | import torch |
| | from transformers import LogitsProcessorList, StoppingCriteriaList, GenerationConfig |
| | from transformers.generation.utils import ( |
| | GenerateBeamOutput, |
| | GenerationMixin, |
| | GenerateBeamDecoderOnlyOutput, |
| | GenerateBeamEncoderDecoderOutput, |
| | ) |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import numpy as np |
| | import logging |
| |
|
| | from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint |
| | from .beam_search import ConstrainedBeamSearchScorer |
| |
|
| | if TYPE_CHECKING: |
| | from transformers.generation.streamers import BaseStreamer |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | def _constrained_beam_search( |
| | model, |
| | input_ids: torch.LongTensor, |
| | logits_processor: LogitsProcessorList, |
| | stopping_criteria: StoppingCriteriaList, |
| | generation_config: GenerationConfig, |
| | synced_gpus: bool = False, |
| | streamer: Optional["BaseStreamer"] = None, |
| | **model_kwargs, |
| | ) -> Union[GenerateBeamOutput, torch.LongTensor]: |
| | r""" |
| | Generates sequences of token ids for models with a language modeling head using **constrained beam search |
| | decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. |
| | |
| | Parameters: |
| | input_ids (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`): |
| | The sequence used as a prompt for the generation. |
| | logits_processor (`LogitsProcessorList`): |
| | An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] |
| | used to modify the prediction scores of the language modeling head applied at each generation step. |
| | stopping_criteria (`StoppingCriteriaList`): |
| | An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] |
| | used to tell if the generation loop should stop. |
| | generation_config ([`~generation.GenerationConfig`]): |
| | The generation configuration to be used as parametrization of the decoding method. |
| | synced_gpus (`bool`): |
| | Whether to continue running the while loop until max_length (needed to avoid deadlocking with |
| | `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). |
| | model_kwargs: |
| | Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is |
| | an encoder-decoder model the kwargs should include `encoder_outputs`. |
| | |
| | Return: |
| | [`~generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or |
| | `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a |
| | [`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and |
| | `return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if |
| | `model.config.is_encoder_decoder=True`. |
| | """ |
| | if generation_config.constraints is not None or generation_config.force_words_ids is not None: |
| | constrained_wrong_parameter_msg = ( |
| | "one of `constraints`, `force_words_ids` is not `None`, triggering constrained beam search. " |
| | "However, `{flag_name}` is set to `{flag_value}`, which is incompatible with this generation " |
| | "mode. Set `constraints` and `force_words_ids` to `None` or unset `{flag_name}` to continue." |
| | ) |
| | if generation_config.do_sample is True: |
| | raise ValueError( |
| | constrained_wrong_parameter_msg.format(flag_name="do_sample", flag_value=generation_config.do_sample) |
| | ) |
| | |
| | final_constraints = [] |
| | if generation_config.constraints is not None: |
| | final_constraints = generation_config.constraints |
| |
|
| | if generation_config.force_words_ids is not None: |
| |
|
| | def typeerror(): |
| | raise ValueError( |
| | "`force_words_ids` has to either be a `list[list[list[int]]]` or `list[list[int]]` " |
| | f"of positive integers, but is {generation_config.force_words_ids}." |
| | ) |
| |
|
| | if ( |
| | not isinstance(generation_config.force_words_ids, list) |
| | or len(generation_config.force_words_ids) == 0 |
| | ): |
| | typeerror() |
| |
|
| | for word_ids in generation_config.force_words_ids: |
| | if isinstance(word_ids[0], list): |
| | if not isinstance(word_ids, list) or len(word_ids) == 0: |
| | typeerror() |
| | if any(not isinstance(token_ids, list) for token_ids in word_ids): |
| | typeerror() |
| | if any( |
| | any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids) |
| | for token_ids in word_ids |
| | ): |
| | typeerror() |
| |
|
| | constraint = DisjunctiveConstraint(word_ids) |
| | else: |
| | if not isinstance(word_ids, list) or len(word_ids) == 0: |
| | typeerror() |
| | if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids): |
| | typeerror() |
| |
|
| | constraint = PhrasalConstraint(word_ids) |
| | final_constraints.append(constraint) |
| | |
| | constrained_beam_scorer = ConstrainedBeamSearchScorer( |
| | constraints=final_constraints, |
| | batch_size=input_ids.shape[0] // generation_config.num_beams, |
| | num_beams=generation_config.num_beams, |
| | device=input_ids.device, |
| | length_penalty=generation_config.length_penalty, |
| | do_early_stopping=generation_config.early_stopping, |
| | num_beam_hyps_to_keep=generation_config.num_return_sequences, |
| | max_length=generation_config.max_length, |
| | ) |
| | |
| | pad_token_id = generation_config._pad_token_tensor |
| | eos_token_id = generation_config._eos_token_tensor |
| | output_attentions = generation_config.output_attentions |
| | output_hidden_states = generation_config.output_hidden_states |
| | output_scores = generation_config.output_scores |
| | output_logits = generation_config.output_logits |
| | return_dict_in_generate = generation_config.return_dict_in_generate |
| |
|
| | batch_size = len(constrained_beam_scorer._beam_hyps) |
| | num_beams = constrained_beam_scorer.num_beams |
| |
|
| | batch_beam_size, cur_len = input_ids.shape[:2] |
| | model_kwargs = model._get_initial_cache_position(cur_len, input_ids.device, model_kwargs) |
| |
|
| | if num_beams * batch_size != batch_beam_size: |
| | raise ValueError( |
| | f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." |
| | ) |
| |
|
| | |
| | scores = () if (return_dict_in_generate and output_scores) else None |
| | raw_logits = () if (return_dict_in_generate and output_logits) else None |
| | beam_indices = ( |
| | tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None |
| | ) |
| | decoder_attentions = () if (return_dict_in_generate and output_attentions) else None |
| | cross_attentions = () if (return_dict_in_generate and output_attentions) else None |
| | decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None |
| |
|
| | |
| | if return_dict_in_generate and model.config.is_encoder_decoder: |
| | encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None |
| | encoder_hidden_states = ( |
| | model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None |
| | ) |
| |
|
| | |
| | |
| | beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) |
| | beam_scores[:, 1:] = -1e9 |
| | beam_scores = beam_scores.view((batch_size * num_beams,)) |
| |
|
| | this_peer_finished = False |
| |
|
| | decoder_prompt_len = input_ids.shape[1] |
| | while model._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): |
| | model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs) |
| |
|
| | |
| | model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) |
| | model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) |
| |
|
| | outputs = model(**model_inputs, return_dict=True) |
| |
|
| | |
| | model_kwargs = model._update_model_kwargs_for_generation( |
| | outputs, |
| | model_kwargs, |
| | is_encoder_decoder=model.config.is_encoder_decoder, |
| | ) |
| | if synced_gpus and this_peer_finished: |
| | cur_len = cur_len + 1 |
| | continue |
| |
|
| | |
| | |
| | |
| | next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device) |
| | next_token_scores = nn.functional.log_softmax( |
| | next_token_logits, dim=-1 |
| | ) |
| |
|
| | next_token_scores_processed = logits_processor(input_ids, next_token_scores) |
| |
|
| | next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( |
| | next_token_scores_processed |
| | ) |
| |
|
| | scores_for_all_vocab = next_token_scores.clone() |
| |
|
| | |
| | if return_dict_in_generate: |
| | if output_scores: |
| | scores += (next_token_scores,) |
| | if output_logits: |
| | raw_logits += (next_token_logits,) |
| | if output_attentions: |
| | decoder_attentions += ( |
| | (outputs.decoder_attentions,) if model.config.is_encoder_decoder else (outputs.attentions,) |
| | ) |
| | if model.config.is_encoder_decoder: |
| | cross_attentions += (outputs.cross_attentions,) |
| |
|
| | if output_hidden_states: |
| | decoder_hidden_states += ( |
| | (outputs.decoder_hidden_states,) |
| | if model.config.is_encoder_decoder |
| | else (outputs.hidden_states,) |
| | ) |
| |
|
| | |
| | vocab_size = next_token_scores.shape[-1] |
| | next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) |
| |
|
| | |
| | n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0 |
| | next_token_scores, next_tokens = torch.topk( |
| | next_token_scores, max(2, 1 + n_eos_tokens) * num_beams, dim=1, largest=True, sorted=True |
| | ) |
| |
|
| | next_indices = (next_tokens / vocab_size).long() |
| | next_tokens = next_tokens % vocab_size |
| |
|
| | |
| | beam_outputs = constrained_beam_scorer.process( |
| | input_ids, |
| | next_token_scores, |
| | next_tokens, |
| | next_indices, |
| | scores_for_all_vocab, |
| | pad_token_id=pad_token_id, |
| | eos_token_id=eos_token_id, |
| | beam_indices=beam_indices, |
| | decoder_prompt_len=decoder_prompt_len, |
| | ) |
| | beam_scores = beam_outputs["next_beam_scores"] |
| | beam_next_tokens = beam_outputs["next_beam_tokens"] |
| | beam_idx = beam_outputs["next_beam_indices"] |
| |
|
| | input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) |
| |
|
| | |
| | |
| | |
| | |
| | del outputs |
| |
|
| | |
| | if model_kwargs.get("past_key_values", None) is not None: |
| | if hasattr(model, "_reorder_cache"): |
| | model_kwargs["past_key_values"] = model._reorder_cache(model_kwargs["past_key_values"], beam_idx) |
| | else: |
| | model_kwargs["past_key_values"].reorder_cache(beam_idx) |
| |
|
| | if return_dict_in_generate and output_scores: |
| | beam_indices = tuple(beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))) |
| |
|
| | |
| | cur_len = cur_len + 1 |
| |
|
| | if constrained_beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): |
| | this_peer_finished = True |
| |
|
| | sequence_outputs = constrained_beam_scorer.finalize( |
| | input_ids, |
| | beam_scores, |
| | next_tokens, |
| | next_indices, |
| | pad_token_id=pad_token_id, |
| | eos_token_id=eos_token_id, |
| | max_length=stopping_criteria.max_length, |
| | beam_indices=beam_indices, |
| | decoder_prompt_len=decoder_prompt_len, |
| | ) |
| |
|
| | if return_dict_in_generate: |
| | if not output_scores: |
| | sequence_outputs["sequence_scores"] = None |
| | if model.config.is_encoder_decoder: |
| | return GenerateBeamEncoderDecoderOutput( |
| | sequences=sequence_outputs["sequences"], |
| | sequences_scores=sequence_outputs["sequence_scores"], |
| | scores=scores, |
| | logits=raw_logits, |
| | beam_indices=sequence_outputs["beam_indices"], |
| | encoder_attentions=encoder_attentions, |
| | encoder_hidden_states=encoder_hidden_states, |
| | decoder_attentions=decoder_attentions, |
| | cross_attentions=cross_attentions, |
| | decoder_hidden_states=decoder_hidden_states, |
| | past_key_values=model_kwargs.get("past_key_values"), |
| | ) |
| | else: |
| | return GenerateBeamDecoderOnlyOutput( |
| | sequences=sequence_outputs["sequences"], |
| | sequences_scores=sequence_outputs["sequence_scores"], |
| | scores=scores, |
| | logits=raw_logits, |
| | beam_indices=sequence_outputs["beam_indices"], |
| | attentions=decoder_attentions, |
| | hidden_states=decoder_hidden_states, |
| | past_key_values=model_kwargs.get("past_key_values"), |
| | ) |
| | else: |
| | return sequence_outputs["sequences"] |
| |
|
| | def generate(model, *args, **kwargs): |
| | """Custom generate function for constrained beam search decoding. |
| | Args: |
| | model (`PreTrainedModel`): |
| | The model to generate from. |
| | num_beams (`int`): The number of beams to use for beam search. |
| | constraints (`list[Constraint]`, *optional*): |
| | Custom constraints that can be added to the generation to ensure that the output will contain the use of |
| | certain tokens as defined by `Constraint` objects, in the most sensible way possible. |
| | force_words_ids (`list[list[list[int]]]`): List of token ids that must be generated. If given a `list[list[int]]`, this is treated as a simple list of |
| | words that must be included, the opposite to `bad_words_ids`. If given `list[list[list[int]]]`, this |
| | triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), where one |
| | can allow different forms of each word. |
| | length_penalty (`float`): The length penalty to use for beam search. |
| | early_stopping (`bool`): Whether to stop beam search when sufficient beams have finished. |
| | num_return_sequences (`int`): The number of sequences to return. |
| | max_length (`int`): The maximum length of the generated sequence. |
| | """ |
| | generation_outputs = GenerationMixin.generate( |
| | model, *args, custom_generate=_constrained_beam_search, **kwargs |
| | ) |
| | return generation_outputs |
| |
|