| | <!--Copyright 2020 The HuggingFace Team. All rights reserved. |
| |
|
| | Licensed under the Apache License, Version 2.0 (the "License") |
| | the License. You may obtain a copy of the License at |
| |
|
| | http://www.apache.org/licenses/LICENSE-2.0 |
| |
|
| | Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on |
| | an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the |
| | specific language governing permissions and limitations under the License. |
| | --> |
| |
|
| | # Utilities for Generation |
| |
|
| | This page lists all the utility functions used by [`~generation.GenerationMixin.generate`], |
| | [`~generation.GenerationMixin.greedy_search`], |
| | [`~generation.GenerationMixin.contrastive_search`], |
| | [`~generation.GenerationMixin.sample`], |
| | [`~generation.GenerationMixin.beam_search`], |
| | [`~generation.GenerationMixin.beam_sample`], |
| | [`~generation.GenerationMixin.group_beam_search`], and |
| | [`~generation.GenerationMixin.constrained_beam_search`]. |
| |
|
| | Most of those are only useful if you are studying the code of the generate methods in the library. |
| |
|
| | ## Generate Outputs |
| |
|
| | The output of [`~generation.GenerationMixin.generate`] is an instance of a subclass of |
| | [`~utils.ModelOutput`]. This output is a data structure containing all the information returned |
| | by [`~generation.GenerationMixin.generate`], but that can also be used as tuple or dictionary. |
| |
|
| | Here's an example: |
| |
|
| | ```python |
| | from transformers import GPT2Tokenizer, GPT2LMHeadModel |
| |
|
| | tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
| | model = GPT2LMHeadModel.from_pretrained("gpt2") |
| |
|
| | inputs = tokenizer("Hello, my dog is cute and ", return_tensors="pt") |
| | generation_output = model.generate(**inputs, return_dict_in_generate=True, output_scores=True) |
| | ``` |
| |
|
| | The `generation_output` object is a [`~generation.GreedySearchDecoderOnlyOutput`], as we can |
| | see in the documentation of that class below, it means it has the following attributes: |
| |
|
| | - `sequences`: the generated sequences of tokens |
| | - `scores` (optional): the prediction scores of the language modelling head, for each generation step |
| | - `hidden_states` (optional): the hidden states of the model, for each generation step |
| | - `attentions` (optional): the attention weights of the model, for each generation step |
| |
|
| | Here we have the `scores` since we passed along `output_scores=True`, but we don't have `hidden_states` and |
| | `attentions` because we didn't pass `output_hidden_states=True` or `output_attentions=True`. |
| |
|
| | You can access each attribute as you would usually do, and if that attribute has not been returned by the model, you |
| | will get `None`. Here for instance `generation_output.scores` are all the generated prediction scores of the |
| | language modeling head, and `generation_output.attentions` is `None`. |
| |
|
| | When using our `generation_output` object as a tuple, it only keeps the attributes that don't have `None` values. |
| | Here, for instance, it has two elements, `loss` then `logits`, so |
| |
|
| | ```python |
| | generation_output[:2] |
| | ``` |
| |
|
| | will return the tuple `(generation_output.sequences, generation_output.scores)` for instance. |
| |
|
| | When using our `generation_output` object as a dictionary, it only keeps the attributes that don't have `None` |
| | values. Here, for instance, it has two keys that are `sequences` and `scores`. |
| |
|
| | We document here all output types. |
| |
|
| |
|
| | ### GreedySearchOutput |
| |
|
| | [[autodoc]] generation.GreedySearchDecoderOnlyOutput |
| |
|
| | [[autodoc]] generation.GreedySearchEncoderDecoderOutput |
| |
|
| | [[autodoc]] generation.FlaxGreedySearchOutput |
| |
|
| | ### SampleOutput |
| |
|
| | [[autodoc]] generation.SampleDecoderOnlyOutput |
| |
|
| | [[autodoc]] generation.SampleEncoderDecoderOutput |
| |
|
| | [[autodoc]] generation.FlaxSampleOutput |
| |
|
| | ### BeamSearchOutput |
| |
|
| | [[autodoc]] generation.BeamSearchDecoderOnlyOutput |
| |
|
| | [[autodoc]] generation.BeamSearchEncoderDecoderOutput |
| |
|
| | ### BeamSampleOutput |
| |
|
| | [[autodoc]] generation.BeamSampleDecoderOnlyOutput |
| |
|
| | [[autodoc]] generation.BeamSampleEncoderDecoderOutput |
| |
|
| | ## LogitsProcessor |
| |
|
| | A [`LogitsProcessor`] can be used to modify the prediction scores of a language model head for |
| | generation. |
| |
|
| | [[autodoc]] LogitsProcessor |
| | - __call__ |
| |
|
| | [[autodoc]] LogitsProcessorList |
| | - __call__ |
| |
|
| | [[autodoc]] LogitsWarper |
| | - __call__ |
| |
|
| | [[autodoc]] MinLengthLogitsProcessor |
| | - __call__ |
| |
|
| | [[autodoc]] MinNewTokensLengthLogitsProcessor |
| | - __call__ |
| |
|
| | [[autodoc]] TemperatureLogitsWarper |
| | - __call__ |
| |
|
| | [[autodoc]] RepetitionPenaltyLogitsProcessor |
| | - __call__ |
| |
|
| | [[autodoc]] TopPLogitsWarper |
| | - __call__ |
| |
|
| | [[autodoc]] TopKLogitsWarper |
| | - __call__ |
| |
|
| | [[autodoc]] TypicalLogitsWarper |
| | - __call__ |
| |
|
| | [[autodoc]] NoRepeatNGramLogitsProcessor |
| | - __call__ |
| |
|
| | [[autodoc]] NoBadWordsLogitsProcessor |
| | - __call__ |
| |
|
| | [[autodoc]] PrefixConstrainedLogitsProcessor |
| | - __call__ |
| |
|
| | [[autodoc]] HammingDiversityLogitsProcessor |
| | - __call__ |
| |
|
| | [[autodoc]] ForcedBOSTokenLogitsProcessor |
| | - __call__ |
| |
|
| | [[autodoc]] ForcedEOSTokenLogitsProcessor |
| | - __call__ |
| |
|
| | [[autodoc]] InfNanRemoveLogitsProcessor |
| | - __call__ |
| |
|
| | [[autodoc]] TFLogitsProcessor |
| | - __call__ |
| |
|
| | [[autodoc]] TFLogitsProcessorList |
| | - __call__ |
| |
|
| | [[autodoc]] TFLogitsWarper |
| | - __call__ |
| |
|
| | [[autodoc]] TFTemperatureLogitsWarper |
| | - __call__ |
| |
|
| | [[autodoc]] TFTopPLogitsWarper |
| | - __call__ |
| |
|
| | [[autodoc]] TFTopKLogitsWarper |
| | - __call__ |
| |
|
| | [[autodoc]] TFMinLengthLogitsProcessor |
| | - __call__ |
| |
|
| | [[autodoc]] TFNoBadWordsLogitsProcessor |
| | - __call__ |
| |
|
| | [[autodoc]] TFNoRepeatNGramLogitsProcessor |
| | - __call__ |
| |
|
| | [[autodoc]] TFRepetitionPenaltyLogitsProcessor |
| | - __call__ |
| |
|
| | [[autodoc]] TFForcedBOSTokenLogitsProcessor |
| | - __call__ |
| |
|
| | [[autodoc]] TFForcedEOSTokenLogitsProcessor |
| | - __call__ |
| |
|
| | [[autodoc]] FlaxLogitsProcessor |
| | - __call__ |
| |
|
| | [[autodoc]] FlaxLogitsProcessorList |
| | - __call__ |
| |
|
| | [[autodoc]] FlaxLogitsWarper |
| | - __call__ |
| |
|
| | [[autodoc]] FlaxTemperatureLogitsWarper |
| | - __call__ |
| |
|
| | [[autodoc]] FlaxTopPLogitsWarper |
| | - __call__ |
| |
|
| | [[autodoc]] FlaxTopKLogitsWarper |
| | - __call__ |
| |
|
| | [[autodoc]] FlaxForcedBOSTokenLogitsProcessor |
| | - __call__ |
| |
|
| | [[autodoc]] FlaxForcedEOSTokenLogitsProcessor |
| | - __call__ |
| |
|
| | [[autodoc]] FlaxMinLengthLogitsProcessor |
| | - __call__ |
| |
|
| | ## StoppingCriteria |
| |
|
| | A [`StoppingCriteria`] can be used to change when to stop generation (other than EOS token). |
| |
|
| | [[autodoc]] StoppingCriteria |
| | - __call__ |
| |
|
| | [[autodoc]] StoppingCriteriaList |
| | - __call__ |
| |
|
| | [[autodoc]] MaxLengthCriteria |
| | - __call__ |
| |
|
| | [[autodoc]] MaxTimeCriteria |
| | - __call__ |
| |
|
| | ## Constraints |
| |
|
| | A [`Constraint`] can be used to force the generation to include specific tokens or sequences in the output. |
| |
|
| | [[autodoc]] Constraint |
| |
|
| | [[autodoc]] PhrasalConstraint |
| |
|
| | [[autodoc]] DisjunctiveConstraint |
| |
|
| | [[autodoc]] ConstraintListState |
| |
|
| | ## BeamSearch |
| |
|
| | [[autodoc]] BeamScorer |
| | - process |
| | - finalize |
| |
|
| | [[autodoc]] BeamSearchScorer |
| | - process |
| | - finalize |
| |
|
| | [[autodoc]] ConstrainedBeamSearchScorer |
| | - process |
| | - finalize |
| |
|
| | ## Utilities |
| |
|
| | [[autodoc]] top_k_top_p_filtering |
| |
|
| | [[autodoc]] tf_top_k_top_p_filtering |
| |
|