| import inspect |
| import warnings |
| from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union |
|
|
| import torch |
| import torch.distributed as dist |
| from torch import nn |
|
|
| from transformers.generation_beam_constraints import Constraint, DisjunctiveConstraint, PhrasalConstraint |
| from transformers.generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer |
| from transformers.generation_logits_process import ( |
| EncoderNoRepeatNGramLogitsProcessor, |
| ExponentialDecayLengthPenalty, |
| ForcedBOSTokenLogitsProcessor, |
| ForcedEOSTokenLogitsProcessor, |
| HammingDiversityLogitsProcessor, |
| InfNanRemoveLogitsProcessor, |
| LogitNormalization, |
| LogitsProcessorList, |
| MinLengthLogitsProcessor, |
| NoBadWordsLogitsProcessor, |
| NoRepeatNGramLogitsProcessor, |
| PrefixConstrainedLogitsProcessor, |
| RepetitionPenaltyLogitsProcessor, |
| TemperatureLogitsWarper, |
| TopKLogitsWarper, |
| TopPLogitsWarper, |
| TypicalLogitsWarper, |
| ) |
| from transformers.generation_stopping_criteria import ( |
| MaxLengthCriteria, |
| MaxTimeCriteria, |
| StoppingCriteria, |
| StoppingCriteriaList, |
| validate_stopping_criteria, |
| ) |
| from transformers.pytorch_utils import torch_int_div |
| from transformers.utils import ModelOutput |
|
|
| from transformers.generation_utils import ( |
| SampleOutput, |
| BeamSearchOutput, |
| BeamSampleOutput, |
| GreedySearchOutput, GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, GreedySearchEncoderDecoderOutput, |
| BeamSearchDecoderOnlyOutput, BeamSearchEncoderDecoderOutput, BeamSampleDecoderOnlyOutput, |
| BeamSampleEncoderDecoderOutput, SampleEncoderDecoderOutput, |
| ) |
| from utils import get_jump_chunks |
| from torch.nn.utils.rnn import pad_sequence |
|
|
| class GenerationMixin: |
| """ |
| A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`]. |
| |
| The class exposes [`~generation_utils.GenerationMixin.generate`], which can be used for: |
| - *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and |
| `do_sample=False`. |
| - *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and |
| `do_sample=True`. |
| - *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and |
| `do_sample=False`. |
| - *beam-search multinomial sampling* by calling [`~generation_utils.GenerationMixin.beam_sample`] if |
| `num_beams>1` and `do_sample=True`. |
| - *diverse beam-search decoding* by calling [`~generation_utils.GenerationMixin.group_beam_search`], if |
| `num_beams>1` and `num_beam_groups>1`. |
| - *constrained beam-search decoding* by calling [`~generation_utils.GenerationMixin.constrained_beam_search`], |
| if `constraints!=None` or `force_words_ids!=None`. |
| """ |
|
|
| def _prepare_model_inputs( |
| self, |
| inputs: Optional[torch.Tensor] = None, |
| bos_token_id: Optional[int] = None, |
| model_kwargs: Optional[Dict[str, torch.Tensor]] = None, |
| ) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]: |
| """ |
| This function extracts the model-specific `inputs` for generation. |
| """ |
| |
| |
| if ( |
| self.config.is_encoder_decoder |
| and hasattr(self, "encoder") |
| and self.encoder.main_input_name != self.main_input_name |
| ): |
| input_name = self.encoder.main_input_name |
| else: |
| input_name = self.main_input_name |
|
|
| model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name} |
|
|
| |
| |
| inputs_kwarg = model_kwargs.pop(input_name, None) |
| if inputs_kwarg is not None and inputs is not None: |
| raise ValueError( |
| f"`inputs`: {inputs}` were passed alongside " |
| f"{input_name} which is not allowed." |
| f"Make sure to either pass {inputs} or {input_name}=..." |
| ) |
| elif inputs_kwarg is not None: |
| inputs = inputs_kwarg |
|
|
| |
| if self._can_retrieve_inputs_from_name(inputs, "inputs_embeds", model_kwargs): |
| inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds" |
|
|
| |
| if not self.config.is_encoder_decoder and input_name != "input_ids": |
| raise ValueError( |
| f"If {input_name} is passed as model-specific keyword " |
| "input then model has to be an encoder-decoder and not a " |
| f"{self.__class__.__name__}." |
| ) |
|
|
| |
| if inputs is None: |
| inputs = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs")) |
|
|
| return inputs, input_name, model_kwargs |
|
|
| def _can_retrieve_inputs_from_name( |
| self, inputs: Optional[torch.Tensor], name: str, model_kwargs: Dict[str, torch.Tensor] |
| ) -> torch.Tensor: |
| """ |
| If `inputs` is None and `name` is in both forward function and keyword arguments, then inputs can be retrieved |
| from name |
| """ |
| can_retrieve_inputs = model_kwargs.get(name, None) is not None and name in set( |
| inspect.signature(self.forward).parameters.keys() |
| ) |
|
|
| if can_retrieve_inputs and inputs is not None: |
| raise ValueError(f"Cannot only pass one of {name} and {self.main_input_name}") |
|
|
| return can_retrieve_inputs |
|
|
| def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]: |
| """ |
| Implement in subclasses of [`PreTrainedModel`] for custom behavior to prepare inputs in the generate method. |
| """ |
| return {"input_ids": input_ids} |
|
|
| def adjust_logits_during_generation(self, logits: torch.FloatTensor, **kwargs) -> torch.FloatTensor: |
| """ |
| Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method. |
| """ |
| return logits |
|
|
| def _prepare_input_ids_for_generation( |
| self, bos_token_id: Optional[int], encoder_outputs: Optional[ModelOutput] |
| ) -> torch.LongTensor: |
| if self.config.is_encoder_decoder and encoder_outputs is not None: |
| |
| shape = encoder_outputs.last_hidden_state.size()[:-1] |
| return torch.ones(shape, dtype=torch.long, device=self.device) * -100 |
|
|
| if bos_token_id is None: |
| raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") |
| return torch.ones((1, 1), dtype=torch.long, device=self.device) * bos_token_id |
|
|
| def _prepare_attention_mask_for_generation( |
| self, |
| inputs: torch.Tensor, |
| pad_token_id: int, |
| eos_token_id: int, |
| ) -> torch.LongTensor: |
| is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long] |
| is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs) |
| is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ( |
| (eos_token_id is not None) and (pad_token_id != eos_token_id) |
| ) |
| |
| if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id: |
| return inputs.ne(pad_token_id).long() |
| else: |
| return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) |
|
|
| def _prepare_encoder_decoder_kwargs_for_generation( |
| self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None |
| ) -> Dict[str, Any]: |
| |
| encoder = self.get_encoder() |
|
|
| |
| irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] |
| encoder_kwargs = { |
| argument: value |
| for argument, value in model_kwargs.items() |
| if not any(argument.startswith(p) for p in irrelevant_prefix) |
| } |
| print('encoder_kwargs:', encoder_kwargs) |
|
|
| |
| model_input_name = model_input_name if model_input_name is not None else self.main_input_name |
| encoder_kwargs["return_dict"] = True |
| encoder_kwargs[model_input_name] = inputs_tensor |
| model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs) |
|
|
| return model_kwargs |
|
|
| def _prepare_decoder_input_ids_for_generation( |
| self, |
| batch_size: int, |
| decoder_start_token_id: int = None, |
| bos_token_id: int = None, |
| model_kwargs: Optional[Dict[str, torch.Tensor]] = None, |
| device: torch.device = None, |
| ) -> torch.LongTensor: |
|
|
| if model_kwargs is not None and "decoder_input_ids" in model_kwargs: |
| return model_kwargs.pop("decoder_input_ids") |
| else: |
| decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) |
| if device is None: |
| device = self.device |
| return torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id |
|
|
| def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int: |
| decoder_start_token_id = ( |
| decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id |
| ) |
| bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id |
|
|
| if decoder_start_token_id is not None: |
| return decoder_start_token_id |
| elif ( |
| hasattr(self.config, "decoder") |
| and hasattr(self.config.decoder, "decoder_start_token_id") |
| and self.config.decoder.decoder_start_token_id is not None |
| ): |
| return self.config.decoder.decoder_start_token_id |
| elif bos_token_id is not None: |
| return bos_token_id |
| elif ( |
| hasattr(self.config, "decoder") |
| and hasattr(self.config.decoder, "bos_token_id") |
| and self.config.decoder.bos_token_id is not None |
| ): |
| return self.config.decoder.bos_token_id |
| raise ValueError( |
| "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." |
| ) |
|
|
| @staticmethod |
| def _expand_inputs_for_generation( |
| input_ids: torch.LongTensor, |
| expand_size: int = 1, |
| is_encoder_decoder: bool = False, |
| attention_mask: Optional[torch.LongTensor] = None, |
| encoder_outputs: Optional[ModelOutput] = None, |
| **model_kwargs, |
| ) -> Tuple[torch.LongTensor, Dict[str, Any]]: |
| expanded_return_idx = ( |
| torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device) |
| ) |
| input_ids = input_ids.index_select(0, expanded_return_idx) |
|
|
| if "token_type_ids" in model_kwargs: |
| token_type_ids = model_kwargs["token_type_ids"] |
| model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx) |
|
|
| if attention_mask is not None: |
| model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx) |
|
|
| if is_encoder_decoder: |
| if encoder_outputs is None: |
| raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") |
| encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( |
| 0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device) |
| ) |
| model_kwargs["encoder_outputs"] = encoder_outputs |
| return input_ids, model_kwargs |
|
|
| @staticmethod |
| def _update_model_kwargs_for_generation( |
| outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False |
| ) -> Dict[str, Any]: |
| |
| if "past_key_values" in outputs: |
| model_kwargs["past"] = outputs.past_key_values |
| elif "mems" in outputs: |
| model_kwargs["past"] = outputs.mems |
| elif "past_buckets_states" in outputs: |
| model_kwargs["past"] = outputs.past_buckets_states |
| else: |
| model_kwargs["past"] = None |
|
|
| |
| if "token_type_ids" in model_kwargs: |
| token_type_ids = model_kwargs["token_type_ids"] |
| model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) |
|
|
| |
| if not is_encoder_decoder: |
| if "attention_mask" in model_kwargs: |
| attention_mask = model_kwargs["attention_mask"] |
| model_kwargs["attention_mask"] = torch.cat( |
| [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 |
| ) |
|
|
| return model_kwargs |
|
|
| def _reorder_cache(self, past, beam_idx): |
| raise NotImplementedError( |
| f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to enable beam search for {self.__class__}" |
| ) |
|
|
| def _get_logits_warper( |
| self, |
| top_k: Optional[int] = None, |
| top_p: Optional[float] = None, |
| typical_p: Optional[float] = None, |
| temperature: Optional[float] = None, |
| num_beams: Optional[int] = None, |
| renormalize_logits: Optional[bool] = None, |
| ) -> LogitsProcessorList: |
| """ |
| This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances |
| used for multinomial sampling. |
| """ |
|
|
| |
| top_k = top_k if top_k is not None else self.config.top_k |
| top_p = top_p if top_p is not None else self.config.top_p |
| typical_p = typical_p if typical_p is not None else self.config.typical_p |
| temperature = temperature if temperature is not None else self.config.temperature |
| |
| warpers = LogitsProcessorList() |
|
|
| |
| |
| if temperature is not None and temperature != 1.0: |
| warpers.append(TemperatureLogitsWarper(temperature)) |
| if top_k is not None and top_k != 0: |
| warpers.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1))) |
| if top_p is not None and top_p < 1.0: |
| warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1))) |
| if typical_p is not None and typical_p < 1.0: |
| warpers.append(TypicalLogitsWarper(mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1))) |
| |
| if renormalize_logits is True: |
| warpers.append(LogitNormalization()) |
| return warpers |
|
|
| def _get_logits_processor( |
| self, |
| repetition_penalty: float, |
| no_repeat_ngram_size: int, |
| encoder_no_repeat_ngram_size: int, |
| input_ids_seq_length: int, |
| encoder_input_ids: torch.LongTensor, |
| bad_words_ids: List[List[int]], |
| min_length: int, |
| max_length: int, |
| eos_token_id: int, |
| forced_bos_token_id: int, |
| forced_eos_token_id: int, |
| prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], |
| num_beams: int, |
| num_beam_groups: int, |
| diversity_penalty: float, |
| remove_invalid_values: bool, |
| exponential_decay_length_penalty: Tuple, |
| logits_processor: Optional[LogitsProcessorList], |
| renormalize_logits: Optional[bool], |
| ) -> LogitsProcessorList: |
| """ |
| This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`] |
| instances used to modify the scores of the language model head. |
| """ |
| processors = LogitsProcessorList() |
|
|
| |
| repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty |
| no_repeat_ngram_size = ( |
| no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size |
| ) |
| encoder_no_repeat_ngram_size = ( |
| encoder_no_repeat_ngram_size |
| if encoder_no_repeat_ngram_size is not None |
| else self.config.encoder_no_repeat_ngram_size |
| ) |
| bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids |
| eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id |
| diversity_penalty = diversity_penalty if diversity_penalty is not None else self.config.diversity_penalty |
| forced_bos_token_id = ( |
| forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id |
| ) |
| forced_eos_token_id = ( |
| forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id |
| ) |
| remove_invalid_values = ( |
| remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values |
| ) |
| exponential_decay_length_penalty = ( |
| exponential_decay_length_penalty |
| if exponential_decay_length_penalty is not None |
| else self.config.exponential_decay_length_penalty |
| ) |
| |
|
|
| |
| |
| if diversity_penalty is not None and diversity_penalty > 0.0: |
| processors.append( |
| HammingDiversityLogitsProcessor( |
| diversity_penalty=diversity_penalty, num_beams=num_beams, num_beam_groups=num_beam_groups |
| ) |
| ) |
| if repetition_penalty is not None and repetition_penalty != 1.0: |
| processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)) |
| if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0: |
| processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size)) |
| if encoder_no_repeat_ngram_size is not None and encoder_no_repeat_ngram_size > 0: |
| if self.config.is_encoder_decoder: |
| processors.append(EncoderNoRepeatNGramLogitsProcessor(encoder_no_repeat_ngram_size, encoder_input_ids)) |
| else: |
| raise ValueError( |
| "It's impossible to use `encoder_no_repeat_ngram_size` with decoder-only architecture" |
| ) |
| if bad_words_ids is not None: |
| processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id)) |
| if min_length is not None and eos_token_id is not None and min_length > 0: |
| processors.append(MinLengthLogitsProcessor(min_length, eos_token_id)) |
| if prefix_allowed_tokens_fn is not None: |
| processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams // num_beam_groups)) |
| if forced_bos_token_id is not None: |
| processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id)) |
| if forced_eos_token_id is not None: |
| processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)) |
| if remove_invalid_values is True: |
| processors.append(InfNanRemoveLogitsProcessor()) |
| if exponential_decay_length_penalty is not None: |
| processors.append( |
| ExponentialDecayLengthPenalty(exponential_decay_length_penalty, eos_token_id, input_ids_seq_length) |
| ) |
| processors = self._merge_criteria_processor_list(processors, logits_processor) |
| |
| if renormalize_logits is True: |
| processors.append(LogitNormalization()) |
| return processors |
|
|
| def _get_stopping_criteria( |
| self, max_length: Optional[int], max_time: Optional[float], stopping_criteria: Optional[StoppingCriteriaList] |
| ) -> StoppingCriteriaList: |
| criteria = StoppingCriteriaList() |
| if max_length is not None: |
| criteria.append(MaxLengthCriteria(max_length=max_length)) |
| if max_time is not None: |
| criteria.append(MaxTimeCriteria(max_time=max_time)) |
| criteria = self._merge_criteria_processor_list(criteria, stopping_criteria) |
| return criteria |
|
|
| def _merge_criteria_processor_list( |
| self, |
| default_list: Union[LogitsProcessorList, StoppingCriteriaList], |
| custom_list: Union[LogitsProcessorList, StoppingCriteriaList], |
| ) -> Union[LogitsProcessorList, StoppingCriteriaList]: |
| if len(custom_list) == 0: |
| return default_list |
| for default in default_list: |
| for custom in custom_list: |
| if type(custom) is type(default): |
| object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor" |
| raise ValueError( |
| f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to `generate`, " |
| f"but it has already been created with the values {default}. {default} has been created by passing the " |
| "corresponding arguments to generate or by the model's config default values. " |
| f"If you just want to change the default values of {object_type} consider passing them as arguments " |
| f"to `generate` instead of using a custom {object_type}." |
| ) |
| default_list.extend(custom_list) |
| return default_list |
|
|
| def compute_transition_beam_scores( |
| self, |
| sequences: torch.Tensor, |
| scores: Tuple[torch.Tensor], |
| beam_indices: torch.Tensor, |
| eos_token_id: int = None, |
| ): |
| """compute the transition probabilities of sequences given generation |
| scores and beam indices""" |
|
|
| |
| |
| |
| scores = torch.stack(scores).reshape(len(scores), -1).transpose(0, 1) |
|
|
| |
| cut_idx = sequences.shape[-1] - scores.shape[-1] |
| |
| beam_sequence_indices = torch.tensor(beam_indices, device=sequences.device) * self.config.vocab_size |
| |
| indices = sequences[:, cut_idx:] + beam_sequence_indices |
| |
| transition_scores = scores.gather(0, indices) |
| |
| |
| eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id |
|
|
| if eos_token_id is not None: |
| is_eos_token_id = sequences[:, cut_idx:] == eos_token_id |
| |
| is_eos_token_id[:, -1] = False |
| is_eos_token_id = is_eos_token_id.roll(1, -1) |
| |
| zero_transition_prob_mask = is_eos_token_id.cumsum(-1).bool() |
| |
| transition_scores.masked_fill_(zero_transition_prob_mask, 0.0) |
|
|
| return transition_scores |
|
|
| |
| def remove_subsets(self, l): |
| |
| l2 = l[:] |
| for m in l: |
| for n in l: |
| if set(m).issubset(set(n)) and m != n: |
| l2.remove(m) |
| break |
| return l2 |
|
|
| |
| @torch.no_grad() |
| def cs_generate( |
| self, |
| inputs: Optional[torch.Tensor] = None, |
| contexts:List[str]=None, |
| model_input:Dict=None, |
| max_length: Optional[int] = None, |
| min_length: Optional[int] = None, |
| do_sample: Optional[bool] = None, |
| early_stopping: Optional[bool] = None, |
| num_beams: Optional[int] = None, |
| temperature: Optional[float] = None, |
| top_k: Optional[int] = None, |
| top_p: Optional[float] = None, |
| typical_p: Optional[float] = None, |
| repetition_penalty: Optional[float] = None, |
| bad_words_ids: Optional[Iterable[int]] = None, |
| force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None, |
| bos_token_id: Optional[int] = None, |
| pad_token_id: Optional[int] = None, |
| eos_token_id: Optional[int] = None, |
| length_penalty: Optional[float] = None, |
| no_repeat_ngram_size: Optional[int] = None, |
| encoder_no_repeat_ngram_size: Optional[int] = None, |
| num_return_sequences: Optional[int] = None, |
| max_time: Optional[float] = None, |
| max_new_tokens: Optional[int] = None, |
| decoder_start_token_id: Optional[int] = None, |
| use_cache: Optional[bool] = None, |
| num_beam_groups: Optional[int] = None, |
| diversity_penalty: Optional[float] = None, |
| prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, |
| logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(), |
| renormalize_logits: Optional[bool] = None, |
| stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(), |
| constraints: Optional[List[Constraint]] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| output_scores: Optional[bool] = None, |
| return_dict_in_generate: Optional[bool] = None, |
| forced_bos_token_id: Optional[int] = None, |
| forced_eos_token_id: Optional[int] = None, |
| remove_invalid_values: Optional[bool] = None, |
| synced_gpus: Optional[bool] = False, |
| exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None, |
| use_kg:bool=False, |
| relation_mapper_builder=None, |
| tokenizer=None, |
| max_neig_per_concept=1, |
| **model_kwargs, |
| ) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]: |
| |
| input_ids = model_input['input_ids'] |
| if "input_commonsense_relations" in model_input: |
| |
| model_kwargs["relation_inputs"] = model_input.get("input_commonsense_relations").to(input_ids.device) |
| if use_kg: |
| all_constraints = [] |
| print('contexts:', contexts[:3]) |
| for context in contexts: |
| constraints = [] |
| print('+++++++') |
| concepts_from_context = relation_mapper_builder.get_concepts_from_context(context=context, |
| clear_common_wds=True, alignment=1) |
| print('concepts_from_context:', concepts_from_context) |
| useful_concepts = [relation_mapper_builder.swow_knowledge.get_related_concepts(concept) for concept in |
| concepts_from_context] |
| if not useful_concepts: |
| useful_concepts = [relation_mapper_builder.knowledge.get_related_concepts(concept) for concept in concepts_from_context] |
| useful_concepts = [[f'{phrase}' for phrase in concepts] for concepts in useful_concepts] |
| |
| |
| |
| print('-------') |
| print('useful_concepts:', useful_concepts) |
| if concepts_from_context and useful_concepts: |
| for context_concept, neighbour_concepts in zip(concepts_from_context, useful_concepts): |
| print('neighbour:', neighbour_concepts[:5]) |
| |
| |
| flexible_words = [word for word in neighbour_concepts if |
| word not in context_concept] |
| print('flexible_words:', flexible_words[:5]) |
| if not flexible_words: |
| continue |
| flexible_words_ids: List[List[int]] = tokenizer(flexible_words, add_special_tokens=False).input_ids |
| flexible_words_ids = self.remove_subsets(flexible_words_ids) |
| |
| |
| flexible_words_ids = flexible_words_ids[:max_neig_per_concept] |
| |
| constraint = DisjunctiveConstraint(flexible_words_ids) |
| constraints.append(constraint) |
| all_constraints.extend(constraints) |
| else: |
| all_constraints = None |
|
|
| generated_answers_encoded = self.generate(input_ids=input_ids, |
| |
| constraints=all_constraints, |
| min_length=min_length, |
| |
| do_sample=do_sample, |
| early_stopping=early_stopping, |
| |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| |
| no_repeat_ngram_size=no_repeat_ngram_size, |
| num_return_sequences=num_return_sequences, |
| return_dict_in_generate=return_dict_in_generate, |
| output_attentions=output_attentions, |
| output_scores=output_scores, |
| **model_kwargs, |
| ) |
| return generated_answers_encoded |
|
|
| |
| @torch.no_grad() |
| def cs_simple_generate( |
| self, |
| inputs: Optional[torch.Tensor] = None, |
| neighbours_contexts:List[List[str]]=None, |
| model_input:Dict=None, |
| max_length: Optional[int] = None, |
| min_length: Optional[int] = None, |
| do_sample: Optional[bool] = None, |
| early_stopping: Optional[bool] = None, |
| num_beams: Optional[int] = None, |
| temperature: Optional[float] = None, |
| top_k: Optional[int] = None, |
| top_p: Optional[float] = None, |
| typical_p: Optional[float] = None, |
| repetition_penalty: Optional[float] = None, |
| bad_words_ids: Optional[Iterable[int]] = None, |
| force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None, |
| bos_token_id: Optional[int] = None, |
| pad_token_id: Optional[int] = None, |
| eos_token_id: Optional[int] = None, |
| length_penalty: Optional[float] = None, |
| no_repeat_ngram_size: Optional[int] = None, |
| encoder_no_repeat_ngram_size: Optional[int] = None, |
| num_return_sequences: Optional[int] = None, |
| max_time: Optional[float] = None, |
| max_new_tokens: Optional[int] = None, |
| decoder_start_token_id: Optional[int] = None, |
| use_cache: Optional[bool] = None, |
| num_beam_groups: Optional[int] = None, |
| diversity_penalty: Optional[float] = None, |
| prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, |
| logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(), |
| renormalize_logits: Optional[bool] = None, |
| stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(), |
| constraints: Optional[List[Constraint]] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| output_scores: Optional[bool] = None, |
| return_dict_in_generate: Optional[bool] = None, |
| forced_bos_token_id: Optional[int] = None, |
| forced_eos_token_id: Optional[int] = None, |
| remove_invalid_values: Optional[bool] = None, |
| synced_gpus: Optional[bool] = False, |
| exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None, |
| use_kg:bool=False, |
| relation_mapper_builder=None, |
| tokenizer=None, |
| max_concepts=2, |
| **model_kwargs, |
| ) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]: |
| |
| input_ids = model_input['input_ids'] |
| if use_kg: |
| all_constraints = [] |
| for context_neighbours in neighbours_contexts: |
| |
| |
| context_neighbours = [f' {concept}' for concept in context_neighbours if len(concept) > 3] |
| n_size_chuncks = len(context_neighbours) // max_concepts |
| n_size_chuncks = n_size_chuncks if n_size_chuncks > 0 else 1 |
| sub_concepts_collection = list(get_jump_chunks(context_neighbours, jump=n_size_chuncks)) |
| constraints = [] |
| for sub_concepts in sub_concepts_collection[:max_concepts]: |
| flexible_words_ids: List[List[int]] = tokenizer(sub_concepts, add_special_tokens=False).input_ids |
| |
| flexible_words_ids = [[word_ids[0]] for word_ids in flexible_words_ids] |
| disjunctive_set = list(map(list, set(map(frozenset, flexible_words_ids)))) |
|
|
| |
| |
| |
| |
| if not any(disjunctive_set): |
| continue |
| constraint = DisjunctiveConstraint(disjunctive_set) |
| constraints.append(constraint) |
| if not any(constraints): |
| constraints=None |
| all_constraints.append(constraints) |
| else: |
| all_constraints = None |
| if not all_constraints: |
| all_constraints = None |
|
|
| generated_answers_encoded = [] |
| |
| for i, contraints in enumerate(all_constraints): |
| |
| if "input_commonsense_relations" in model_input: |
| |
| model_kwargs["relation_inputs"] = model_input.get("input_commonsense_relations")[i].unsqueeze(0).to(input_ids.device) |
| |
| model_kwargs["attention_mask"] = model_input.get("attention_mask")[i].unsqueeze(0).to(input_ids.device) |
| gen = self.generate(input_ids=input_ids[i].unsqueeze(0), |
| constraints=contraints, |
| min_length=min_length, |
| |
| do_sample=do_sample, |
| early_stopping=early_stopping, |
| |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| |
| no_repeat_ngram_size=no_repeat_ngram_size, |
| num_return_sequences=num_return_sequences, |
| return_dict_in_generate=return_dict_in_generate, |
| output_attentions=output_attentions, |
| output_scores=output_scores, |
| **model_kwargs) |
| |
| |
| generated_answers_encoded.append(gen[0].detach().cpu()) |
| |
| |
| return torch.LongTensor(pad_sequence(generated_answers_encoded, batch_first=True, padding_value=tokenizer.pad_token_id)).to(input_ids.device) |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| inputs: Optional[torch.Tensor] = None, |
| max_length: Optional[int] = None, |
| min_length: Optional[int] = None, |
| do_sample: Optional[bool] = None, |
| early_stopping: Optional[bool] = None, |
| num_beams: Optional[int] = None, |
| temperature: Optional[float] = None, |
| top_k: Optional[int] = None, |
| top_p: Optional[float] = None, |
| typical_p: Optional[float] = None, |
| repetition_penalty: Optional[float] = None, |
| bad_words_ids: Optional[Iterable[int]] = None, |
| force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None, |
| bos_token_id: Optional[int] = None, |
| pad_token_id: Optional[int] = None, |
| eos_token_id: Optional[int] = None, |
| length_penalty: Optional[float] = None, |
| no_repeat_ngram_size: Optional[int] = None, |
| encoder_no_repeat_ngram_size: Optional[int] = None, |
| num_return_sequences: Optional[int] = None, |
| max_time: Optional[float] = None, |
| max_new_tokens: Optional[int] = None, |
| decoder_start_token_id: Optional[int] = None, |
| use_cache: Optional[bool] = None, |
| num_beam_groups: Optional[int] = None, |
| diversity_penalty: Optional[float] = None, |
| prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, |
| logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(), |
| renormalize_logits: Optional[bool] = None, |
| stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(), |
| constraints: Optional[List[Constraint]] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| output_scores: Optional[bool] = None, |
| return_dict_in_generate: Optional[bool] = None, |
| forced_bos_token_id: Optional[int] = None, |
| forced_eos_token_id: Optional[int] = None, |
| remove_invalid_values: Optional[bool] = None, |
| synced_gpus: Optional[bool] = False, |
| exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None, |
| **model_kwargs, |
| ) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]: |
| r""" |
| |
| Generates sequences of token ids for models with a language modeling head. The method supports the following |
| generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models: |
| |
| - *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and |
| `do_sample=False`. |
| - *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and |
| `do_sample=True`. |
| - *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and |
| `do_sample=False`. |
| - *beam-search multinomial sampling* by calling [`~generation_utils.GenerationMixin.beam_sample`] if |
| `num_beams>1` and `do_sample=True`. |
| - *diverse beam-search decoding* by calling [`~generation_utils.GenerationMixin.group_beam_search`], if |
| `num_beams>1` and `num_beam_groups>1`. |
| - *constrained beam-search decoding* by calling |
| [`~generation_utils.GenerationMixin.constrained_beam_search`], if `constraints!=None` or |
| `force_words_ids!=None`. |
| |
| <Tip warning={true}> |
| |
| Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name as |
| defined in the model's config (`config.json`) which in turn defaults to the |
| [`~modeling_utils.PretrainedConfig`] of the model. |
| |
| </Tip> |
| |
| Most of these parameters are explained in more detail in [this blog |
| post](https://huggingface.co/blog/how-to-generate). |
| |
| Parameters: |
| inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): |
| The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the |
| method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` |
| should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of |
| `input_ids`, `input_values`, `input_features`, or `pixel_values`. |
| max_length (`int`, *optional*, defaults to `model.config.max_length`): |
| The maximum length of the sequence to be generated. |
| max_new_tokens (`int`, *optional*, defaults to None): |
| The maximum numbers of tokens to generate, ignore the current number of tokens. Use either |
| `max_new_tokens` or `max_length` but not both, they serve the same purpose. |
| min_length (`int`, *optional*, defaults to 10): |
| The minimum length of the sequence to be generated. |
| do_sample (`bool`, *optional*, defaults to `False`): |
| Whether or not to use sampling ; use greedy decoding otherwise. |
| early_stopping (`bool`, *optional*, defaults to `False`): |
| Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not. |
| num_beams (`int`, *optional*, defaults to 1): |
| Number of beams for beam search. 1 means no beam search. |
| temperature (`float`, *optional*, defaults to 1.0): |
| The value used to module the next token probabilities. |
| top_k (`int`, *optional*, defaults to 50): |
| The number of highest probability vocabulary tokens to keep for top-k-filtering. |
| top_p (`float`, *optional*, defaults to 1.0): |
| If set to float < 1, only the most probable tokens with probabilities that add up to `top_p` or higher |
| are kept for generation. |
| repetition_penalty (`float`, *optional*, defaults to 1.0): |
| The parameter for repetition penalty. 1.0 means no penalty. See [this |
| paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. |
| pad_token_id (`int`, *optional*): |
| The id of the *padding* token. |
| bos_token_id (`int`, *optional*): |
| The id of the *beginning-of-sequence* token. |
| eos_token_id (`int`, *optional*): |
| The id of the *end-of-sequence* token. |
| length_penalty (`float`, *optional*, defaults to 1.0): |
| Exponential penalty to the length. 1.0 means that the beam score is penalized by the sequence length. |
| 0.0 means no penalty. Set to values < 0.0 in order to encourage the model to generate longer |
| sequences, to a value > 0.0 in order to encourage the model to produce shorter sequences. |
| no_repeat_ngram_size (`int`, *optional*, defaults to 0): |
| If set to int > 0, all ngrams of that size can only occur once. |
| encoder_no_repeat_ngram_size (`int`, *optional*, defaults to 0): |
| If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the |
| `decoder_input_ids`. |
| bad_words_ids(`List[List[int]]`, *optional*): |
| List of token ids that are not allowed to be generated. In order to get the token ids of the words that |
| should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True, |
| add_special_tokens=False).input_ids`. |
| force_words_ids(`List[List[int]]` or `List[List[List[int]]]`, *optional*): |
| List of token ids that must be generated. If given a `List[List[int]]`, this is treated as a simple |
| list of words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`, |
| this triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), |
| where one can allow different forms of each word. |
| num_return_sequences(`int`, *optional*, defaults to 1): |
| The number of independently computed returned sequences for each element in the batch. |
| max_time(`float`, *optional*, defaults to None): |
| The maximum amount of time you allow the computation to run for in seconds. generation will still |
| finish the current pass after allocated time has been passed. |
| attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Mask to avoid performing attention on padding token indices. Mask values are in `[0, 1]`, 1 for tokens |
| that are not masked, and 0 for masked tokens. If not provided, will default to a tensor the same shape |
| as `input_ids` that masks the pad token. [What are attention masks?](../glossary#attention-mask) |
| decoder_start_token_id (`int`, *optional*): |
| If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token. |
| use_cache: (`bool`, *optional*, defaults to `True`): |
| Whether or not the model should use the past last key/values attentions (if applicable to the model) to |
| speed up decoding. |
| num_beam_groups (`int`, *optional*, defaults to 1): |
| Number of groups to divide `num_beams` into in order to ensure diversity among different groups of |
| beams. [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details. |
| diversity_penalty (`float`, *optional*, defaults to 0.0): |
| This value is subtracted from a beam's score if it generates a token same as any beam from other group |
| at a particular time. Note that `diversity_penalty` is only effective if `group beam search` is |
| enabled. |
| prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): |
| If provided, this function constraints the beam search to allowed tokens only at each step. If not |
| provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and |
| `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned |
| on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful |
| for constrained generation conditioned on the prefix, as described in [Autoregressive Entity |
| Retrieval](https://arxiv.org/abs/2010.00904). |
| logits_processor (`LogitsProcessorList`, *optional*): |
| Custom logits processors that complement the default logits processors built from arguments and a |
| model's config. If a logit processor is passed that is already created with the arguments or a model's |
| config an error is thrown. This feature is intended for advanced users. |
| renormalize_logits: (`bool`, *optional*, defaults to `False`): |
| Whether to renormalize the logits after applying all the logits processors or warpers (including the |
| custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the |
| score logits are normalized but some logit processors or warpers break the normalization. |
| stopping_criteria (`StoppingCriteriaList`, *optional*): |
| Custom stopping criteria that complement the default stopping criteria built from arguments and a |
| model's config. If a stopping criteria is passed that is already created with the arguments or a |
| model's config an error is thrown. This feature is intended for advanced users. |
| constraints (`List[Constraint]`, *optional*): |
| Custom constraints that can be added to the generation to ensure that the output will contain the use |
| of certain tokens as defined by `Constraint` objects, in the most sensible way possible. |
| output_attentions (`bool`, *optional*, defaults to `False`): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
| returned tensors for more details. |
| output_hidden_states (`bool`, *optional*, defaults to `False`): |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors |
| for more details. |
| output_scores (`bool`, *optional*, defaults to `False`): |
| Whether or not to return the prediction scores. See `scores` under returned tensors for more details. |
| return_dict_in_generate (`bool`, *optional*, defaults to `False`): |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| forced_bos_token_id (`int`, *optional*): |
| The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful |
| for multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be |
| the target language token. |
| forced_eos_token_id (`int`, *optional*): |
| The id of the token to force as the last generated token when `max_length` is reached. |
| remove_invalid_values (`bool`, *optional*): |
| Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to |
| crash. Note that using `remove_invalid_values` can slow down generation. |
| synced_gpus (`bool`, *optional*, defaults to `False`): |
| Whether to continue running the while loop until max_length (needed for ZeRO stage 3) |
| exponential_decay_length_penalty (`tuple(int, float)`, *optional*): |
| This Tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been |
| generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates |
| where penalty starts and `decay_factor` represents the factor of exponential decay |
| |
| model_kwargs: |
| Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model |
| is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs |
| should be prefixed with *decoder_*. |
| |
| Return: |
| [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` |
| or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. |
| |
| If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible |
| [`~utils.ModelOutput`] types are: |
| |
| - [`~generation_utils.GreedySearchDecoderOnlyOutput`], |
| - [`~generation_utils.SampleDecoderOnlyOutput`], |
| - [`~generation_utils.BeamSearchDecoderOnlyOutput`], |
| - [`~generation_utils.BeamSampleDecoderOnlyOutput`] |
| |
| If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible |
| [`~utils.ModelOutput`] types are: |
| |
| - [`~generation_utils.GreedySearchEncoderDecoderOutput`], |
| - [`~generation_utils.SampleEncoderDecoderOutput`], |
| - [`~generation_utils.BeamSearchEncoderDecoderOutput`], |
| - [`~generation_utils.BeamSampleEncoderDecoderOutput`] |
| |
| Examples: |
| |
| Greedy Decoding: |
| |
| ```python |
| >>> from transformers import AutoTokenizer, AutoModelForCausalLM |
| |
| >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") |
| >>> model = AutoModelForCausalLM.from_pretrained("gpt2") |
| |
| >>> prompt = "Today I believe we can finally" |
| >>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids |
| |
| >>> # generate up to 30 tokens |
| >>> outputs = model.generate(input_ids, do_sample=False, max_length=30) |
| >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) |
| ['Today I believe we can finally get to the point where we can make a difference in the lives of the people of the United States of America.\n'] |
| ``` |
| |
| Multinomial Sampling: |
| |
| ```python |
| >>> from transformers import AutoTokenizer, AutoModelForCausalLM |
| >>> import torch |
| |
| >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") |
| >>> model = AutoModelForCausalLM.from_pretrained("gpt2") |
| |
| >>> prompt = "Today I believe we can finally" |
| >>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids |
| |
| >>> # sample up to 30 tokens |
| >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT |
| >>> outputs = model.generate(input_ids, do_sample=True, max_length=30) |
| >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) |
| ['Today I believe we can finally get rid of discrimination," said Rep. Mark Pocan (D-Wis.).\n\n"Just look at the'] |
| ``` |
| |
| Beam-search decoding: |
| |
| ```python |
| >>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
| |
| >>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de") |
| >>> model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-de") |
| |
| >>> sentence = "Paris is one of the densest populated areas in Europe." |
| >>> input_ids = tokenizer(sentence, return_tensors="pt").input_ids |
| |
| >>> outputs = model.generate(input_ids) |
| >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) |
| ['Paris ist eines der dichtesten besiedelten Gebiete Europas.'] |
| ```""" |
| |
| bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id |
| num_beams = num_beams if num_beams is not None else self.config.num_beams |
| length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty |
| early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping |
| num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups |
| do_sample = do_sample if do_sample is not None else self.config.do_sample |
| num_return_sequences = ( |
| num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences |
| ) |
|
|
| pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id |
| eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id |
|
|
| if eos_token_id is None and hasattr(self.config, "decoder"): |
| eos_token_id = self.config.decoder.eos_token_id |
|
|
| if pad_token_id is None and eos_token_id is not None: |
| |
| print(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") |
| pad_token_id = eos_token_id |
|
|
| output_scores = output_scores if output_scores is not None else self.config.output_scores |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict_in_generate = ( |
| return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate |
| ) |
|
|
| |
| |
| |
| |
| |
| inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, bos_token_id, model_kwargs) |
| batch_size = inputs_tensor.shape[0] |
|
|
| |
| model_kwargs["output_attentions"] = output_attentions |
| model_kwargs["output_hidden_states"] = output_hidden_states |
| model_kwargs["use_cache"] = use_cache |
|
|
| accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) |
| requires_attention_mask = "encoder_outputs" not in model_kwargs |
|
|
| if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: |
| model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( |
| inputs_tensor, pad_token_id, eos_token_id |
| ) |
|
|
| if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: |
| |
| |
| model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( |
| inputs_tensor, model_kwargs, model_input_name |
| ) |
|
|
| |
| if self.config.is_encoder_decoder: |
| input_ids = self._prepare_decoder_input_ids_for_generation( |
| batch_size, |
| decoder_start_token_id=decoder_start_token_id, |
| bos_token_id=bos_token_id, |
| model_kwargs=model_kwargs, |
| device=inputs_tensor.device, |
| ) |
| else: |
| |
| input_ids = inputs_tensor |
|
|
| input_ids_seq_length = input_ids.shape[-1] |
|
|
| |
| |
| if max_length is None and max_new_tokens is not None: |
| max_length = max_new_tokens + input_ids_seq_length |
| elif max_length is not None and max_new_tokens is not None: |
| |
| warnings.warn( |
| "Both `max_length` and `max_new_tokens` have been set " |
| f"but they serve the same purpose. `max_length` {max_length} " |
| f"will take priority over `max_new_tokens` {max_new_tokens}.", |
| UserWarning, |
| ) |
| |
| max_length = max_length if max_length is not None else self.config.max_length |
| min_length = min_length if min_length is not None else self.config.min_length |
|
|
| if min_length is not None and min_length > max_length: |
| raise ValueError( |
| f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum " |
| f"length ({max_length})" |
| ) |
| if input_ids_seq_length >= max_length: |
| input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" |
| print( |
| f"Input length of {input_ids_string} is {input_ids_seq_length}, but ``max_length`` is set to {max_length}. " |
| "This can lead to unexpected behavior. You should consider increasing ``config.max_length`` or ``max_length``." |
| ) |
|
|
| |
| is_constraint_gen_mode = constraints is not None or force_words_ids is not None |
| is_greedy_gen_mode = ( |
| (num_beams == 1) and (num_beam_groups == 1) and do_sample is False and not is_constraint_gen_mode |
| ) |
| is_sample_gen_mode = ( |
| (num_beams == 1) and (num_beam_groups == 1) and do_sample is True and not is_constraint_gen_mode |
| ) |
| is_beam_gen_mode = ( |
| (num_beams > 1) and (num_beam_groups == 1) and do_sample is False and not is_constraint_gen_mode |
| ) |
| is_beam_sample_gen_mode = ( |
| (num_beams > 1) and (num_beam_groups == 1) and do_sample is True and not is_constraint_gen_mode |
| ) |
| is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1) and not is_constraint_gen_mode |
|
|
| if num_beam_groups > num_beams: |
| raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`") |
| if is_group_beam_gen_mode and do_sample is True: |
| raise ValueError( |
| "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`." |
| ) |
|
|
| |
| logits_processor = self._get_logits_processor( |
| repetition_penalty=repetition_penalty, |
| no_repeat_ngram_size=no_repeat_ngram_size, |
| encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size, |
| input_ids_seq_length=input_ids_seq_length, |
| encoder_input_ids=inputs_tensor, |
| bad_words_ids=bad_words_ids, |
| min_length=min_length, |
| max_length=max_length, |
| eos_token_id=eos_token_id, |
| forced_bos_token_id=forced_bos_token_id, |
| forced_eos_token_id=forced_eos_token_id, |
| prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, |
| num_beams=num_beams, |
| num_beam_groups=num_beam_groups, |
| diversity_penalty=diversity_penalty, |
| remove_invalid_values=remove_invalid_values, |
| exponential_decay_length_penalty=exponential_decay_length_penalty, |
| logits_processor=logits_processor, |
| renormalize_logits=renormalize_logits, |
| ) |
|
|
| |
| stopping_criteria = self._get_stopping_criteria( |
| max_length=max_length, max_time=max_time, stopping_criteria=stopping_criteria |
| ) |
|
|
| |
| if is_greedy_gen_mode: |
| if num_return_sequences > 1: |
| raise ValueError( |
| f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search." |
| ) |
|
|
| |
| return self.greedy_search( |
| input_ids, |
| logits_processor=logits_processor, |
| stopping_criteria=stopping_criteria, |
| pad_token_id=pad_token_id, |
| eos_token_id=eos_token_id, |
| output_scores=output_scores, |
| return_dict_in_generate=return_dict_in_generate, |
| synced_gpus=synced_gpus, |
| **model_kwargs, |
| ) |
|
|
| elif is_sample_gen_mode: |
| |
| logits_warper = self._get_logits_warper( |
| top_k=top_k, |
| top_p=top_p, |
| typical_p=typical_p, |
| temperature=temperature, |
| num_beams=num_beams, |
| renormalize_logits=renormalize_logits, |
| ) |
|
|
| |
| input_ids, model_kwargs = self._expand_inputs_for_generation( |
| input_ids, |
| expand_size=num_return_sequences, |
| is_encoder_decoder=self.config.is_encoder_decoder, |
| **model_kwargs, |
| ) |
|
|
| |
| return self.sample( |
| input_ids, |
| logits_processor=logits_processor, |
| logits_warper=logits_warper, |
| stopping_criteria=stopping_criteria, |
| pad_token_id=pad_token_id, |
| eos_token_id=eos_token_id, |
| output_scores=output_scores, |
| return_dict_in_generate=return_dict_in_generate, |
| synced_gpus=synced_gpus, |
| **model_kwargs, |
| ) |
|
|
| elif is_beam_gen_mode: |
| if num_return_sequences > num_beams: |
| raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") |
|
|
| if stopping_criteria.max_length is None: |
| raise ValueError("`max_length` needs to be a stopping_criteria for now.") |
|
|
| |
| beam_scorer = BeamSearchScorer( |
| batch_size=batch_size, |
| num_beams=num_beams, |
| device=inputs_tensor.device, |
| length_penalty=length_penalty, |
| do_early_stopping=early_stopping, |
| num_beam_hyps_to_keep=num_return_sequences, |
| ) |
| |
| input_ids, model_kwargs = self._expand_inputs_for_generation( |
| input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs |
| ) |
| |
| return self.beam_search( |
| input_ids, |
| beam_scorer, |
| logits_processor=logits_processor, |
| stopping_criteria=stopping_criteria, |
| pad_token_id=pad_token_id, |
| eos_token_id=eos_token_id, |
| output_scores=output_scores, |
| return_dict_in_generate=return_dict_in_generate, |
| synced_gpus=synced_gpus, |
| **model_kwargs, |
| ) |
|
|
| elif is_beam_sample_gen_mode: |
| |
| logits_warper = self._get_logits_warper( |
| top_k=top_k, |
| top_p=top_p, |
| typical_p=typical_p, |
| temperature=temperature, |
| num_beams=num_beams, |
| renormalize_logits=renormalize_logits, |
| ) |
|
|
| if stopping_criteria.max_length is None: |
| raise ValueError("`max_length` needs to be a stopping_criteria for now.") |
| |
| beam_scorer = BeamSearchScorer( |
| batch_size=batch_size * num_return_sequences, |
| num_beams=num_beams, |
| device=inputs_tensor.device, |
| length_penalty=length_penalty, |
| do_early_stopping=early_stopping, |
| ) |
|
|
| |
| input_ids, model_kwargs = self._expand_inputs_for_generation( |
| input_ids, |
| expand_size=num_beams * num_return_sequences, |
| is_encoder_decoder=self.config.is_encoder_decoder, |
| **model_kwargs, |
| ) |
|
|
| |
| return self.beam_sample( |
| input_ids, |
| beam_scorer, |
| logits_processor=logits_processor, |
| logits_warper=logits_warper, |
| stopping_criteria=stopping_criteria, |
| pad_token_id=pad_token_id, |
| eos_token_id=eos_token_id, |
| output_scores=output_scores, |
| return_dict_in_generate=return_dict_in_generate, |
| synced_gpus=synced_gpus, |
| **model_kwargs, |
| ) |
|
|
| elif is_group_beam_gen_mode: |
| if num_return_sequences > num_beams: |
| raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") |
|
|
| if num_beams % num_beam_groups != 0: |
| raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.") |
|
|
| if stopping_criteria.max_length is None: |
| raise ValueError("`max_length` needs to be a stopping_criteria for now.") |
|
|
| |
| beam_scorer = BeamSearchScorer( |
| batch_size=batch_size, |
| num_beams=num_beams, |
| max_length=stopping_criteria.max_length, |
| device=inputs_tensor.device, |
| length_penalty=length_penalty, |
| do_early_stopping=early_stopping, |
| num_beam_hyps_to_keep=num_return_sequences, |
| num_beam_groups=num_beam_groups, |
| ) |
| |
| input_ids, model_kwargs = self._expand_inputs_for_generation( |
| input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs |
| ) |
| |
| return self.group_beam_search( |
| input_ids, |
| beam_scorer, |
| logits_processor=logits_processor, |
| stopping_criteria=stopping_criteria, |
| pad_token_id=pad_token_id, |
| eos_token_id=eos_token_id, |
| output_scores=output_scores, |
| return_dict_in_generate=return_dict_in_generate, |
| synced_gpus=synced_gpus, |
| **model_kwargs, |
| ) |
|
|
| elif is_constraint_gen_mode: |
| if num_return_sequences > num_beams: |
| raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") |
|
|
| if stopping_criteria.max_length is None: |
| raise ValueError("`max_length` needs to be a stopping_criteria for now.") |
|
|
| if num_beams <= 1: |
| raise ValueError("`num_beams` needs to be greater than 1 for constrained genertation.") |
|
|
| if do_sample: |
| raise ValueError("`do_sample` needs to be false for constrained generation.") |
|
|
| if num_beam_groups is not None and num_beam_groups > 1: |
| raise ValueError("`num_beam_groups` not supported yet for constrained generation.") |
|
|
| final_constraints = [] |
| if constraints is not None: |
| final_constraints = constraints |
|
|
| if force_words_ids is not None: |
|
|
| def typeerror(): |
| raise ValueError( |
| "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`" |
| f"of positive integers, but is {force_words_ids}." |
| ) |
|
|
| if not isinstance(force_words_ids, list) or len(force_words_ids) == 0: |
| typeerror() |
|
|
| for word_ids in force_words_ids: |
| if isinstance(word_ids[0], list): |
| if not isinstance(word_ids, list) or len(word_ids) == 0: |
| typeerror() |
| if any(not isinstance(token_ids, list) for token_ids in word_ids): |
| typeerror() |
| if any( |
| any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids) |
| for token_ids in word_ids |
| ): |
| typeerror() |
|
|
| constraint = DisjunctiveConstraint(word_ids) |
| else: |
| if not isinstance(word_ids, list) or len(word_ids) == 0: |
| typeerror() |
| if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids): |
| typeerror() |
|
|
| constraint = PhrasalConstraint(word_ids) |
| final_constraints.append(constraint) |
|
|
| |
| constrained_beam_scorer = ConstrainedBeamSearchScorer( |
| constraints=final_constraints, |
| batch_size=batch_size, |
| num_beams=num_beams, |
| device=inputs_tensor.device, |
| length_penalty=length_penalty, |
| do_early_stopping=early_stopping, |
| num_beam_hyps_to_keep=num_return_sequences, |
| ) |
| |
| input_ids, model_kwargs = self._expand_inputs_for_generation( |
| input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs |
| ) |
| |
| return self.constrained_beam_search( |
| input_ids, |
| constrained_beam_scorer=constrained_beam_scorer, |
| logits_processor=logits_processor, |
| stopping_criteria=stopping_criteria, |
| pad_token_id=pad_token_id, |
| eos_token_id=eos_token_id, |
| output_scores=output_scores, |
| return_dict_in_generate=return_dict_in_generate, |
| synced_gpus=synced_gpus, |
| **model_kwargs, |
| ) |
|
|
| def greedy_search( |
| self, |
| input_ids: torch.LongTensor, |
| logits_processor: Optional[LogitsProcessorList] = None, |
| stopping_criteria: Optional[StoppingCriteriaList] = None, |
| max_length: Optional[int] = None, |
| pad_token_id: Optional[int] = None, |
| eos_token_id: Optional[int] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| output_scores: Optional[bool] = None, |
| return_dict_in_generate: Optional[bool] = None, |
| synced_gpus: Optional[bool] = False, |
| **model_kwargs, |
| ) -> Union[GreedySearchOutput, torch.LongTensor]: |
| r""" |
| Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be |
| used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. |
| |
| Parameters: |
| |
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
| The sequence used as a prompt for the generation. |
| logits_processor (`LogitsProcessorList`, *optional*): |
| An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] |
| used to modify the prediction scores of the language modeling head applied at each generation step. |
| stopping_criteria (`StoppingCriteriaList`, *optional*): |
| An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] |
| used to tell if the generation loop should stop. |
| |
| max_length (`int`, *optional*, defaults to 20): |
| **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated |
| tokens. The maximum length of the sequence to be generated. |
| pad_token_id (`int`, *optional*): |
| The id of the *padding* token. |
| eos_token_id (`int`, *optional*): |
| The id of the *end-of-sequence* token. |
| output_attentions (`bool`, *optional*, defaults to `False`): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
| returned tensors for more details. |
| output_hidden_states (`bool`, *optional*, defaults to `False`): |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors |
| for more details. |
| output_scores (`bool`, *optional*, defaults to `False`): |
| Whether or not to return the prediction scores. See `scores` under returned tensors for more details. |
| return_dict_in_generate (`bool`, *optional*, defaults to `False`): |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| synced_gpus (`bool`, *optional*, defaults to `False`): |
| Whether to continue running the while loop until max_length (needed for ZeRO stage 3) |
| model_kwargs: |
| Additional model specific keyword arguments will be forwarded to the `forward` function of the model. |
| If model is an encoder-decoder model the kwargs should include `encoder_outputs`. |
| |
| Return: |
| [`~generation_utils.GreedySearchDecoderOnlyOutput`], [`~generation_utils.GreedySearchEncoderDecoderOutput`] |
| or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a |
| [`~generation_utils.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and |
| `return_dict_in_generate=True` or a [`~generation_utils.GreedySearchEncoderDecoderOutput`] if |
| `model.config.is_encoder_decoder=True`. |
| |
| Examples: |
| |
| ```python |
| >>> from transformers import ( |
| ... AutoTokenizer, |
| ... AutoModelForCausalLM, |
| ... LogitsProcessorList, |
| ... MinLengthLogitsProcessor, |
| ... StoppingCriteriaList, |
| ... MaxLengthCriteria, |
| ... ) |
| |
| >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") |
| >>> model = AutoModelForCausalLM.from_pretrained("gpt2") |
| |
| >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token |
| >>> model.config.pad_token_id = model.config.eos_token_id |
| |
| >>> input_prompt = "It might be possible to" |
| >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids |
| |
| >>> # instantiate logits processors |
| >>> logits_processor = LogitsProcessorList( |
| ... [ |
| ... MinLengthLogitsProcessor(10, eos_token_id=model.config.eos_token_id), |
| ... ] |
| ... ) |
| >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) |
| |
| >>> outputs = model.greedy_search( |
| ... input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria |
| ... ) |
| |
| >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) |
| ["It might be possible to get a better understanding of the nature of the problem, but it's not"] |
| ```""" |
| |
| logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() |
| stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() |
| if max_length is not None: |
| warnings.warn( |
| "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", |
| UserWarning, |
| ) |
| stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) |
| pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id |
| eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id |
| output_scores = output_scores if output_scores is not None else self.config.output_scores |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict_in_generate = ( |
| return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate |
| ) |
|
|
| |
| scores = () if (return_dict_in_generate and output_scores) else None |
| decoder_attentions = () if (return_dict_in_generate and output_attentions) else None |
| cross_attentions = () if (return_dict_in_generate and output_attentions) else None |
| decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None |
|
|
| |
| if return_dict_in_generate and self.config.is_encoder_decoder: |
| encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None |
| encoder_hidden_states = ( |
| model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None |
| ) |
|
|
| |
| unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) |
| cur_len = input_ids.shape[-1] |
|
|
| this_peer_finished = False |
| while True: |
|
|
| if synced_gpus: |
| |
| |
| this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) |
| |
| dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) |
| |
| if this_peer_finished_flag.item() == 0.0: |
| break |
|
|
| |
| model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) |
|
|
| |
| outputs = self( |
| **model_inputs, |
| return_dict=True, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| ) |
|
|
| if synced_gpus and this_peer_finished: |
| cur_len = cur_len + 1 |
| continue |
|
|
| next_token_logits = outputs.logits[:, -1, :] |
|
|
| |
| if return_dict_in_generate: |
| if output_scores: |
| scores += (next_token_logits,) |
| if output_attentions: |
| decoder_attentions += ( |
| (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) |
| ) |
| if self.config.is_encoder_decoder: |
| cross_attentions += (outputs.cross_attentions,) |
|
|
| if output_hidden_states: |
| decoder_hidden_states += ( |
| (outputs.decoder_hidden_states,) |
| if self.config.is_encoder_decoder |
| else (outputs.hidden_states,) |
| ) |
|
|
| |
| next_tokens_scores = logits_processor(input_ids, next_token_logits) |
|
|
| |
| next_tokens = torch.argmax(next_tokens_scores, dim=-1) |
|
|
| |
| if eos_token_id is not None: |
| if pad_token_id is None: |
| raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") |
| next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) |
|
|
| |
| input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) |
| model_kwargs = self._update_model_kwargs_for_generation( |
| outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder |
| ) |
| cur_len = cur_len + 1 |
|
|
| |
| if eos_token_id is not None: |
| unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) |
|
|
| |
| if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): |
| if not synced_gpus: |
| break |
| else: |
| this_peer_finished = True |
|
|
| if return_dict_in_generate: |
| if self.config.is_encoder_decoder: |
| return GreedySearchEncoderDecoderOutput( |
| sequences=input_ids, |
| scores=scores, |
| encoder_attentions=encoder_attentions, |
| encoder_hidden_states=encoder_hidden_states, |
| decoder_attentions=decoder_attentions, |
| cross_attentions=cross_attentions, |
| decoder_hidden_states=decoder_hidden_states, |
| ) |
| else: |
| return GreedySearchDecoderOnlyOutput( |
| sequences=input_ids, |
| scores=scores, |
| attentions=decoder_attentions, |
| hidden_states=decoder_hidden_states, |
| ) |
| else: |
| return input_ids |
|
|
| def sample( |
| self, |
| input_ids: torch.LongTensor, |
| logits_processor: Optional[LogitsProcessorList] = None, |
| stopping_criteria: Optional[StoppingCriteriaList] = None, |
| logits_warper: Optional[LogitsProcessorList] = None, |
| max_length: Optional[int] = None, |
| pad_token_id: Optional[int] = None, |
| eos_token_id: Optional[int] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| output_scores: Optional[bool] = None, |
| return_dict_in_generate: Optional[bool] = None, |
| synced_gpus: Optional[bool] = False, |
| **model_kwargs, |
| ) -> Union[SampleOutput, torch.LongTensor]: |
| r""" |
| Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and |
| can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. |
| |
| Parameters: |
| |
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
| The sequence used as a prompt for the generation. |
| logits_processor (`LogitsProcessorList`, *optional*): |
| An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] |
| used to modify the prediction scores of the language modeling head applied at each generation step. |
| stopping_criteria (`StoppingCriteriaList`, *optional*): |
| An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] |
| used to tell if the generation loop should stop. |
| logits_warper (`LogitsProcessorList`, *optional*): |
| An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used |
| to warp the prediction score distribution of the language modeling head applied before multinomial |
| sampling at each generation step. |
| max_length (`int`, *optional*, defaults to 20): |
| **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated |
| tokens. The maximum length of the sequence to be generated. |
| pad_token_id (`int`, *optional*): |
| The id of the *padding* token. |
| eos_token_id (`int`, *optional*): |
| The id of the *end-of-sequence* token. |
| output_attentions (`bool`, *optional*, defaults to `False`): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
| returned tensors for more details. |
| output_hidden_states (`bool`, *optional*, defaults to `False`): |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors |
| for more details. |
| output_scores (`bool`, *optional*, defaults to `False`): |
| Whether or not to return the prediction scores. See `scores` under returned tensors for more details. |
| return_dict_in_generate (`bool`, *optional*, defaults to `False`): |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| synced_gpus (`bool`, *optional*, defaults to `False`): |
| Whether to continue running the while loop until max_length (needed for ZeRO stage 3) |
| model_kwargs: |
| Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is |
| an encoder-decoder model the kwargs should include `encoder_outputs`. |
| |
| Return: |
| [`~generation_utils.SampleDecoderOnlyOutput`], [`~generation_utils.SampleEncoderDecoderOutput`] or |
| `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a |
| [`~generation_utils.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and |
| `return_dict_in_generate=True` or a [`~generation_utils.SampleEncoderDecoderOutput`] if |
| `model.config.is_encoder_decoder=True`. |
| |
| Examples: |
| |
| ```python |
| >>> from transformers import ( |
| ... AutoTokenizer, |
| ... AutoModelForCausalLM, |
| ... LogitsProcessorList, |
| ... MinLengthLogitsProcessor, |
| ... TopKLogitsWarper, |
| ... TemperatureLogitsWarper, |
| ... StoppingCriteriaList, |
| ... MaxLengthCriteria, |
| ... ) |
| >>> import torch |
| |
| >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") |
| >>> model = AutoModelForCausalLM.from_pretrained("gpt2") |
| |
| >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token |
| >>> model.config.pad_token_id = model.config.eos_token_id |
| |
| >>> input_prompt = "Today is a beautiful day, and" |
| >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids |
| |
| >>> # instantiate logits processors |
| >>> logits_processor = LogitsProcessorList( |
| ... [ |
| ... MinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id), |
| ... ] |
| ... ) |
| >>> # instantiate logits processors |
| >>> logits_warper = LogitsProcessorList( |
| ... [ |
| ... TopKLogitsWarper(50), |
| ... TemperatureLogitsWarper(0.7), |
| ... ] |
| ... ) |
| |
| >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) |
| |
| >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT |
| >>> outputs = model.sample( |
| ... input_ids, |
| ... logits_processor=logits_processor, |
| ... logits_warper=logits_warper, |
| ... stopping_criteria=stopping_criteria, |
| ... ) |
| |
| >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) |
| ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the'] |
| ```""" |
|
|
| |
| logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() |
| stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() |
| if max_length is not None: |
| warnings.warn( |
| "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", |
| UserWarning, |
| ) |
| stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) |
| logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() |
| pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id |
| eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id |
| output_scores = output_scores if output_scores is not None else self.config.output_scores |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict_in_generate = ( |
| return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate |
| ) |
|
|
| |
| scores = () if (return_dict_in_generate and output_scores) else None |
| decoder_attentions = () if (return_dict_in_generate and output_attentions) else None |
| cross_attentions = () if (return_dict_in_generate and output_attentions) else None |
| decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None |
|
|
| |
| if return_dict_in_generate and self.config.is_encoder_decoder: |
| encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None |
| encoder_hidden_states = ( |
| model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None |
| ) |
|
|
| |
| unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) |
| cur_len = input_ids.shape[-1] |
|
|
| this_peer_finished = False |
| |
| while True: |
|
|
| if synced_gpus: |
| |
| |
| this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) |
| |
| dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) |
| |
| if this_peer_finished_flag.item() == 0.0: |
| break |
|
|
| |
| model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) |
|
|
| |
| outputs = self( |
| **model_inputs, |
| return_dict=True, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| ) |
|
|
| if synced_gpus and this_peer_finished: |
| cur_len = cur_len + 1 |
| continue |
|
|
| next_token_logits = outputs.logits[:, -1, :] |
|
|
| |
| next_token_scores = logits_processor(input_ids, next_token_logits) |
| next_token_scores = logits_warper(input_ids, next_token_scores) |
|
|
| |
| if return_dict_in_generate: |
| if output_scores: |
| scores += (next_token_scores,) |
| if output_attentions: |
| decoder_attentions += ( |
| (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) |
| ) |
| if self.config.is_encoder_decoder: |
| cross_attentions += (outputs.cross_attentions,) |
|
|
| if output_hidden_states: |
| decoder_hidden_states += ( |
| (outputs.decoder_hidden_states,) |
| if self.config.is_encoder_decoder |
| else (outputs.hidden_states,) |
| ) |
|
|
| |
| probs = nn.functional.softmax(next_token_scores, dim=-1) |
| next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) |
|
|
| |
| if eos_token_id is not None: |
| if pad_token_id is None: |
| raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") |
| next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) |
|
|
| |
| input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) |
| model_kwargs = self._update_model_kwargs_for_generation( |
| outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder |
| ) |
| cur_len = cur_len + 1 |
|
|
| |
| if eos_token_id is not None: |
| unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) |
|
|
| |
| if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): |
| if not synced_gpus: |
| break |
| else: |
| this_peer_finished = True |
|
|
| if return_dict_in_generate: |
| if self.config.is_encoder_decoder: |
| return SampleEncoderDecoderOutput( |
| sequences=input_ids, |
| scores=scores, |
| encoder_attentions=encoder_attentions, |
| encoder_hidden_states=encoder_hidden_states, |
| decoder_attentions=decoder_attentions, |
| cross_attentions=cross_attentions, |
| decoder_hidden_states=decoder_hidden_states, |
| ) |
| else: |
| return SampleDecoderOnlyOutput( |
| sequences=input_ids, |
| scores=scores, |
| attentions=decoder_attentions, |
| hidden_states=decoder_hidden_states, |
| ) |
| else: |
| return input_ids |
|
|
| def beam_search( |
| self, |
| input_ids: torch.LongTensor, |
| beam_scorer: BeamScorer, |
| logits_processor: Optional[LogitsProcessorList] = None, |
| stopping_criteria: Optional[StoppingCriteriaList] = None, |
| max_length: Optional[int] = None, |
| pad_token_id: Optional[int] = None, |
| eos_token_id: Optional[int] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| output_scores: Optional[bool] = None, |
| return_dict_in_generate: Optional[bool] = None, |
| synced_gpus: Optional[bool] = False, |
| **model_kwargs, |
| ) -> Union[BeamSearchOutput, torch.LongTensor]: |
| r""" |
| Generates sequences of token ids for models with a language modeling head using **beam search decoding** and |
| can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. |
| |
| Parameters: |
| |
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
| The sequence used as a prompt for the generation. |
| beam_scorer (`BeamScorer`): |
| An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and |
| sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. |
| logits_processor (`LogitsProcessorList`, *optional*): |
| An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] |
| used to modify the prediction scores of the language modeling head applied at each generation step. |
| stopping_criteria (`StoppingCriteriaList`, *optional*): |
| An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] |
| used to tell if the generation loop should stop. |
| max_length (`int`, *optional*, defaults to 20): |
| **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated |
| tokens. The maximum length of the sequence to be generated. |
| pad_token_id (`int`, *optional*): |
| The id of the *padding* token. |
| eos_token_id (`int`, *optional*): |
| The id of the *end-of-sequence* token. |
| output_attentions (`bool`, *optional*, defaults to `False`): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
| returned tensors for more details. |
| output_hidden_states (`bool`, *optional*, defaults to `False`): |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors |
| for more details. |
| output_scores (`bool`, *optional*, defaults to `False`): |
| Whether or not to return the prediction scores. See `scores` under returned tensors for more details. |
| return_dict_in_generate (`bool`, *optional*, defaults to `False`): |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| synced_gpus (`bool`, *optional*, defaults to `False`): |
| Whether to continue running the while loop until max_length (needed for ZeRO stage 3) |
| model_kwargs: |
| Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is |
| an encoder-decoder model the kwargs should include `encoder_outputs`. |
| |
| Return: |
| [`generation_utilsBeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or |
| `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a |
| [`~generation_utils.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and |
| `return_dict_in_generate=True` or a [`~generation_utils.BeamSearchEncoderDecoderOutput`] if |
| `model.config.is_encoder_decoder=True`. |
| |
| |
| Examples: |
| |
| ```python |
| >>> from transformers import ( |
| ... AutoTokenizer, |
| ... AutoModelForSeq2SeqLM, |
| ... LogitsProcessorList, |
| ... MinLengthLogitsProcessor, |
| ... BeamSearchScorer, |
| ... ) |
| >>> import torch |
| |
| >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") |
| >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") |
| |
| >>> encoder_input_str = "translate English to German: How old are you?" |
| >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids |
| |
| |
| >>> # lets run beam search using 3 beams |
| >>> num_beams = 3 |
| >>> # define decoder start token ids |
| >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) |
| >>> input_ids = input_ids * model.config.decoder_start_token_id |
| |
| >>> # add encoder_outputs to model keyword arguments |
| >>> model_kwargs = { |
| ... "encoder_outputs": model.get_encoder()( |
| ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True |
| ... ) |
| ... } |
| |
| >>> # instantiate beam scorer |
| >>> beam_scorer = BeamSearchScorer( |
| ... batch_size=1, |
| ... num_beams=num_beams, |
| ... device=model.device, |
| ... ) |
| |
| >>> # instantiate logits processors |
| >>> logits_processor = LogitsProcessorList( |
| ... [ |
| ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), |
| ... ] |
| ... ) |
| |
| >>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs) |
| |
| >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) |
| ['Wie alt bist du?'] |
| ```""" |
| |
| logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() |
| stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() |
| if max_length is not None: |
| warnings.warn( |
| "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", |
| UserWarning, |
| ) |
| stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) |
| if len(stopping_criteria) == 0: |
| warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) |
| pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id |
| eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id |
| output_scores = output_scores if output_scores is not None else self.config.output_scores |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict_in_generate = ( |
| return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate |
| ) |
|
|
| batch_size = len(beam_scorer._beam_hyps) |
| num_beams = beam_scorer.num_beams |
|
|
| batch_beam_size, cur_len = input_ids.shape |
|
|
| if num_beams * batch_size != batch_beam_size: |
| raise ValueError( |
| f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." |
| ) |
|
|
| |
| scores = () if (return_dict_in_generate and output_scores) else None |
| beam_indices = ( |
| tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None |
| ) |
| decoder_attentions = () if (return_dict_in_generate and output_attentions) else None |
| cross_attentions = () if (return_dict_in_generate and output_attentions) else None |
| decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None |
|
|
| |
| if return_dict_in_generate and self.config.is_encoder_decoder: |
| encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None |
| encoder_hidden_states = ( |
| model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None |
| ) |
|
|
| beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) |
| beam_scores[:, 1:] = -1e9 |
| beam_scores = beam_scores.view((batch_size * num_beams,)) |
|
|
| this_peer_finished = False |
| while True: |
|
|
| if synced_gpus: |
| |
| |
| this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) |
| |
| dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) |
| |
| if this_peer_finished_flag.item() == 0.0: |
| break |
|
|
| model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) |
|
|
| outputs = self( |
| **model_inputs, |
| return_dict=True, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| ) |
|
|
| if synced_gpus and this_peer_finished: |
| cur_len = cur_len + 1 |
| continue |
|
|
| next_token_logits = outputs.logits[:, -1, :] |
| |
| |
| next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) |
| next_token_scores = nn.functional.log_softmax( |
| next_token_logits, dim=-1 |
| ) |
|
|
| |
| next_token_scores_processed = logits_processor(input_ids, next_token_scores) |
| next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) |
|
|
| |
| if return_dict_in_generate: |
| if output_scores: |
| scores += (next_token_scores_processed,) |
| if output_attentions: |
| decoder_attentions += ( |
| (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) |
| ) |
| if self.config.is_encoder_decoder: |
| cross_attentions += (outputs.cross_attentions,) |
|
|
| if output_hidden_states: |
| decoder_hidden_states += ( |
| (outputs.decoder_hidden_states,) |
| if self.config.is_encoder_decoder |
| else (outputs.hidden_states,) |
| ) |
|
|
| |
| vocab_size = next_token_scores.shape[-1] |
| next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) |
|
|
| next_token_scores, next_tokens = torch.topk( |
| next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True |
| ) |
|
|
| next_indices = torch_int_div(next_tokens, vocab_size) |
| next_tokens = next_tokens % vocab_size |
|
|
| |
| beam_outputs = beam_scorer.process( |
| input_ids, |
| next_token_scores, |
| next_tokens, |
| next_indices, |
| pad_token_id=pad_token_id, |
| eos_token_id=eos_token_id, |
| ) |
|
|
| beam_scores = beam_outputs["next_beam_scores"] |
| beam_next_tokens = beam_outputs["next_beam_tokens"] |
| beam_idx = beam_outputs["next_beam_indices"] |
|
|
| input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) |
|
|
| model_kwargs = self._update_model_kwargs_for_generation( |
| outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder |
| ) |
| if model_kwargs["past"] is not None: |
| model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx) |
|
|
| if return_dict_in_generate and output_scores: |
| beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) |
|
|
| |
| cur_len = cur_len + 1 |
|
|
| if beam_scorer.is_done or stopping_criteria(input_ids, scores): |
| if not synced_gpus: |
| break |
| else: |
| this_peer_finished = True |
|
|
| sequence_outputs = beam_scorer.finalize( |
| input_ids, |
| beam_scores, |
| next_tokens, |
| next_indices, |
| pad_token_id=pad_token_id, |
| eos_token_id=eos_token_id, |
| max_length=stopping_criteria.max_length, |
| ) |
|
|
| if return_dict_in_generate: |
| if not output_scores: |
| sequence_outputs["sequence_scores"] = None |
| else: |
| num_return_sequences = beam_scorer.num_beam_hyps_to_keep |
| |
| beam_indices = tuple( |
| (beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size)) |
| ) |
| beam_indices = sum(beam_indices, ()) |
|
|
| if self.config.is_encoder_decoder: |
| return BeamSearchEncoderDecoderOutput( |
| sequences=sequence_outputs["sequences"], |
| sequences_scores=sequence_outputs["sequence_scores"], |
| scores=scores, |
| beam_indices=beam_indices, |
| encoder_attentions=encoder_attentions, |
| encoder_hidden_states=encoder_hidden_states, |
| decoder_attentions=decoder_attentions, |
| cross_attentions=cross_attentions, |
| decoder_hidden_states=decoder_hidden_states, |
| ) |
| else: |
| return BeamSearchDecoderOnlyOutput( |
| sequences=sequence_outputs["sequences"], |
| sequences_scores=sequence_outputs["sequence_scores"], |
| scores=scores, |
| beam_indices=beam_indices, |
| attentions=decoder_attentions, |
| hidden_states=decoder_hidden_states, |
| ) |
| else: |
| return sequence_outputs["sequences"] |
|
|
| def beam_sample( |
| self, |
| input_ids: torch.LongTensor, |
| beam_scorer: BeamScorer, |
| logits_processor: Optional[LogitsProcessorList] = None, |
| stopping_criteria: Optional[StoppingCriteriaList] = None, |
| logits_warper: Optional[LogitsProcessorList] = None, |
| max_length: Optional[int] = None, |
| pad_token_id: Optional[int] = None, |
| eos_token_id: Optional[int] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| output_scores: Optional[bool] = None, |
| return_dict_in_generate: Optional[bool] = None, |
| synced_gpus: Optional[bool] = False, |
| **model_kwargs, |
| ) -> Union[BeamSampleOutput, torch.LongTensor]: |
| r""" |
| Generates sequences of token ids for models with a language modeling head using **beam search multinomial |
| sampling** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. |
| |
| Parameters: |
| |
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
| The sequence used as a prompt for the generation. |
| beam_scorer (`BeamScorer`): |
| A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and |
| sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. |
| logits_processor (`LogitsProcessorList`, *optional*): |
| An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] |
| used to modify the prediction scores of the language modeling head applied at each generation step. |
| stopping_criteria (`StoppingCriteriaList`, *optional*): |
| An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] |
| used to tell if the generation loop should stop. |
| logits_warper (`LogitsProcessorList`, *optional*): |
| An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used |
| to warp the prediction score distribution of the language modeling head applied before multinomial |
| sampling at each generation step. |
| max_length (`int`, *optional*, defaults to 20): |
| **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated |
| tokens. The maximum length of the sequence to be generated. |
| pad_token_id (`int`, *optional*): |
| The id of the *padding* token. |
| eos_token_id (`int`, *optional*): |
| The id of the *end-of-sequence* token. |
| output_attentions (`bool`, *optional*, defaults to `False`): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
| returned tensors for more details. |
| output_hidden_states (`bool`, *optional*, defaults to `False`): |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors |
| for more details. |
| output_scores (`bool`, *optional*, defaults to `False`): |
| Whether or not to return the prediction scores. See `scores` under returned tensors for more details. |
| return_dict_in_generate (`bool`, *optional*, defaults to `False`): |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| synced_gpus (`bool`, *optional*, defaults to `False`): |
| Whether to continue running the while loop until max_length (needed for ZeRO stage 3) |
| model_kwargs: |
| Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is |
| an encoder-decoder model the kwargs should include `encoder_outputs`. |
| |
| Return: |
| [`~generation_utils.BeamSampleDecoderOnlyOutput`], [`~generation_utils.BeamSampleEncoderDecoderOutput`] or |
| `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a |
| [`~generation_utils.BeamSampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and |
| `return_dict_in_generate=True` or a [`~generation_utils.BeamSampleEncoderDecoderOutput`] if |
| `model.config.is_encoder_decoder=True`. |
| |
| Examples: |
| |
| ```python |
| >>> from transformers import ( |
| ... AutoTokenizer, |
| ... AutoModelForSeq2SeqLM, |
| ... LogitsProcessorList, |
| ... MinLengthLogitsProcessor, |
| ... TopKLogitsWarper, |
| ... TemperatureLogitsWarper, |
| ... BeamSearchScorer, |
| ... ) |
| >>> import torch |
| |
| >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") |
| >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") |
| |
| >>> encoder_input_str = "translate English to German: How old are you?" |
| >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids |
| |
| >>> # lets run beam search using 3 beams |
| >>> num_beams = 3 |
| >>> # define decoder start token ids |
| >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) |
| >>> input_ids = input_ids * model.config.decoder_start_token_id |
| |
| >>> # add encoder_outputs to model keyword arguments |
| >>> model_kwargs = { |
| ... "encoder_outputs": model.get_encoder()( |
| ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True |
| ... ) |
| ... } |
| |
| >>> # instantiate beam scorer |
| >>> beam_scorer = BeamSearchScorer( |
| ... batch_size=1, |
| ... max_length=model.config.max_length, |
| ... num_beams=num_beams, |
| ... device=model.device, |
| ... ) |
| |
| >>> # instantiate logits processors |
| >>> logits_processor = LogitsProcessorList( |
| ... [MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id)] |
| ... ) |
| >>> # instantiate logits processors |
| >>> logits_warper = LogitsProcessorList( |
| ... [ |
| ... TopKLogitsWarper(50), |
| ... TemperatureLogitsWarper(0.7), |
| ... ] |
| ... ) |
| |
| >>> outputs = model.beam_sample( |
| ... input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs |
| ... ) |
| |
| >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) |
| ['Wie alt bist du?'] |
| ```""" |
| |
| logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() |
| stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() |
| if max_length is not None: |
| warnings.warn( |
| "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", |
| UserWarning, |
| ) |
| stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) |
| pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id |
| eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id |
| output_scores = output_scores if output_scores is not None else self.config.output_scores |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict_in_generate = ( |
| return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate |
| ) |
|
|
| batch_size = len(beam_scorer._beam_hyps) |
| num_beams = beam_scorer.num_beams |
|
|
| batch_beam_size, cur_len = input_ids.shape |
|
|
| |
| scores = () if (return_dict_in_generate and output_scores) else None |
| beam_indices = ( |
| tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None |
| ) |
| decoder_attentions = () if (return_dict_in_generate and output_attentions) else None |
| cross_attentions = () if (return_dict_in_generate and output_attentions) else None |
| decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None |
|
|
| |
| if return_dict_in_generate and self.config.is_encoder_decoder: |
| encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None |
| encoder_hidden_states = ( |
| model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None |
| ) |
|
|
| beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) |
| beam_scores = beam_scores.view((batch_size * num_beams,)) |
|
|
| this_peer_finished = False |
| while True: |
|
|
| if synced_gpus: |
| |
| |
| this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) |
| |
| dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) |
| |
| if this_peer_finished_flag.item() == 0.0: |
| break |
|
|
| model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) |
|
|
| outputs = self( |
| **model_inputs, |
| return_dict=True, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| ) |
|
|
| if synced_gpus and this_peer_finished: |
| cur_len = cur_len + 1 |
| continue |
|
|
| next_token_logits = outputs.logits[:, -1, :] |
|
|
| |
| |
| next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) |
| next_token_scores = nn.functional.log_softmax( |
| next_token_logits, dim=-1 |
| ) |
|
|
| next_token_scores_processed = logits_processor(input_ids, next_token_scores) |
| next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) |
| next_token_scores = logits_warper(input_ids, next_token_scores) |
|
|
| |
| if return_dict_in_generate: |
| if output_scores: |
| scores += (logits_warper(input_ids, next_token_scores_processed),) |
| if output_attentions: |
| decoder_attentions += ( |
| (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) |
| ) |
| if self.config.is_encoder_decoder: |
| cross_attentions += (outputs.cross_attentions,) |
|
|
| if output_hidden_states: |
| decoder_hidden_states += ( |
| (outputs.decoder_hidden_states,) |
| if self.config.is_encoder_decoder |
| else (outputs.hidden_states,) |
| ) |
|
|
| |
| vocab_size = next_token_scores.shape[-1] |
| next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) |
|
|
| probs = nn.functional.softmax(next_token_scores, dim=-1) |
|
|
| next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) |
| next_token_scores = torch.gather(next_token_scores, -1, next_tokens) |
|
|
| next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1) |
| next_tokens = torch.gather(next_tokens, -1, _indices) |
|
|
| next_indices = torch_int_div(next_tokens, vocab_size) |
| next_tokens = next_tokens % vocab_size |
|
|
| |
| beam_outputs = beam_scorer.process( |
| input_ids, |
| next_token_scores, |
| next_tokens, |
| next_indices, |
| pad_token_id=pad_token_id, |
| eos_token_id=eos_token_id, |
| ) |
| beam_scores = beam_outputs["next_beam_scores"] |
| beam_next_tokens = beam_outputs["next_beam_tokens"] |
| beam_idx = beam_outputs["next_beam_indices"] |
|
|
| input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) |
|
|
| model_kwargs = self._update_model_kwargs_for_generation( |
| outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder |
| ) |
| if model_kwargs["past"] is not None: |
| model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx) |
|
|
| if return_dict_in_generate and output_scores: |
| beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) |
|
|
| |
| cur_len = cur_len + 1 |
|
|
| if beam_scorer.is_done or stopping_criteria(input_ids, scores): |
| if not synced_gpus: |
| break |
| else: |
| this_peer_finished = True |
|
|
| sequence_outputs = beam_scorer.finalize( |
| input_ids, |
| beam_scores, |
| next_tokens, |
| next_indices, |
| pad_token_id=pad_token_id, |
| eos_token_id=eos_token_id, |
| max_length=stopping_criteria.max_length, |
| ) |
|
|
| if return_dict_in_generate: |
| if not output_scores: |
| sequence_outputs["sequence_scores"] = None |
| else: |
| num_return_sequences = beam_scorer.num_beam_hyps_to_keep |
| |
| beam_indices = tuple( |
| (beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size)) |
| ) |
| beam_indices = sum(beam_indices, ()) |
|
|
| if self.config.is_encoder_decoder: |
| return BeamSampleEncoderDecoderOutput( |
| sequences=sequence_outputs["sequences"], |
| sequences_scores=sequence_outputs["sequence_scores"], |
| scores=scores, |
| beam_indices=beam_indices, |
| encoder_attentions=encoder_attentions, |
| encoder_hidden_states=encoder_hidden_states, |
| decoder_attentions=decoder_attentions, |
| cross_attentions=cross_attentions, |
| decoder_hidden_states=decoder_hidden_states, |
| ) |
| else: |
| return BeamSampleDecoderOnlyOutput( |
| sequences=sequence_outputs["sequences"], |
| sequences_scores=sequence_outputs["sequence_scores"], |
| scores=scores, |
| beam_indices=beam_indices, |
| attentions=decoder_attentions, |
| hidden_states=decoder_hidden_states, |
| ) |
| else: |
| return sequence_outputs["sequences"] |
|
|
| def group_beam_search( |
| self, |
| input_ids: torch.LongTensor, |
| beam_scorer: BeamScorer, |
| logits_processor: Optional[LogitsProcessorList] = None, |
| stopping_criteria: Optional[StoppingCriteriaList] = None, |
| max_length: Optional[int] = None, |
| pad_token_id: Optional[int] = None, |
| eos_token_id: Optional[int] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| output_scores: Optional[bool] = None, |
| return_dict_in_generate: Optional[bool] = None, |
| synced_gpus: Optional[bool] = False, |
| **model_kwargs, |
| ): |
| r""" |
| Generates sequences of token ids for models with a language modeling head using **diverse beam search |
| decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. |
| |
| Parameters: |
| |
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
| The sequence used as a prompt for the generation. |
| beam_scorer (`BeamScorer`): |
| An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and |
| sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. |
| logits_processor (`LogitsProcessorList`, *optional*): |
| An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] |
| used to modify the prediction scores of the language modeling head applied at each generation step. |
| stopping_criteria (`StoppingCriteriaList`, *optional*): |
| An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] |
| used to tell if the generation loop should stop. |
| max_length (`int`, *optional*, defaults to 20): |
| **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated |
| tokens. The maximum length of the sequence to be generated. |
| pad_token_id (`int`, *optional*): |
| The id of the *padding* token. |
| eos_token_id (`int`, *optional*): |
| The id of the *end-of-sequence* token. |
| output_attentions (`bool`, *optional*, defaults to `False`): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
| returned tensors for more details. |
| output_hidden_states (`bool`, *optional*, defaults to `False`): |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors |
| for more details. |
| output_scores (`bool`, *optional*, defaults to `False`): |
| Whether or not to return the prediction scores. See `scores` under returned tensors for more details. |
| return_dict_in_generate (`bool`, *optional*, defaults to `False`): |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| synced_gpus (`bool`, *optional*, defaults to `False`): |
| Whether to continue running the while loop until max_length (needed for ZeRO stage 3) |
| |
| model_kwargs: |
| Additional model specific kwargs that will be forwarded to the `forward` function of the model. If |
| model is an encoder-decoder model the kwargs should include `encoder_outputs`. |
| |
| Return: |
| [`~generation_utils.BeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or |
| `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a |
| [`~generation_utils.BeamSearchDecoderOnlyOutput`] if [`~generation_utils.BeamSearchDecoderOnlyOutput`] if |
| `model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a |
| [`~generation_utils.BeamSearchEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. |
| |
| Examples: |
| |
| ```python |
| >>> from transformers import ( |
| ... AutoTokenizer, |
| ... AutoModelForSeq2SeqLM, |
| ... LogitsProcessorList, |
| ... MinLengthLogitsProcessor, |
| ... HammingDiversityLogitsProcessor, |
| ... BeamSearchScorer, |
| ... ) |
| >>> import torch |
| |
| >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") |
| >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") |
| |
| >>> encoder_input_str = "translate English to German: How old are you?" |
| >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids |
| |
| |
| >>> # lets run diverse beam search using 6 beams |
| >>> num_beams = 6 |
| >>> # define decoder start token ids |
| >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) |
| >>> input_ids = input_ids * model.config.decoder_start_token_id |
| |
| >>> # add encoder_outputs to model keyword arguments |
| >>> model_kwargs = { |
| ... "encoder_outputs": model.get_encoder()( |
| ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True |
| ... ) |
| ... } |
| |
| >>> # instantiate beam scorer |
| >>> beam_scorer = BeamSearchScorer( |
| ... batch_size=1, |
| ... max_length=model.config.max_length, |
| ... num_beams=num_beams, |
| ... device=model.device, |
| ... num_beam_groups=3, |
| ... ) |
| |
| >>> # instantiate logits processors |
| >>> logits_processor = LogitsProcessorList( |
| ... [ |
| ... HammingDiversityLogitsProcessor(5.5, num_beams=6, num_beam_groups=3), |
| ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), |
| ... ] |
| ... ) |
| |
| >>> outputs = model.group_beam_search( |
| ... input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs |
| ... ) |
| |
| >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) |
| ['Wie alt bist du?'] |
| ```""" |
| |
| logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() |
| stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() |
| if max_length is not None: |
| warnings.warn( |
| "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", |
| UserWarning, |
| ) |
| stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) |
| pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id |
| eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id |
| output_scores = output_scores if output_scores is not None else self.config.output_scores |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict_in_generate = ( |
| return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate |
| ) |
|
|
| batch_size = len(beam_scorer._beam_hyps) |
| num_beams = beam_scorer.num_beams |
| num_beam_groups = beam_scorer.num_beam_groups |
| num_sub_beams = num_beams // num_beam_groups |
| device = input_ids.device |
|
|
| batch_beam_size, cur_len = input_ids.shape |
|
|
| if return_dict_in_generate and output_scores: |
| beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)] |
| else: |
| beam_indices = None |
|
|
| if num_beams * batch_size != batch_beam_size: |
| raise ValueError( |
| f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." |
| ) |
|
|
| |
| scores = () if (return_dict_in_generate and output_scores) else None |
| decoder_attentions = () if (return_dict_in_generate and output_attentions) else None |
| cross_attentions = () if (return_dict_in_generate and output_attentions) else None |
| decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None |
|
|
| |
| if return_dict_in_generate and self.config.is_encoder_decoder: |
| encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None |
| encoder_hidden_states = ( |
| model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None |
| ) |
|
|
| beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) |
| |
| |
| beam_scores[:, ::num_sub_beams] = 0 |
| beam_scores = beam_scores.view((batch_size * num_beams,)) |
|
|
| this_peer_finished = False |
| while True: |
|
|
| if synced_gpus: |
| |
| |
| this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) |
| |
| dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) |
| |
| if this_peer_finished_flag.item() == 0.0: |
| break |
|
|
| |
| current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) |
|
|
| |
| reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) |
|
|
| |
| model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) |
| outputs = self( |
| **model_inputs, |
| return_dict=True, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| ) |
|
|
| if synced_gpus and this_peer_finished: |
| cur_len = cur_len + 1 |
| continue |
|
|
| if output_scores: |
| processed_score = torch.zeros_like(outputs.logits[:, -1, :]) |
|
|
| for beam_group_idx in range(num_beam_groups): |
| group_start_idx = beam_group_idx * num_sub_beams |
| group_end_idx = min(group_start_idx + num_sub_beams, num_beams) |
| group_size = group_end_idx - group_start_idx |
|
|
| |
| batch_group_indices = [] |
|
|
| for batch_idx in range(batch_size): |
| batch_group_indices.extend( |
| [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] |
| ) |
| group_input_ids = input_ids[batch_group_indices] |
|
|
| |
| next_token_logits = outputs.logits[batch_group_indices, -1, :] |
|
|
| |
| |
| next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) |
| next_token_scores = nn.functional.log_softmax( |
| next_token_logits, dim=-1 |
| ) |
| vocab_size = next_token_scores.shape[-1] |
|
|
| next_token_scores_processed = logits_processor( |
| group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx |
| ) |
| next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) |
| next_token_scores = next_token_scores.expand_as(next_token_scores_processed) |
|
|
| if output_scores: |
| processed_score[batch_group_indices] = next_token_scores_processed |
|
|
| |
| next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) |
|
|
| next_token_scores, next_tokens = torch.topk( |
| next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True |
| ) |
|
|
| next_indices = torch_int_div(next_tokens, vocab_size) |
| next_tokens = next_tokens % vocab_size |
|
|
| |
| beam_outputs = beam_scorer.process( |
| group_input_ids, |
| next_token_scores, |
| next_tokens, |
| next_indices, |
| pad_token_id=pad_token_id, |
| eos_token_id=eos_token_id, |
| ) |
| beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] |
| beam_next_tokens = beam_outputs["next_beam_tokens"] |
| beam_idx = beam_outputs["next_beam_indices"] |
|
|
| if return_dict_in_generate and output_scores: |
| beam_indices[beam_group_idx] = tuple( |
| beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0])) |
| ) |
|
|
| input_ids[batch_group_indices] = group_input_ids[beam_idx] |
| group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) |
| current_tokens[batch_group_indices] = group_input_ids[:, -1] |
|
|
| |
| |
| reordering_indices[batch_group_indices] = ( |
| num_beams * torch_int_div(beam_idx, group_size) + group_start_idx + (beam_idx % group_size) |
| ) |
|
|
| |
| if return_dict_in_generate: |
| if output_scores: |
| scores += (processed_score,) |
| if output_attentions: |
| decoder_attentions += ( |
| (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) |
| ) |
| if self.config.is_encoder_decoder: |
| cross_attentions += (outputs.cross_attentions,) |
|
|
| if output_hidden_states: |
| decoder_hidden_states += ( |
| (outputs.decoder_hidden_states,) |
| if self.config.is_encoder_decoder |
| else (outputs.hidden_states,) |
| ) |
|
|
| input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) |
|
|
| model_kwargs = self._update_model_kwargs_for_generation( |
| outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder |
| ) |
| if model_kwargs["past"] is not None: |
| model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], reordering_indices) |
|
|
| |
| cur_len = cur_len + 1 |
|
|
| if beam_scorer.is_done or stopping_criteria(input_ids, scores): |
| if not synced_gpus: |
| break |
| else: |
| this_peer_finished = True |
|
|
| sequence_outputs = beam_scorer.finalize( |
| input_ids, |
| beam_scores, |
| next_tokens, |
| next_indices, |
| pad_token_id=pad_token_id, |
| eos_token_id=eos_token_id, |
| max_length=stopping_criteria.max_length, |
| ) |
|
|
| if return_dict_in_generate: |
| if not output_scores: |
| sequence_outputs["sequence_scores"] = None |
| else: |
| beam_indices = sum(beam_indices, ()) |
| num_return_sequences = beam_scorer.num_beam_hyps_to_keep |
| |
| beam_indices = tuple( |
| (beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size)) |
| ) |
| beam_indices = sum(beam_indices, ()) |
|
|
| if self.config.is_encoder_decoder: |
| return BeamSearchEncoderDecoderOutput( |
| sequences=sequence_outputs["sequences"], |
| sequences_scores=sequence_outputs["sequence_scores"], |
| scores=scores, |
| beam_indices=beam_indices, |
| encoder_attentions=encoder_attentions, |
| encoder_hidden_states=encoder_hidden_states, |
| decoder_attentions=decoder_attentions, |
| cross_attentions=cross_attentions, |
| decoder_hidden_states=decoder_hidden_states, |
| ) |
| else: |
| return BeamSearchDecoderOnlyOutput( |
| sequences=sequence_outputs["sequences"], |
| sequences_scores=sequence_outputs["sequence_scores"], |
| scores=scores, |
| attentions=decoder_attentions, |
| hidden_states=decoder_hidden_states, |
| ) |
| else: |
| return sequence_outputs["sequences"] |
|
|
| def constrained_beam_search( |
| self, |
| input_ids: torch.LongTensor, |
| constrained_beam_scorer: ConstrainedBeamSearchScorer, |
| logits_processor: Optional[LogitsProcessorList] = None, |
| stopping_criteria: Optional[StoppingCriteriaList] = None, |
| max_length: Optional[int] = None, |
| pad_token_id: Optional[int] = None, |
| eos_token_id: Optional[int] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| output_scores: Optional[bool] = None, |
| return_dict_in_generate: Optional[bool] = None, |
| synced_gpus: Optional[bool] = None, |
| **model_kwargs, |
| ) -> Union[BeamSearchOutput, torch.LongTensor]: |
|
|
| r""" |
| Generates sequences of token ids for models with a language modeling head using **constrained beam search |
| decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. |
| |
| Parameters: |
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
| The sequence used as a prompt for the generation. |
| constrained_beam_scorer (`ConstrainedBeamSearchScorer`): |
| A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and |
| sorted during generation, while satisfying a list of positive constraints. For more information, the |
| documentation of [`ConstrainedBeamSearchScorer`] should be read. |
| logits_processor (`LogitsProcessorList`, *optional*): |
| An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] |
| used to modify the prediction scores of the language modeling head applied at each generation step. |
| stopping_criteria (`StoppingCriteriaList`, *optional*): |
| An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] |
| used to tell if the generation loop should stop. |
| logits_warper (`LogitsProcessorList`, *optional*): |
| An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used |
| to warp the prediction score distribution of the language modeling head applied before multinomial |
| sampling at each generation step. |
| max_length (`int`, *optional*, defaults to 20): |
| **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated |
| tokens. The maximum length of the sequence to be generated. |
| pad_token_id (`int`, *optional*): |
| The id of the *padding* token. |
| eos_token_id (`int`, *optional*): |
| The id of the *end-of-sequence* token. |
| output_attentions (`bool`, *optional*, defaults to `False`): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
| returned tensors for more details. |
| output_hidden_states (`bool`, *optional*, defaults to `False`): |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors |
| for more details. |
| output_scores (`bool`, *optional*, defaults to `False`): |
| Whether or not to return the prediction scores. See `scores` under returned tensors for more details. |
| return_dict_in_generate (`bool`, *optional*, defaults to `False`): |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| synced_gpus (`bool`, *optional*, defaults to `False`): |
| Whether to continue running the while loop until max_length (needed for ZeRO stage 3) |
| model_kwargs: |
| Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is |
| an encoder-decoder model the kwargs should include `encoder_outputs`. |
| |
| Return: |
| [`generation_utilsBeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or |
| `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a |
| [`~generation_utils.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and |
| `return_dict_in_generate=True` or a [`~generation_utils.BeamSearchEncoderDecoderOutput`] if |
| `model.config.is_encoder_decoder=True`. |
| |
| |
| Examples: |
| |
| ```python |
| >>> from transformers import ( |
| ... AutoTokenizer, |
| ... AutoModelForSeq2SeqLM, |
| ... LogitsProcessorList, |
| ... MinLengthLogitsProcessor, |
| ... ConstrainedBeamSearchScorer, |
| ... PhrasalConstraint, |
| ... ) |
| >>> import torch |
| |
| >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") |
| >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") |
| |
| >>> encoder_input_str = "translate English to German: How old are you?" |
| >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids |
| |
| |
| >>> # lets run beam search using 3 beams |
| >>> num_beams = 3 |
| >>> # define decoder start token ids |
| >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) |
| >>> input_ids = input_ids * model.config.decoder_start_token_id |
| |
| >>> # add encoder_outputs to model keyword arguments |
| >>> model_kwargs = { |
| ... "encoder_outputs": model.get_encoder()( |
| ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True |
| ... ) |
| ... } |
| |
| >>> constraint_str = "Sie" |
| >>> constraint_token_ids = tokenizer.encode(constraint_str)[:-1] # slice to remove eos token |
| >>> constraints = [PhrasalConstraint(token_ids=constraint_token_ids)] |
| |
| |
| >>> # instantiate beam scorer |
| >>> beam_scorer = ConstrainedBeamSearchScorer( |
| ... batch_size=1, num_beams=num_beams, device=model.device, constraints=constraints |
| ... ) |
| |
| >>> # instantiate logits processors |
| >>> logits_processor = LogitsProcessorList( |
| ... [ |
| ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), |
| ... ] |
| ... ) |
| |
| >>> outputs = model.constrained_beam_search( |
| ... input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs |
| ... ) |
| |
| >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) |
| ['Wie alt sind Sie?'] |
| ```""" |
| |
| logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() |
| stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() |
| if max_length is not None: |
| warnings.warn( |
| "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", |
| UserWarning, |
| ) |
| stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) |
| if len(stopping_criteria) == 0: |
| warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) |
| pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id |
| eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id |
| output_scores = output_scores if output_scores is not None else self.config.output_scores |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict_in_generate = ( |
| return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate |
| ) |
|
|
| |
| scores = () if (return_dict_in_generate and output_scores) else None |
| decoder_attentions = () if (return_dict_in_generate and output_attentions) else None |
| cross_attentions = () if (return_dict_in_generate and output_attentions) else None |
| decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None |
|
|
| |
| if return_dict_in_generate and self.config.is_encoder_decoder: |
| encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None |
| encoder_hidden_states = ( |
| model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None |
| ) |
|
|
| batch_size = len(constrained_beam_scorer._beam_hyps) |
| num_beams = constrained_beam_scorer.num_beams |
|
|
| batch_beam_size, cur_len = input_ids.shape |
|
|
| if num_beams * batch_size != batch_beam_size: |
| raise ValueError( |
| f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." |
| ) |
|
|
| beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) |
| beam_scores[:, 1:] = -1e9 |
| beam_scores = beam_scores.view((batch_size * num_beams,)) |
|
|
| this_peer_finished = False |
| while True: |
|
|
| if synced_gpus: |
| |
| |
| this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) |
| |
| dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) |
| |
| if this_peer_finished_flag.item() == 0.0: |
| break |
|
|
| model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) |
|
|
| outputs = self( |
| **model_inputs, |
| return_dict=True, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| ) |
|
|
| if synced_gpus and this_peer_finished: |
| cur_len = cur_len + 1 |
| continue |
|
|
| next_token_logits = outputs.logits[:, -1, :] |
| |
| |
| next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) |
| next_token_scores = nn.functional.log_softmax( |
| next_token_logits, dim=-1 |
| ) |
|
|
| next_token_scores_processed = logits_processor(input_ids, next_token_scores) |
|
|
| scores_for_all_vocab = next_token_scores_processed.clone() |
|
|
| next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) |
|
|
| |
| if return_dict_in_generate: |
| if output_scores: |
| scores += (next_token_scores,) |
| if output_attentions: |
| decoder_attentions += ( |
| (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) |
| ) |
| if self.config.is_encoder_decoder: |
| cross_attentions += (outputs.cross_attentions,) |
|
|
| if output_hidden_states: |
| decoder_hidden_states += ( |
| (outputs.decoder_hidden_states,) |
| if self.config.is_encoder_decoder |
| else (outputs.hidden_states,) |
| ) |
|
|
| |
| vocab_size = next_token_scores.shape[-1] |
| next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) |
|
|
| next_token_scores, next_tokens = torch.topk( |
| next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True |
| ) |
|
|
| next_indices = (next_tokens / vocab_size).long() |
| next_tokens = next_tokens % vocab_size |
|
|
| |
| beam_outputs = constrained_beam_scorer.process( |
| input_ids, |
| next_token_scores, |
| next_tokens, |
| next_indices, |
| scores_for_all_vocab, |
| pad_token_id=pad_token_id, |
| eos_token_id=eos_token_id, |
| ) |
| beam_scores = beam_outputs["next_beam_scores"] |
| beam_next_tokens = beam_outputs["next_beam_tokens"] |
| beam_idx = beam_outputs["next_beam_indices"] |
|
|
| input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) |
| model_kwargs = self._update_model_kwargs_for_generation( |
| outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder |
| ) |
| if model_kwargs["past"] is not None: |
| model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx) |
|
|
| |
| cur_len = cur_len + 1 |
|
|
| if constrained_beam_scorer.is_done or stopping_criteria(input_ids, scores): |
| if not synced_gpus: |
| break |
| else: |
| this_peer_finished = True |
|
|
| sequence_outputs = constrained_beam_scorer.finalize( |
| input_ids, |
| beam_scores, |
| next_tokens, |
| next_indices, |
| pad_token_id=pad_token_id, |
| eos_token_id=eos_token_id, |
| max_length=stopping_criteria.max_length, |
| ) |
|
|
| if return_dict_in_generate: |
| if not output_scores: |
| sequence_outputs["sequence_scores"] = None |
| if self.config.is_encoder_decoder: |
| return BeamSearchEncoderDecoderOutput( |
| sequences=sequence_outputs["sequences"], |
| sequences_scores=sequence_outputs["sequence_scores"], |
| scores=scores, |
| encoder_attentions=encoder_attentions, |
| encoder_hidden_states=encoder_hidden_states, |
| decoder_attentions=decoder_attentions, |
| cross_attentions=cross_attentions, |
| decoder_hidden_states=decoder_hidden_states, |
| ) |
| else: |
| return BeamSearchDecoderOnlyOutput( |
| sequences=sequence_outputs["sequences"], |
| sequences_scores=sequence_outputs["sequence_scores"], |
| scores=scores, |
| attentions=decoder_attentions, |
| hidden_states=decoder_hidden_states, |
| ) |
| else: |
| return sequence_outputs["sequences"] |
|
|
|
|
| def top_k_top_p_filtering( |
| logits: torch.FloatTensor, |
| top_k: int = 0, |
| top_p: float = 1.0, |
| filter_value: float = -float("Inf"), |
| min_tokens_to_keep: int = 1, |
| ) -> torch.FloatTensor: |
| """ |
| Filter a distribution of logits using top-k and/or nucleus (top-p) filtering |
| |
| Args: |
| logits: logits distribution shape (batch size, vocabulary size) |
| top_k (`int`, *optional*, defaults to 0): |
| If > 0, only keep the top k tokens with highest probability (top-k filtering) |
| top_p (`float`, *optional*, defaults to 1.0): |
| If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus |
| filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) |
| min_tokens_to_keep (`int`, *optional*, defaults to 1): |
| Minimumber of tokens we keep per batch example in the output. |
| |
| From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 |
| """ |
| if top_k > 0: |
| logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)( |
| None, logits |
| ) |
|
|
| if 0 <= top_p <= 1.0: |
| logits = TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=min_tokens_to_keep)(None, logits) |
|
|
| return logits |
|
|