| | |
| |
|
| | from transformers import ( |
| | GenerationConfig, |
| | GenerationMixin, |
| | LogitsProcessorList, |
| | StoppingCriteriaList, |
| | DisjunctiveConstraint, |
| | BeamSearchScorer, |
| | PhrasalConstraint, |
| | ConstrainedBeamSearchScorer, |
| | PreTrainedModel, |
| | ) |
| | import numpy as np |
| | import random |
| | import warnings |
| | import inspect |
| | from transformers.generation.utils import GenerateOutput, SampleOutput, logger |
| | import torch |
| | from typing import Callable, List, Optional, Union |
| | from torch import nn |
| | import torch.distributed as dist |
| | import copy |
| |
|
| |
|
| | def setup_seed(seed): |
| | if seed == -1: |
| | return |
| | torch.manual_seed(seed) |
| | if torch.cuda.is_available(): |
| | torch.cuda.manual_seed_all(seed) |
| | np.random.seed(seed) |
| | random.seed(seed) |
| | torch.backends.cudnn.deterministic = True |
| |
|
| |
|
| | class StreamGenerationConfig(GenerationConfig): |
| | def __init__(self, **kwargs): |
| | super().__init__(**kwargs) |
| | self.do_stream = kwargs.pop("do_stream", False) |
| |
|
| |
|
| | class NewGenerationMixin(GenerationMixin): |
| | @torch.no_grad() |
| | def generate( |
| | self, |
| | inputs: Optional[torch.Tensor] = None, |
| | generation_config: Optional[StreamGenerationConfig] = None, |
| | logits_processor: Optional[LogitsProcessorList] = None, |
| | stopping_criteria: Optional[StoppingCriteriaList] = None, |
| | prefix_allowed_tokens_fn: Optional[ |
| | Callable[[int, torch.Tensor], List[int]] |
| | ] = None, |
| | synced_gpus: Optional[bool] = False, |
| | seed=0, |
| | **kwargs, |
| | ) -> Union[GenerateOutput, torch.LongTensor]: |
| | r""" |
| | |
| | Generates sequences of token ids for models with a language modeling head. |
| | |
| | <Tip warning={true}> |
| | |
| | Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the |
| | model's default generation configuration. You can override any `generation_config` by passing the corresponding |
| | parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. |
| | |
| | For an overview of generation strategies and code examples, check out the [following |
| | guide](./generation_strategies). |
| | |
| | </Tip> |
| | |
| | Parameters: |
| | inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): |
| | The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the |
| | method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` |
| | should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of |
| | `input_ids`, `input_values`, `input_features`, or `pixel_values`. |
| | generation_config (`~generation.GenerationConfig`, *optional*): |
| | The generation configuration to be used as base parametrization for the generation call. `**kwargs` |
| | passed to generate matching the attributes of `generation_config` will override them. If |
| | `generation_config` is not provided, the default will be used, which had the following loading |
| | priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model |
| | configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s |
| | default values, whose documentation should be checked to parameterize generation. |
| | logits_processor (`LogitsProcessorList`, *optional*): |
| | Custom logits processors that complement the default logits processors built from arguments and |
| | generation config. If a logit processor is passed that is already created with the arguments or a |
| | generation config an error is thrown. This feature is intended for advanced users. |
| | stopping_criteria (`StoppingCriteriaList`, *optional*): |
| | Custom stopping criteria that complement the default stopping criteria built from arguments and a |
| | generation config. If a stopping criteria is passed that is already created with the arguments or a |
| | generation config an error is thrown. This feature is intended for advanced users. |
| | prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): |
| | If provided, this function constraints the beam search to allowed tokens only at each step. If not |
| | provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and |
| | `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned |
| | on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful |
| | for constrained generation conditioned on the prefix, as described in [Autoregressive Entity |
| | Retrieval](https://arxiv.org/abs/2010.00904). |
| | synced_gpus (`bool`, *optional*, defaults to `False`): |
| | Whether to continue running the while loop until max_length (needed for ZeRO stage 3) |
| | kwargs: |
| | Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be |
| | forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder |
| | specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. |
| | |
| | Return: |
| | [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` |
| | or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. |
| | |
| | If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible |
| | [`~utils.ModelOutput`] types are: |
| | |
| | - [`~generation.GreedySearchDecoderOnlyOutput`], |
| | - [`~generation.SampleDecoderOnlyOutput`], |
| | - [`~generation.BeamSearchDecoderOnlyOutput`], |
| | - [`~generation.BeamSampleDecoderOnlyOutput`] |
| | |
| | If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible |
| | [`~utils.ModelOutput`] types are: |
| | |
| | - [`~generation.GreedySearchEncoderDecoderOutput`], |
| | - [`~generation.SampleEncoderDecoderOutput`], |
| | - [`~generation.BeamSearchEncoderDecoderOutput`], |
| | - [`~generation.BeamSampleEncoderDecoderOutput`] |
| | """ |
| | setup_seed(seed) |
| | |
| | self._validate_model_class() |
| |
|
| | |
| | if generation_config is None: |
| | |
| | |
| | if self.generation_config._from_model_config: |
| | new_generation_config = StreamGenerationConfig.from_model_config( |
| | self.config |
| | ) |
| | if new_generation_config != self.generation_config: |
| | warnings.warn( |
| | "You have modified the pretrained model configuration to control generation. This is a" |
| | " deprecated strategy to control generation and will be removed soon, in a future version." |
| | " Please use a generation configuration file (see" |
| | " https://huggingface.co/docs/transformers/main_classes/text_generation)" |
| | ) |
| | self.generation_config = new_generation_config |
| | generation_config = self.generation_config |
| |
|
| | generation_config = copy.deepcopy(generation_config) |
| | model_kwargs = generation_config.update( |
| | **kwargs |
| | ) |
| | |
| |
|
| | |
| | logits_processor = ( |
| | logits_processor if logits_processor is not None else LogitsProcessorList() |
| | ) |
| | stopping_criteria = ( |
| | stopping_criteria |
| | if stopping_criteria is not None |
| | else StoppingCriteriaList() |
| | ) |
| |
|
| | if ( |
| | generation_config.pad_token_id is None |
| | and generation_config.eos_token_id is not None |
| | ): |
| | if model_kwargs.get("attention_mask", None) is None: |
| | logger.warning( |
| | "The attention mask and the pad token id were not set. As a consequence, you may observe " |
| | "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." |
| | ) |
| | eos_token_id = generation_config.eos_token_id |
| | if isinstance(eos_token_id, list): |
| | eos_token_id = eos_token_id[0] |
| | logger.warning( |
| | f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation." |
| | ) |
| | generation_config.pad_token_id = eos_token_id |
| |
|
| | |
| | |
| | |
| | |
| | |
| | inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( |
| | inputs, generation_config.bos_token_id, model_kwargs |
| | ) |
| | batch_size = inputs_tensor.shape[0] |
| |
|
| | |
| | model_kwargs["output_attentions"] = generation_config.output_attentions |
| | model_kwargs["output_hidden_states"] = generation_config.output_hidden_states |
| | model_kwargs["use_cache"] = generation_config.use_cache |
| |
|
| | accepts_attention_mask = "attention_mask" in set( |
| | inspect.signature(self.forward).parameters.keys() |
| | ) |
| | requires_attention_mask = "encoder_outputs" not in model_kwargs |
| |
|
| | if ( |
| | model_kwargs.get("attention_mask", None) is None |
| | and requires_attention_mask |
| | and accepts_attention_mask |
| | ): |
| | model_kwargs[ |
| | "attention_mask" |
| | ] = self._prepare_attention_mask_for_generation( |
| | inputs_tensor, |
| | generation_config.pad_token_id, |
| | generation_config.eos_token_id, |
| | ) |
| |
|
| | |
| | if not self.config.is_encoder_decoder: |
| | if ( |
| | generation_config.pad_token_id is not None |
| | and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) |
| | > 0 |
| | ): |
| | logger.warning( |
| | "A decoder-only architecture is being used, but right-padding was detected! For correct " |
| | "generation results, please set `padding_side='left'` when initializing the tokenizer." |
| | ) |
| |
|
| | if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: |
| | |
| | |
| | model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( |
| | inputs_tensor, model_kwargs, model_input_name |
| | ) |
| |
|
| | |
| | if self.config.is_encoder_decoder: |
| | input_ids = self._prepare_decoder_input_ids_for_generation( |
| | batch_size, |
| | decoder_start_token_id=generation_config.decoder_start_token_id, |
| | bos_token_id=generation_config.bos_token_id, |
| | model_kwargs=model_kwargs, |
| | device=inputs_tensor.device, |
| | ) |
| | else: |
| | |
| | input_ids = inputs_tensor |
| |
|
| | |
| | input_ids_seq_length = input_ids.shape[-1] |
| | has_default_max_length = ( |
| | kwargs.get("max_length") is None |
| | and generation_config.max_length is not None |
| | ) |
| | if has_default_max_length and generation_config.max_new_tokens is None: |
| | warnings.warn( |
| | "Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to" |
| | f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via the" |
| | " config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we" |
| | " recommend using `max_new_tokens` to control the maximum length of the generation.", |
| | UserWarning, |
| | ) |
| | elif has_default_max_length and generation_config.max_new_tokens is not None: |
| | generation_config.max_length = ( |
| | generation_config.max_new_tokens + input_ids_seq_length |
| | ) |
| | elif ( |
| | not has_default_max_length and generation_config.max_new_tokens is not None |
| | ): |
| | raise ValueError( |
| | "Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a" |
| | " limit to the generated output length. Remove one of those arguments. Please refer to the" |
| | " documentation for more information. " |
| | "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" |
| | ) |
| |
|
| | if ( |
| | generation_config.min_length is not None |
| | and generation_config.min_length > generation_config.max_length |
| | ): |
| | raise ValueError( |
| | f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than" |
| | f" the maximum length ({generation_config.max_length})" |
| | ) |
| | if input_ids_seq_length >= generation_config.max_length: |
| | input_ids_string = ( |
| | "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" |
| | ) |
| | logger.warning( |
| | f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" |
| | f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" |
| | " increasing `max_new_tokens`." |
| | ) |
| |
|
| | |
| | is_constraint_gen_mode = ( |
| | generation_config.constraints is not None |
| | or generation_config.force_words_ids is not None |
| | ) |
| |
|
| | is_contrastive_search_gen_mode = ( |
| | generation_config.top_k is not None |
| | and generation_config.top_k > 1 |
| | and generation_config.do_sample is False |
| | and generation_config.penalty_alpha is not None |
| | and generation_config.penalty_alpha > 0 |
| | ) |
| |
|
| | is_greedy_gen_mode = ( |
| | (generation_config.num_beams == 1) |
| | and (generation_config.num_beam_groups == 1) |
| | and generation_config.do_sample is False |
| | and not is_constraint_gen_mode |
| | and not is_contrastive_search_gen_mode |
| | ) |
| | is_sample_gen_mode = ( |
| | (generation_config.num_beams == 1) |
| | and (generation_config.num_beam_groups == 1) |
| | and generation_config.do_sample is True |
| | and generation_config.do_stream is False |
| | and not is_constraint_gen_mode |
| | and not is_contrastive_search_gen_mode |
| | ) |
| | is_sample_gen_stream_mode = ( |
| | (generation_config.num_beams == 1) |
| | and (generation_config.num_beam_groups == 1) |
| | and generation_config.do_stream is True |
| | and not is_constraint_gen_mode |
| | and not is_contrastive_search_gen_mode |
| | ) |
| | is_beam_gen_mode = ( |
| | (generation_config.num_beams > 1) |
| | and (generation_config.num_beam_groups == 1) |
| | and generation_config.do_sample is False |
| | and not is_constraint_gen_mode |
| | and not is_contrastive_search_gen_mode |
| | ) |
| | is_beam_sample_gen_mode = ( |
| | (generation_config.num_beams > 1) |
| | and (generation_config.num_beam_groups == 1) |
| | and generation_config.do_sample is True |
| | and not is_constraint_gen_mode |
| | and not is_contrastive_search_gen_mode |
| | ) |
| | is_group_beam_gen_mode = ( |
| | (generation_config.num_beams > 1) |
| | and (generation_config.num_beam_groups > 1) |
| | and not is_constraint_gen_mode |
| | and not is_contrastive_search_gen_mode |
| | ) |
| |
|
| | if generation_config.num_beam_groups > generation_config.num_beams: |
| | raise ValueError( |
| | "`num_beam_groups` has to be smaller or equal to `num_beams`" |
| | ) |
| | if is_group_beam_gen_mode and generation_config.do_sample is True: |
| | raise ValueError( |
| | "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`." |
| | ) |
| |
|
| | if self.device.type != input_ids.device.type: |
| | warnings.warn( |
| | "You are calling .generate() with the `input_ids` being on a device type different" |
| | f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model" |
| | f" is on {self.device.type}. You may experience unexpected behaviors or slower generation." |
| | " Please make sure that you have put `input_ids` to the" |
| | f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before" |
| | " running `.generate()`.", |
| | UserWarning, |
| | ) |
| | |
| | logits_processor = self._get_logits_processor( |
| | generation_config=generation_config, |
| | input_ids_seq_length=input_ids_seq_length, |
| | encoder_input_ids=inputs_tensor, |
| | prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, |
| | logits_processor=logits_processor, |
| | ) |
| |
|
| | |
| | stopping_criteria = self._get_stopping_criteria( |
| | generation_config=generation_config, stopping_criteria=stopping_criteria |
| | ) |
| | |
| | if is_greedy_gen_mode: |
| | if generation_config.num_return_sequences > 1: |
| | raise ValueError( |
| | f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" |
| | " greedy search." |
| | ) |
| |
|
| | |
| | return self.greedy_search( |
| | input_ids, |
| | logits_processor=logits_processor, |
| | stopping_criteria=stopping_criteria, |
| | pad_token_id=generation_config.pad_token_id, |
| | eos_token_id=generation_config.eos_token_id, |
| | output_scores=generation_config.output_scores, |
| | return_dict_in_generate=generation_config.return_dict_in_generate, |
| | synced_gpus=synced_gpus, |
| | **model_kwargs, |
| | ) |
| |
|
| | elif is_contrastive_search_gen_mode: |
| | if generation_config.num_return_sequences > 1: |
| | raise ValueError( |
| | f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" |
| | " contrastive search." |
| | ) |
| |
|
| | return self.contrastive_search( |
| | input_ids, |
| | top_k=generation_config.top_k, |
| | penalty_alpha=generation_config.penalty_alpha, |
| | logits_processor=logits_processor, |
| | stopping_criteria=stopping_criteria, |
| | pad_token_id=generation_config.pad_token_id, |
| | eos_token_id=generation_config.eos_token_id, |
| | output_scores=generation_config.output_scores, |
| | return_dict_in_generate=generation_config.return_dict_in_generate, |
| | synced_gpus=synced_gpus, |
| | **model_kwargs, |
| | ) |
| |
|
| | elif is_sample_gen_mode: |
| | |
| | logits_warper = self._get_logits_warper(generation_config) |
| |
|
| | |
| | input_ids, model_kwargs = self._expand_inputs_for_generation( |
| | input_ids=input_ids, |
| | expand_size=generation_config.num_return_sequences, |
| | is_encoder_decoder=self.config.is_encoder_decoder, |
| | **model_kwargs, |
| | ) |
| |
|
| | |
| | return self.sample( |
| | input_ids, |
| | logits_processor=logits_processor, |
| | logits_warper=logits_warper, |
| | stopping_criteria=stopping_criteria, |
| | pad_token_id=generation_config.pad_token_id, |
| | eos_token_id=generation_config.eos_token_id, |
| | output_scores=generation_config.output_scores, |
| | return_dict_in_generate=generation_config.return_dict_in_generate, |
| | synced_gpus=synced_gpus, |
| | **model_kwargs, |
| | ) |
| | elif is_sample_gen_stream_mode: |
| | |
| | logits_warper = self._get_logits_warper(generation_config) |
| |
|
| | |
| | input_ids, model_kwargs = self._expand_inputs_for_generation( |
| | input_ids=input_ids, |
| | expand_size=generation_config.num_return_sequences, |
| | is_encoder_decoder=self.config.is_encoder_decoder, |
| | **model_kwargs, |
| | ) |
| |
|
| | |
| | return self.sample_stream( |
| | input_ids, |
| | logits_processor=logits_processor, |
| | logits_warper=logits_warper, |
| | stopping_criteria=stopping_criteria, |
| | pad_token_id=generation_config.pad_token_id, |
| | eos_token_id=generation_config.eos_token_id, |
| | output_scores=generation_config.output_scores, |
| | return_dict_in_generate=generation_config.return_dict_in_generate, |
| | synced_gpus=synced_gpus, |
| | **model_kwargs, |
| | ) |
| | elif is_beam_gen_mode: |
| | if generation_config.num_return_sequences > generation_config.num_beams: |
| | raise ValueError( |
| | "`num_return_sequences` has to be smaller or equal to `num_beams`." |
| | ) |
| |
|
| | if stopping_criteria.max_length is None: |
| | raise ValueError( |
| | "`max_length` needs to be a stopping_criteria for now." |
| | ) |
| |
|
| | |
| | beam_scorer = BeamSearchScorer( |
| | batch_size=batch_size, |
| | num_beams=generation_config.num_beams, |
| | device=inputs_tensor.device, |
| | length_penalty=generation_config.length_penalty, |
| | do_early_stopping=generation_config.early_stopping, |
| | num_beam_hyps_to_keep=generation_config.num_return_sequences, |
| | ) |
| | |
| | input_ids, model_kwargs = self._expand_inputs_for_generation( |
| | input_ids=input_ids, |
| | expand_size=generation_config.num_beams, |
| | is_encoder_decoder=self.config.is_encoder_decoder, |
| | **model_kwargs, |
| | ) |
| | |
| | return self.beam_search( |
| | input_ids, |
| | beam_scorer, |
| | logits_processor=logits_processor, |
| | stopping_criteria=stopping_criteria, |
| | pad_token_id=generation_config.pad_token_id, |
| | eos_token_id=generation_config.eos_token_id, |
| | output_scores=generation_config.output_scores, |
| | return_dict_in_generate=generation_config.return_dict_in_generate, |
| | synced_gpus=synced_gpus, |
| | **model_kwargs, |
| | ) |
| |
|
| | elif is_beam_sample_gen_mode: |
| | |
| | logits_warper = self._get_logits_warper(generation_config) |
| |
|
| | if stopping_criteria.max_length is None: |
| | raise ValueError( |
| | "`max_length` needs to be a stopping_criteria for now." |
| | ) |
| | |
| | beam_scorer = BeamSearchScorer( |
| | batch_size=batch_size * generation_config.num_return_sequences, |
| | num_beams=generation_config.num_beams, |
| | device=inputs_tensor.device, |
| | length_penalty=generation_config.length_penalty, |
| | do_early_stopping=generation_config.early_stopping, |
| | ) |
| |
|
| | |
| | input_ids, model_kwargs = self._expand_inputs_for_generation( |
| | input_ids=input_ids, |
| | expand_size=generation_config.num_beams |
| | * generation_config.num_return_sequences, |
| | is_encoder_decoder=self.config.is_encoder_decoder, |
| | **model_kwargs, |
| | ) |
| |
|
| | |
| | return self.beam_sample( |
| | input_ids, |
| | beam_scorer, |
| | logits_processor=logits_processor, |
| | logits_warper=logits_warper, |
| | stopping_criteria=stopping_criteria, |
| | pad_token_id=generation_config.pad_token_id, |
| | eos_token_id=generation_config.eos_token_id, |
| | output_scores=generation_config.output_scores, |
| | return_dict_in_generate=generation_config.return_dict_in_generate, |
| | synced_gpus=synced_gpus, |
| | **model_kwargs, |
| | ) |
| |
|
| | elif is_group_beam_gen_mode: |
| | if generation_config.num_return_sequences > generation_config.num_beams: |
| | raise ValueError( |
| | "`num_return_sequences` has to be smaller or equal to `num_beams`." |
| | ) |
| |
|
| | if generation_config.num_beams % generation_config.num_beam_groups != 0: |
| | raise ValueError( |
| | "`num_beams` should be divisible by `num_beam_groups` for group beam search." |
| | ) |
| |
|
| | if stopping_criteria.max_length is None: |
| | raise ValueError( |
| | "`max_length` needs to be a stopping_criteria for now." |
| | ) |
| |
|
| | has_default_typical_p = ( |
| | kwargs.get("typical_p") is None and generation_config.typical_p == 1.0 |
| | ) |
| | if not has_default_typical_p: |
| | raise ValueError( |
| | "Decoder argument `typical_p` is not supported with beam groups." |
| | ) |
| |
|
| | |
| | beam_scorer = BeamSearchScorer( |
| | batch_size=batch_size, |
| | num_beams=generation_config.num_beams, |
| | max_length=stopping_criteria.max_length, |
| | device=inputs_tensor.device, |
| | length_penalty=generation_config.length_penalty, |
| | do_early_stopping=generation_config.early_stopping, |
| | num_beam_hyps_to_keep=generation_config.num_return_sequences, |
| | num_beam_groups=generation_config.num_beam_groups, |
| | ) |
| | |
| | input_ids, model_kwargs = self._expand_inputs_for_generation( |
| | input_ids=input_ids, |
| | expand_size=generation_config.num_beams, |
| | is_encoder_decoder=self.config.is_encoder_decoder, |
| | **model_kwargs, |
| | ) |
| | |
| | return self.group_beam_search( |
| | input_ids, |
| | beam_scorer, |
| | logits_processor=logits_processor, |
| | stopping_criteria=stopping_criteria, |
| | pad_token_id=generation_config.pad_token_id, |
| | eos_token_id=generation_config.eos_token_id, |
| | output_scores=generation_config.output_scores, |
| | return_dict_in_generate=generation_config.return_dict_in_generate, |
| | synced_gpus=synced_gpus, |
| | **model_kwargs, |
| | ) |
| |
|
| | elif is_constraint_gen_mode: |
| | if generation_config.num_return_sequences > generation_config.num_beams: |
| | raise ValueError( |
| | "`num_return_sequences` has to be smaller or equal to `num_beams`." |
| | ) |
| |
|
| | if stopping_criteria.max_length is None: |
| | raise ValueError( |
| | "`max_length` needs to be a stopping_criteria for now." |
| | ) |
| |
|
| | if generation_config.num_beams <= 1: |
| | raise ValueError( |
| | "`num_beams` needs to be greater than 1 for constrained generation." |
| | ) |
| |
|
| | if generation_config.do_sample: |
| | raise ValueError( |
| | "`do_sample` needs to be false for constrained generation." |
| | ) |
| |
|
| | if ( |
| | generation_config.num_beam_groups is not None |
| | and generation_config.num_beam_groups > 1 |
| | ): |
| | raise ValueError( |
| | "`num_beam_groups` not supported yet for constrained generation." |
| | ) |
| |
|
| | final_constraints = [] |
| | if generation_config.constraints is not None: |
| | final_constraints = generation_config.constraints |
| |
|
| | if generation_config.force_words_ids is not None: |
| |
|
| | def typeerror(): |
| | raise ValueError( |
| | "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`" |
| | f"of positive integers, but is {generation_config.force_words_ids}." |
| | ) |
| |
|
| | if ( |
| | not isinstance(generation_config.force_words_ids, list) |
| | or len(generation_config.force_words_ids) == 0 |
| | ): |
| | typeerror() |
| |
|
| | for word_ids in generation_config.force_words_ids: |
| | if isinstance(word_ids[0], list): |
| | if not isinstance(word_ids, list) or len(word_ids) == 0: |
| | typeerror() |
| | if any( |
| | not isinstance(token_ids, list) for token_ids in word_ids |
| | ): |
| | typeerror() |
| | if any( |
| | any( |
| | (not isinstance(token_id, int) or token_id < 0) |
| | for token_id in token_ids |
| | ) |
| | for token_ids in word_ids |
| | ): |
| | typeerror() |
| |
|
| | constraint = DisjunctiveConstraint(word_ids) |
| | else: |
| | if not isinstance(word_ids, list) or len(word_ids) == 0: |
| | typeerror() |
| | if any( |
| | (not isinstance(token_id, int) or token_id < 0) |
| | for token_id in word_ids |
| | ): |
| | typeerror() |
| |
|
| | constraint = PhrasalConstraint(word_ids) |
| | final_constraints.append(constraint) |
| |
|
| | |
| | constrained_beam_scorer = ConstrainedBeamSearchScorer( |
| | constraints=final_constraints, |
| | batch_size=batch_size, |
| | num_beams=generation_config.num_beams, |
| | device=inputs_tensor.device, |
| | length_penalty=generation_config.length_penalty, |
| | do_early_stopping=generation_config.early_stopping, |
| | num_beam_hyps_to_keep=generation_config.num_return_sequences, |
| | ) |
| | |
| | input_ids, model_kwargs = self._expand_inputs_for_generation( |
| | input_ids=input_ids, |
| | expand_size=generation_config.num_beams, |
| | is_encoder_decoder=self.config.is_encoder_decoder, |
| | **model_kwargs, |
| | ) |
| | |
| | return self.constrained_beam_search( |
| | input_ids, |
| | constrained_beam_scorer=constrained_beam_scorer, |
| | logits_processor=logits_processor, |
| | stopping_criteria=stopping_criteria, |
| | pad_token_id=generation_config.pad_token_id, |
| | eos_token_id=generation_config.eos_token_id, |
| | output_scores=generation_config.output_scores, |
| | return_dict_in_generate=generation_config.return_dict_in_generate, |
| | synced_gpus=synced_gpus, |
| | **model_kwargs, |
| | ) |
| |
|
| | @torch.no_grad() |
| | def sample_stream( |
| | self, |
| | input_ids: torch.LongTensor, |
| | logits_processor: Optional[LogitsProcessorList] = None, |
| | stopping_criteria: Optional[StoppingCriteriaList] = None, |
| | logits_warper: Optional[LogitsProcessorList] = None, |
| | max_length: Optional[int] = None, |
| | pad_token_id: Optional[int] = None, |
| | eos_token_id: Optional[Union[int, List[int]]] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | output_scores: Optional[bool] = None, |
| | return_dict_in_generate: Optional[bool] = None, |
| | synced_gpus: Optional[bool] = False, |
| | **model_kwargs, |
| | ) -> Union[SampleOutput, torch.LongTensor]: |
| | r""" |
| | Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and |
| | can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. |
| | |
| | <Tip warning={true}> |
| | |
| | In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead. |
| | For an overview of generation strategies and code examples, check the [following |
| | guide](./generation_strategies). |
| | |
| | </Tip> |
| | |
| | Parameters: |
| | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
| | The sequence used as a prompt for the generation. |
| | logits_processor (`LogitsProcessorList`, *optional*): |
| | An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] |
| | used to modify the prediction scores of the language modeling head applied at each generation step. |
| | stopping_criteria (`StoppingCriteriaList`, *optional*): |
| | An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] |
| | used to tell if the generation loop should stop. |
| | logits_warper (`LogitsProcessorList`, *optional*): |
| | An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used |
| | to warp the prediction score distribution of the language modeling head applied before multinomial |
| | sampling at each generation step. |
| | max_length (`int`, *optional*, defaults to 20): |
| | **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated |
| | tokens. The maximum length of the sequence to be generated. |
| | pad_token_id (`int`, *optional*): |
| | The id of the *padding* token. |
| | eos_token_id (`int`, *optional*): |
| | The id of the *end-of-sequence* token. |
| | output_attentions (`bool`, *optional*, defaults to `False`): |
| | Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
| | returned tensors for more details. |
| | output_hidden_states (`bool`, *optional*, defaults to `False`): |
| | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors |
| | for more details. |
| | output_scores (`bool`, *optional*, defaults to `False`): |
| | Whether or not to return the prediction scores. See `scores` under returned tensors for more details. |
| | return_dict_in_generate (`bool`, *optional*, defaults to `False`): |
| | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| | synced_gpus (`bool`, *optional*, defaults to `False`): |
| | Whether to continue running the while loop until max_length (needed for ZeRO stage 3) |
| | model_kwargs: |
| | Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is |
| | an encoder-decoder model the kwargs should include `encoder_outputs`. |
| | |
| | Return: |
| | [`~generation.SampleDecoderOnlyOutput`], [`~generation.SampleEncoderDecoderOutput`] or `torch.LongTensor`: |
| | A `torch.LongTensor` containing the generated tokens (default behaviour) or a |
| | [`~generation.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and |
| | `return_dict_in_generate=True` or a [`~generation.SampleEncoderDecoderOutput`] if |
| | `model.config.is_encoder_decoder=True`. |
| | |
| | Examples: |
| | |
| | ```python |
| | >>> from transformers import ( |
| | ... AutoTokenizer, |
| | ... AutoModelForCausalLM, |
| | ... LogitsProcessorList, |
| | ... MinLengthLogitsProcessor, |
| | ... TopKLogitsWarper, |
| | ... TemperatureLogitsWarper, |
| | ... StoppingCriteriaList, |
| | ... MaxLengthCriteria, |
| | ... ) |
| | >>> import torch |
| | |
| | >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") |
| | >>> model = AutoModelForCausalLM.from_pretrained("gpt2") |
| | |
| | >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token |
| | >>> model.config.pad_token_id = model.config.eos_token_id |
| | >>> model.generation_config.pad_token_id = model.config.eos_token_id |
| | |
| | >>> input_prompt = "Today is a beautiful day, and" |
| | >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids |
| | |
| | >>> # instantiate logits processors |
| | >>> logits_processor = LogitsProcessorList( |
| | ... [ |
| | ... MinLengthLogitsProcessor(15, eos_token_id=model.generation_config.eos_token_id), |
| | ... ] |
| | ... ) |
| | >>> # instantiate logits processors |
| | >>> logits_warper = LogitsProcessorList( |
| | ... [ |
| | ... TopKLogitsWarper(50), |
| | ... TemperatureLogitsWarper(0.7), |
| | ... ] |
| | ... ) |
| | |
| | >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) |
| | |
| | >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT |
| | >>> outputs = model.sample( |
| | ... input_ids, |
| | ... logits_processor=logits_processor, |
| | ... logits_warper=logits_warper, |
| | ... stopping_criteria=stopping_criteria, |
| | ... ) |
| | |
| | >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) |
| | ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the'] |
| | ```""" |
| | |
| | logits_processor = ( |
| | logits_processor if logits_processor is not None else LogitsProcessorList() |
| | ) |
| | stopping_criteria = ( |
| | stopping_criteria |
| | if stopping_criteria is not None |
| | else StoppingCriteriaList() |
| | ) |
| | if max_length is not None: |
| | warnings.warn( |
| | "`max_length` is deprecated in this function, use" |
| | " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", |
| | UserWarning, |
| | ) |
| | stopping_criteria = validate_stopping_criteria( |
| | stopping_criteria, max_length |
| | ) |
| | logits_warper = ( |
| | logits_warper if logits_warper is not None else LogitsProcessorList() |
| | ) |
| | pad_token_id = ( |
| | pad_token_id |
| | if pad_token_id is not None |
| | else self.generation_config.pad_token_id |
| | ) |
| | eos_token_id = ( |
| | eos_token_id |
| | if eos_token_id is not None |
| | else self.generation_config.eos_token_id |
| | ) |
| | if isinstance(eos_token_id, int): |
| | eos_token_id = [eos_token_id] |
| | output_scores = ( |
| | output_scores |
| | if output_scores is not None |
| | else self.generation_config.output_scores |
| | ) |
| | output_attentions = ( |
| | output_attentions |
| | if output_attentions is not None |
| | else self.generation_config.output_attentions |
| | ) |
| | output_hidden_states = ( |
| | output_hidden_states |
| | if output_hidden_states is not None |
| | else self.generation_config.output_hidden_states |
| | ) |
| | return_dict_in_generate = ( |
| | return_dict_in_generate |
| | if return_dict_in_generate is not None |
| | else self.generation_config.return_dict_in_generate |
| | ) |
| |
|
| | |
| | scores = () if (return_dict_in_generate and output_scores) else None |
| | decoder_attentions = ( |
| | () if (return_dict_in_generate and output_attentions) else None |
| | ) |
| | cross_attentions = ( |
| | () if (return_dict_in_generate and output_attentions) else None |
| | ) |
| | decoder_hidden_states = ( |
| | () if (return_dict_in_generate and output_hidden_states) else None |
| | ) |
| |
|
| | |
| | unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) |
| |
|
| | this_peer_finished = False |
| | |
| | while True: |
| | if synced_gpus: |
| | |
| | |
| | this_peer_finished_flag = torch.tensor( |
| | 0.0 if this_peer_finished else 1.0 |
| | ).to(input_ids.device) |
| | |
| | dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) |
| | |
| | if this_peer_finished_flag.item() == 0.0: |
| | break |
| |
|
| | |
| | model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) |
| |
|
| | |
| | outputs = self( |
| | **model_inputs, |
| | return_dict=True, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | ) |
| |
|
| | if synced_gpus and this_peer_finished: |
| | continue |
| |
|
| | next_token_logits = outputs.logits[:, -1, :] |
| |
|
| | |
| | next_token_scores = logits_processor(input_ids, next_token_logits) |
| | next_token_scores = logits_warper(input_ids, next_token_scores) |
| |
|
| | |
| | if return_dict_in_generate: |
| | if output_scores: |
| | scores += (next_token_scores,) |
| | if output_attentions: |
| | decoder_attentions += ( |
| | (outputs.decoder_attentions,) |
| | if self.config.is_encoder_decoder |
| | else (outputs.attentions,) |
| | ) |
| | if self.config.is_encoder_decoder: |
| | cross_attentions += (outputs.cross_attentions,) |
| |
|
| | if output_hidden_states: |
| | decoder_hidden_states += ( |
| | (outputs.decoder_hidden_states,) |
| | if self.config.is_encoder_decoder |
| | else (outputs.hidden_states,) |
| | ) |
| |
|
| | |
| | probs = nn.functional.softmax(next_token_scores, dim=-1) |
| | next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) |
| |
|
| | |
| | if eos_token_id is not None: |
| | if pad_token_id is None: |
| | raise ValueError( |
| | "If `eos_token_id` is defined, make sure that `pad_token_id` is defined." |
| | ) |
| | next_tokens = next_tokens * unfinished_sequences + pad_token_id * ( |
| | 1 - unfinished_sequences |
| | ) |
| | yield next_tokens, self.final_norm(outputs.hidden_states[-1][:, -1]) |
| | |
| | input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) |
| | model_kwargs = self._update_model_kwargs_for_generation( |
| | outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder |
| | ) |
| |
|
| | |
| | if eos_token_id is not None: |
| | unfinished_sequences = unfinished_sequences.mul( |
| | (sum(next_tokens != i for i in eos_token_id)).long() |
| | ) |
| |
|
| | |
| | if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): |
| | if not synced_gpus: |
| | break |
| | else: |
| | this_peer_finished = True |
| |
|
| |
|
| | def init_stream_support(): |
| | """Overload PreTrainedModel for streaming.""" |
| | PreTrainedModel.generate_stream = NewGenerationMixin.generate |
| | PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream |
| |
|
| |
|
| | if __name__ == "__main__": |
| | from transformers import PreTrainedModel |
| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| |
|
| | PreTrainedModel.generate = NewGenerationMixin.generate |
| | PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream |
| | model = AutoModelForCausalLM.from_pretrained( |
| | "bigscience/bloom-560m", torch_dtype=torch.float16 |
| | ) |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m") |
| | model = model.to("cuda:0") |
| | model = model.eval() |
| | prompt_text = "hello? \n" |
| | input_ids = tokenizer( |
| | prompt_text, return_tensors="pt", add_special_tokens=False |
| | ).input_ids |
| | input_ids = input_ids.to("cuda:0") |
| |
|
| | with torch.no_grad(): |
| | result = model.generate( |
| | input_ids, |
| | max_new_tokens=200, |
| | do_sample=True, |
| | top_k=30, |
| | top_p=0.85, |
| | temperature=0.35, |
| | repetition_penalty=1.2, |
| | early_stopping=True, |
| | seed=0, |
| | ) |
| | print(tokenizer.decode(result, skip_special_tokens=True)) |
| | generator = model.generate( |
| | input_ids, |
| | max_new_tokens=200, |
| | do_sample=True, |
| | top_k=30, |
| | top_p=0.85, |
| | temperature=0.35, |
| | repetition_penalty=1.2, |
| | early_stopping=True, |
| | seed=0, |
| | do_stream=True, |
| | ) |
| | stream_result = "" |
| | for x in generator: |
| | chunk = tokenizer.decode(x, skip_special_tokens=True) |
| | stream_result += chunk |
| | print(stream_result) |
| |
|