| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | import inspect |
| | import unittest |
| |
|
| | from transformers import is_torch_available |
| | from transformers.testing_utils import require_torch, slow, torch_device |
| |
|
| |
|
| | if is_torch_available(): |
| | import torch |
| |
|
| | from transformers import BartForConditionalGeneration, BartTokenizer, top_k_top_p_filtering |
| | from transformers.generation_beam_search import BeamSearchScorer |
| | from transformers.generation_logits_process import ( |
| | ForcedBOSTokenLogitsProcessor, |
| | ForcedEOSTokenLogitsProcessor, |
| | HammingDiversityLogitsProcessor, |
| | InfNanRemoveLogitsProcessor, |
| | LogitsProcessorList, |
| | MinLengthLogitsProcessor, |
| | NoBadWordsLogitsProcessor, |
| | NoRepeatNGramLogitsProcessor, |
| | RepetitionPenaltyLogitsProcessor, |
| | TemperatureLogitsWarper, |
| | TopKLogitsWarper, |
| | TopPLogitsWarper, |
| | ) |
| | from transformers.generation_stopping_criteria import MaxLengthCriteria, StoppingCriteriaList |
| | from transformers.generation_utils import ( |
| | BeamSampleDecoderOnlyOutput, |
| | BeamSampleEncoderDecoderOutput, |
| | BeamSearchDecoderOnlyOutput, |
| | BeamSearchEncoderDecoderOutput, |
| | GreedySearchDecoderOnlyOutput, |
| | GreedySearchEncoderDecoderOutput, |
| | SampleDecoderOnlyOutput, |
| | SampleEncoderDecoderOutput, |
| | ) |
| |
|
| |
|
| | class GenerationTesterMixin: |
| | model_tester = None |
| | all_generative_model_classes = () |
| | input_name = "input_ids" |
| |
|
| | def _get_input_ids_and_config(self): |
| | config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() |
| |
|
| | input_ids = inputs_dict[self.input_name] |
| | attention_mask = torch.ones_like(input_ids, dtype=torch.long) |
| |
|
| | |
| | max_batch_size = 2 |
| | sequence_length = input_ids.shape[-1] // 2 |
| | input_ids = input_ids[:max_batch_size, :sequence_length] |
| | attention_mask = attention_mask[:max_batch_size, :sequence_length] |
| |
|
| | |
| | max_length = input_ids.shape[-1] + 3 |
| | if config.eos_token_id is not None and config.pad_token_id is None: |
| | |
| | config.pad_token_id = config.eos_token_id |
| | return config, input_ids, attention_mask, max_length |
| |
|
| | @staticmethod |
| | def _get_logits_processor_and_kwargs( |
| | input_length, |
| | eos_token_id, |
| | forced_bos_token_id=None, |
| | forced_eos_token_id=None, |
| | max_length=None, |
| | diversity_penalty=None, |
| | ): |
| | process_kwargs = { |
| | "min_length": input_length + 1, |
| | "bad_words_ids": [[1, 0]], |
| | "no_repeat_ngram_size": 2, |
| | "repetition_penalty": 1.2, |
| | } |
| | logits_processor = LogitsProcessorList( |
| | ( |
| | [ |
| | HammingDiversityLogitsProcessor(diversity_penalty, num_beams=2, num_beam_groups=2), |
| | ] |
| | if diversity_penalty is not None |
| | else [] |
| | ) |
| | + ( |
| | [ |
| | MinLengthLogitsProcessor(process_kwargs["min_length"], eos_token_id), |
| | ] |
| | if eos_token_id is not None |
| | else [] |
| | ) |
| | + ( |
| | [ |
| | ForcedBOSTokenLogitsProcessor(forced_bos_token_id), |
| | ] |
| | if forced_bos_token_id is not None |
| | else [] |
| | ) |
| | + ( |
| | [ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)] |
| | if forced_eos_token_id is not None |
| | else [] |
| | ) |
| | + [ |
| | NoBadWordsLogitsProcessor(process_kwargs["bad_words_ids"], eos_token_id), |
| | NoRepeatNGramLogitsProcessor(process_kwargs["no_repeat_ngram_size"]), |
| | RepetitionPenaltyLogitsProcessor(process_kwargs["repetition_penalty"]), |
| | ] |
| | ) |
| | return process_kwargs, logits_processor |
| |
|
| | @staticmethod |
| | def _get_warper_and_kwargs(num_beams): |
| | warp_kwargs = {"top_k": 10, "top_p": 0.7, "temperature": 0.7} |
| | logits_warper = LogitsProcessorList( |
| | [ |
| | TemperatureLogitsWarper(warp_kwargs["temperature"]), |
| | TopKLogitsWarper(top_k=warp_kwargs["top_k"], min_tokens_to_keep=(2 if num_beams > 1 else 1)), |
| | TopPLogitsWarper(top_p=warp_kwargs["top_p"], min_tokens_to_keep=(2 if num_beams > 1 else 1)), |
| | ] |
| | ) |
| | return warp_kwargs, logits_warper |
| |
|
| | @staticmethod |
| | def _get_beam_scorer_and_kwargs(batch_size, max_length, num_return_sequences=1): |
| | beam_kwargs = { |
| | "early_stopping": False, |
| | "length_penalty": 2.0, |
| | "num_beams": 2, |
| | "num_return_sequences": num_return_sequences, |
| | } |
| | beam_scorer = BeamSearchScorer( |
| | batch_size=batch_size, |
| | num_beams=beam_kwargs["num_beams"], |
| | device=torch_device, |
| | length_penalty=beam_kwargs["length_penalty"], |
| | do_early_stopping=beam_kwargs["early_stopping"], |
| | num_beam_hyps_to_keep=num_return_sequences, |
| | ) |
| | return beam_kwargs, beam_scorer |
| |
|
| | @staticmethod |
| | def _get_diverse_beam_scorer_and_kwargs(batch_size, max_length, num_return_sequences=1): |
| | beam_kwargs = { |
| | "early_stopping": False, |
| | "length_penalty": 2.0, |
| | "num_beams": 2, |
| | "num_return_sequences": num_return_sequences, |
| | "num_beam_groups": 2, |
| | "diversity_penalty": 2.0, |
| | } |
| | beam_scorer = BeamSearchScorer( |
| | batch_size=batch_size, |
| | num_beams=beam_kwargs["num_beams"], |
| | device=torch_device, |
| | length_penalty=beam_kwargs["length_penalty"], |
| | do_early_stopping=beam_kwargs["early_stopping"], |
| | num_beam_hyps_to_keep=num_return_sequences, |
| | num_beam_groups=beam_kwargs["num_beam_groups"], |
| | ) |
| | return beam_kwargs, beam_scorer |
| |
|
| | @staticmethod |
| | def _get_encoder_outputs( |
| | model, input_ids, attention_mask, output_attentions=None, output_hidden_states=None, num_interleave=1 |
| | ): |
| | encoder = model.get_encoder() |
| | encoder_outputs = encoder( |
| | input_ids, |
| | attention_mask=attention_mask, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | ) |
| | encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave( |
| | num_interleave, dim=0 |
| | ) |
| | input_ids = torch.zeros_like(input_ids[:, :1]) + model._get_decoder_start_token_id() |
| | attention_mask = None |
| | return encoder_outputs, input_ids, attention_mask |
| |
|
| | def _greedy_generate( |
| | self, |
| | model, |
| | input_ids, |
| | attention_mask, |
| | max_length, |
| | output_scores=False, |
| | output_attentions=False, |
| | output_hidden_states=False, |
| | return_dict_in_generate=False, |
| | ): |
| | if model.config.is_encoder_decoder: |
| | max_length = 4 |
| | logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( |
| | input_ids.shape[-1], |
| | eos_token_id=model.config.eos_token_id, |
| | forced_bos_token_id=model.config.forced_bos_token_id, |
| | forced_eos_token_id=model.config.forced_eos_token_id, |
| | max_length=max_length, |
| | ) |
| |
|
| | kwargs = {} |
| |
|
| | output_generate = model.generate( |
| | input_ids, |
| | attention_mask=attention_mask, |
| | do_sample=False, |
| | num_beams=1, |
| | max_length=max_length, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | output_scores=output_scores, |
| | return_dict_in_generate=return_dict_in_generate, |
| | remove_invalid_values=True, |
| | **logits_process_kwargs, |
| | ) |
| |
|
| | if model.config.is_encoder_decoder: |
| | encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( |
| | model, |
| | input_ids, |
| | attention_mask, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | ) |
| | kwargs["encoder_outputs"] = encoder_outputs |
| |
|
| | with torch.no_grad(): |
| | output_greedy = model.greedy_search( |
| | input_ids, |
| | max_length=max_length, |
| | attention_mask=attention_mask, |
| | logits_processor=logits_processor, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | output_scores=output_scores, |
| | return_dict_in_generate=return_dict_in_generate, |
| | **kwargs, |
| | ) |
| | return output_greedy, output_generate |
| |
|
| | def _sample_generate( |
| | self, |
| | model, |
| | input_ids, |
| | attention_mask, |
| | max_length, |
| | num_return_sequences, |
| | logits_processor, |
| | logits_warper, |
| | logits_warper_kwargs, |
| | process_kwargs, |
| | output_scores=False, |
| | output_attentions=False, |
| | output_hidden_states=False, |
| | return_dict_in_generate=False, |
| | ): |
| | torch.manual_seed(0) |
| | output_generate = model.generate( |
| | input_ids, |
| | do_sample=True, |
| | num_beams=1, |
| | max_length=max_length, |
| | num_return_sequences=num_return_sequences, |
| | attention_mask=attention_mask, |
| | output_scores=output_scores, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict_in_generate=return_dict_in_generate, |
| | remove_invalid_values=True, |
| | **logits_warper_kwargs, |
| | **process_kwargs, |
| | ) |
| |
|
| | torch.manual_seed(0) |
| | kwargs = {} |
| | if model.config.is_encoder_decoder: |
| | encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs( |
| | model, |
| | input_ids, |
| | attention_mask, |
| | num_interleave=num_return_sequences, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | ) |
| | kwargs["encoder_outputs"] = encoder_outputs |
| | input_ids_clone = input_ids_clone.repeat_interleave(num_return_sequences, dim=0) |
| | else: |
| | attention_mask_clone = attention_mask.repeat_interleave(num_return_sequences, dim=0) |
| | input_ids_clone = input_ids.repeat_interleave(num_return_sequences, dim=0) |
| |
|
| | |
| | logits_processor.append(InfNanRemoveLogitsProcessor()) |
| |
|
| | with torch.no_grad(): |
| | output_sample = model.sample( |
| | input_ids_clone, |
| | attention_mask=attention_mask_clone, |
| | max_length=max_length, |
| | logits_processor=logits_processor, |
| | logits_warper=logits_warper, |
| | output_scores=output_scores, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict_in_generate=return_dict_in_generate, |
| | **kwargs, |
| | ) |
| | return output_sample, output_generate |
| |
|
| | def _beam_search_generate( |
| | self, |
| | model, |
| | input_ids, |
| | attention_mask, |
| | max_length, |
| | beam_scorer, |
| | beam_kwargs, |
| | logits_processor, |
| | logits_process_kwargs, |
| | output_scores=False, |
| | output_attentions=False, |
| | output_hidden_states=False, |
| | return_dict_in_generate=False, |
| | ): |
| | output_generate = model.generate( |
| | input_ids, |
| | attention_mask=attention_mask, |
| | do_sample=False, |
| | max_length=max_length, |
| | output_scores=output_scores, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict_in_generate=return_dict_in_generate, |
| | remove_invalid_values=True, |
| | **beam_kwargs, |
| | **logits_process_kwargs, |
| | ) |
| |
|
| | |
| | kwargs = {} |
| | if model.config.is_encoder_decoder: |
| | encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs( |
| | model, |
| | input_ids, |
| | attention_mask, |
| | num_interleave=beam_scorer.num_beams, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | ) |
| | kwargs["encoder_outputs"] = encoder_outputs |
| | input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0) |
| | else: |
| | attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0) |
| | input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0) |
| |
|
| | with torch.no_grad(): |
| | output_beam_search = model.beam_search( |
| | input_ids_clone, |
| | beam_scorer, |
| | max_length=max_length, |
| | attention_mask=attention_mask_clone, |
| | logits_processor=logits_processor, |
| | output_scores=output_scores, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict_in_generate=return_dict_in_generate, |
| | **kwargs, |
| | ) |
| | return output_generate, output_beam_search |
| |
|
| | def _beam_sample_generate( |
| | self, |
| | model, |
| | input_ids, |
| | attention_mask, |
| | max_length, |
| | num_return_sequences, |
| | beam_scorer, |
| | beam_kwargs, |
| | logits_warper, |
| | logits_warper_kwargs, |
| | output_scores=False, |
| | output_attentions=False, |
| | output_hidden_states=False, |
| | return_dict_in_generate=False, |
| | ): |
| | torch.manual_seed(0) |
| | output_generate = model.generate( |
| | input_ids, |
| | attention_mask=attention_mask, |
| | do_sample=True, |
| | max_length=max_length, |
| | output_scores=output_scores, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict_in_generate=return_dict_in_generate, |
| | remove_invalid_values=True, |
| | **beam_kwargs, |
| | **logits_warper_kwargs, |
| | ) |
| | |
| | kwargs = {} |
| | if model.config.is_encoder_decoder: |
| | encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( |
| | model, |
| | input_ids, |
| | attention_mask, |
| | num_interleave=beam_scorer.num_beams * num_return_sequences, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | ) |
| | kwargs["encoder_outputs"] = encoder_outputs |
| | else: |
| | attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0) |
| |
|
| | |
| | logits_processor = LogitsProcessorList() |
| | logits_processor.append(InfNanRemoveLogitsProcessor()) |
| |
|
| | torch.manual_seed(0) |
| | with torch.no_grad(): |
| | output_beam_sample = model.beam_sample( |
| | input_ids.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0), |
| | beam_scorer, |
| | max_length=max_length, |
| | attention_mask=attention_mask, |
| | logits_warper=logits_warper, |
| | logits_processor=logits_processor, |
| | output_scores=output_scores, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict_in_generate=return_dict_in_generate, |
| | **kwargs, |
| | ) |
| |
|
| | return output_generate, output_beam_sample |
| |
|
| | def _group_beam_search_generate( |
| | self, |
| | model, |
| | input_ids, |
| | attention_mask, |
| | max_length, |
| | beam_scorer, |
| | beam_kwargs, |
| | logits_processor, |
| | logits_process_kwargs, |
| | output_scores=False, |
| | output_attentions=False, |
| | output_hidden_states=False, |
| | return_dict_in_generate=False, |
| | ): |
| | output_generate = model.generate( |
| | input_ids, |
| | attention_mask=attention_mask, |
| | do_sample=False, |
| | max_length=max_length, |
| | output_scores=output_scores, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict_in_generate=return_dict_in_generate, |
| | remove_invalid_values=True, |
| | **beam_kwargs, |
| | **logits_process_kwargs, |
| | ) |
| |
|
| | |
| | kwargs = {} |
| | if model.config.is_encoder_decoder: |
| | encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs( |
| | model, |
| | input_ids, |
| | attention_mask, |
| | num_interleave=beam_scorer.num_beams, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | ) |
| | kwargs["encoder_outputs"] = encoder_outputs |
| | input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0) |
| | else: |
| | attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0) |
| | input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0) |
| |
|
| | with torch.no_grad(): |
| | output_group_beam_search = model.group_beam_search( |
| | input_ids_clone, |
| | beam_scorer, |
| | max_length=max_length, |
| | attention_mask=attention_mask_clone, |
| | logits_processor=logits_processor, |
| | output_scores=output_scores, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict_in_generate=return_dict_in_generate, |
| | **kwargs, |
| | ) |
| | return output_generate, output_group_beam_search |
| |
|
| | def test_greedy_generate(self): |
| | |
| | for model_class in self.all_generative_model_classes: |
| | config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() |
| | |
| | model = model_class(config).to(torch_device).eval() |
| | output_greedy, output_generate = self._greedy_generate( |
| | model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length |
| | ) |
| | self.assertListEqual(output_greedy.tolist(), output_generate.tolist()) |
| |
|
| | def test_greedy_generate_dict_outputs(self): |
| | for model_class in self.all_generative_model_classes: |
| | |
| | config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() |
| | config.use_cache = False |
| | model = model_class(config).to(torch_device).eval() |
| | output_greedy, output_generate = self._greedy_generate( |
| | model=model, |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | max_length=max_length, |
| | output_scores=True, |
| | output_hidden_states=True, |
| | output_attentions=True, |
| | return_dict_in_generate=True, |
| | ) |
| |
|
| | if model.config.is_encoder_decoder: |
| | self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput) |
| | self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput) |
| | else: |
| | self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput) |
| | self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput) |
| |
|
| | self.assertListEqual(output_generate.sequences.tolist(), output_greedy.sequences.tolist()) |
| |
|
| | for output in (output_greedy, output_generate): |
| | self._check_outputs(output, input_ids, model.config) |
| |
|
| | def test_greedy_generate_dict_outputs_use_cache(self): |
| | for model_class in self.all_generative_model_classes: |
| | |
| | config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() |
| |
|
| | if not hasattr(config, "use_cache"): |
| | |
| | return |
| |
|
| | config.use_cache = True |
| | config.is_decoder = True |
| | model = model_class(config).to(torch_device).eval() |
| | output_greedy, output_generate = self._greedy_generate( |
| | model=model, |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | max_length=max_length, |
| | output_scores=True, |
| | output_hidden_states=True, |
| | output_attentions=True, |
| | return_dict_in_generate=True, |
| | ) |
| |
|
| | self.assertListEqual(output_generate.sequences.tolist(), output_greedy.sequences.tolist()) |
| |
|
| | for output in (output_greedy, output_generate): |
| | self._check_outputs(output, input_ids, model.config, use_cache=True) |
| |
|
| | def test_sample_generate(self): |
| | for model_class in self.all_generative_model_classes: |
| | config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() |
| | model = model_class(config).to(torch_device).eval() |
| |
|
| | if model.config.is_encoder_decoder: |
| | max_length = 4 |
| |
|
| | process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( |
| | input_ids.shape[-1], |
| | model.config.eos_token_id, |
| | forced_bos_token_id=model.config.forced_bos_token_id, |
| | forced_eos_token_id=model.config.forced_eos_token_id, |
| | max_length=max_length, |
| | ) |
| | logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) |
| |
|
| | |
| | output_sample, output_generate = self._sample_generate( |
| | model=model, |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | max_length=max_length, |
| | num_return_sequences=1, |
| | logits_processor=logits_processor, |
| | logits_warper=logits_warper, |
| | logits_warper_kwargs=logits_warper_kwargs, |
| | process_kwargs=process_kwargs, |
| | ) |
| | self.assertListEqual(output_sample.tolist(), output_generate.tolist()) |
| |
|
| | |
| | output_sample, output_generate = self._sample_generate( |
| | model=model, |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | max_length=max_length, |
| | num_return_sequences=3, |
| | logits_processor=logits_processor, |
| | logits_warper=logits_warper, |
| | logits_warper_kwargs=logits_warper_kwargs, |
| | process_kwargs=process_kwargs, |
| | ) |
| | self.assertListEqual(output_sample.tolist(), output_generate.tolist()) |
| |
|
| | def test_sample_generate_dict_output(self): |
| | for model_class in self.all_generative_model_classes: |
| | |
| | config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() |
| | config.use_cache = False |
| | model = model_class(config).to(torch_device).eval() |
| | if model.config.is_encoder_decoder: |
| | max_length = 4 |
| |
|
| | process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( |
| | input_ids.shape[-1], |
| | model.config.eos_token_id, |
| | forced_bos_token_id=model.config.forced_bos_token_id, |
| | forced_eos_token_id=model.config.forced_eos_token_id, |
| | max_length=max_length, |
| | ) |
| | logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) |
| |
|
| | output_sample, output_generate = self._sample_generate( |
| | model=model, |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | max_length=max_length, |
| | num_return_sequences=2, |
| | logits_processor=logits_processor, |
| | logits_warper=logits_warper, |
| | logits_warper_kwargs=logits_warper_kwargs, |
| | process_kwargs=process_kwargs, |
| | output_scores=True, |
| | output_hidden_states=True, |
| | output_attentions=True, |
| | return_dict_in_generate=True, |
| | ) |
| |
|
| | if model.config.is_encoder_decoder: |
| | self.assertIsInstance(output_sample, SampleEncoderDecoderOutput) |
| | self.assertIsInstance(output_generate, SampleEncoderDecoderOutput) |
| | else: |
| | self.assertIsInstance(output_sample, SampleDecoderOnlyOutput) |
| | self.assertIsInstance(output_generate, SampleDecoderOnlyOutput) |
| |
|
| | self.assertListEqual(output_generate.sequences.tolist(), output_sample.sequences.tolist()) |
| |
|
| | for output in (output_sample, output_generate): |
| | self._check_outputs(output, input_ids, model.config, num_return_sequences=2) |
| |
|
| | def test_beam_search_generate(self): |
| | for model_class in self.all_generative_model_classes: |
| | config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() |
| |
|
| | |
| | |
| | |
| | config.eos_token_id = None |
| | config.forced_eos_token_id = None |
| |
|
| | model = model_class(config).to(torch_device).eval() |
| | if model.config.is_encoder_decoder: |
| | max_length = 4 |
| |
|
| | logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( |
| | input_ids.shape[-1], |
| | config.eos_token_id, |
| | config.forced_bos_token_id, |
| | config.forced_eos_token_id, |
| | max_length, |
| | ) |
| | beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) |
| |
|
| | |
| | output_generate, output_beam_search = self._beam_search_generate( |
| | model=model, |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | max_length=max_length, |
| | beam_scorer=beam_scorer, |
| | beam_kwargs=beam_kwargs, |
| | logits_process_kwargs=logits_process_kwargs, |
| | logits_processor=logits_processor, |
| | ) |
| | self.assertListEqual(output_generate.tolist(), output_beam_search.tolist()) |
| |
|
| | |
| | num_return_sequences = 2 |
| | if model.config.is_encoder_decoder: |
| | max_length = 4 |
| | beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs( |
| | input_ids.shape[0], max_length, num_return_sequences=num_return_sequences |
| | ) |
| |
|
| | output_generate, output_beam_search = self._beam_search_generate( |
| | model=model, |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | max_length=max_length, |
| | beam_scorer=beam_scorer, |
| | beam_kwargs=beam_kwargs, |
| | logits_process_kwargs=logits_process_kwargs, |
| | logits_processor=logits_processor, |
| | ) |
| | self.assertListEqual(output_generate.tolist(), output_beam_search.tolist()) |
| |
|
| | def test_beam_search_generate_dict_output(self): |
| | for model_class in self.all_generative_model_classes: |
| | config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() |
| |
|
| | |
| | config.use_cache = False |
| |
|
| | |
| | |
| | |
| | config.eos_token_id = None |
| | config.forced_eos_token_id = None |
| |
|
| | model = model_class(config).to(torch_device).eval() |
| | if model.config.is_encoder_decoder: |
| | max_length = 4 |
| |
|
| | logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( |
| | input_ids.shape[-1], |
| | config.eos_token_id, |
| | config.forced_bos_token_id, |
| | config.forced_eos_token_id, |
| | max_length, |
| | ) |
| | beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) |
| | output_generate, output_beam_search = self._beam_search_generate( |
| | model=model, |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | max_length=max_length, |
| | beam_scorer=beam_scorer, |
| | beam_kwargs=beam_kwargs, |
| | logits_process_kwargs=logits_process_kwargs, |
| | logits_processor=logits_processor, |
| | output_scores=True, |
| | output_hidden_states=True, |
| | output_attentions=True, |
| | return_dict_in_generate=True, |
| | ) |
| | if model.config.is_encoder_decoder: |
| | self.assertIsInstance(output_beam_search, BeamSearchEncoderDecoderOutput) |
| | self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) |
| | else: |
| | self.assertIsInstance(output_beam_search, BeamSearchDecoderOnlyOutput) |
| | self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) |
| |
|
| | self.assertListEqual(output_generate.sequences.tolist(), output_beam_search.sequences.tolist()) |
| | self.assertTrue( |
| | torch.allclose(output_generate["sequences_scores"], output_beam_search["sequences_scores"], atol=1e-3) |
| | ) |
| | self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],)) |
| | self.assertTrue((output_generate["sequences_scores"] < 0).all().item()) |
| |
|
| | for output in (output_beam_search, output_generate): |
| | self._check_outputs(output, input_ids, model.config, num_return_sequences=beam_scorer.num_beams) |
| |
|
| | def test_beam_search_generate_dict_outputs_use_cache(self): |
| | for model_class in self.all_generative_model_classes: |
| | |
| | config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() |
| |
|
| | |
| | |
| | |
| | config.eos_token_id = None |
| | config.forced_eos_token_id = None |
| |
|
| | if not hasattr(config, "use_cache"): |
| | |
| | return |
| |
|
| | model = model_class(config).to(torch_device).eval() |
| | if model.config.is_encoder_decoder: |
| | max_length = 4 |
| |
|
| | logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( |
| | input_ids.shape[-1], |
| | config.eos_token_id, |
| | config.forced_bos_token_id, |
| | config.forced_eos_token_id, |
| | max_length, |
| | ) |
| |
|
| | beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) |
| |
|
| | config.use_cache = True |
| | config.is_decoder = True |
| | model = model_class(config).to(torch_device).eval() |
| | output_beam, output_generate = self._beam_search_generate( |
| | model=model, |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | max_length=max_length, |
| | beam_scorer=beam_scorer, |
| | beam_kwargs=beam_kwargs, |
| | logits_process_kwargs=logits_process_kwargs, |
| | logits_processor=logits_processor, |
| | output_scores=True, |
| | output_hidden_states=True, |
| | output_attentions=True, |
| | return_dict_in_generate=True, |
| | ) |
| |
|
| | self.assertListEqual(output_generate.sequences.tolist(), output_beam.sequences.tolist()) |
| |
|
| | for output in (output_beam, output_generate): |
| | self._check_outputs( |
| | output, input_ids, model.config, use_cache=True, num_return_sequences=beam_scorer.num_beams |
| | ) |
| |
|
| | def test_beam_sample_generate(self): |
| | for model_class in self.all_generative_model_classes: |
| | config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() |
| |
|
| | |
| | |
| | |
| | config.eos_token_id = None |
| | config.forced_eos_token_id = None |
| |
|
| | logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) |
| |
|
| | model = model_class(config).to(torch_device).eval() |
| |
|
| | |
| | |
| | num_return_sequences = 2 |
| | if model.config.is_encoder_decoder: |
| | max_length = 4 |
| | beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs( |
| | input_ids.shape[0] * num_return_sequences, max_length |
| | ) |
| | beam_kwargs["num_return_sequences"] = num_return_sequences |
| |
|
| | output_generate, output_beam_sample = self._beam_sample_generate( |
| | model=model, |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | max_length=max_length, |
| | num_return_sequences=num_return_sequences, |
| | beam_scorer=beam_scorer, |
| | beam_kwargs=beam_kwargs, |
| | logits_warper=logits_warper, |
| | logits_warper_kwargs=logits_warper_kwargs, |
| | ) |
| | self.assertListEqual(output_generate.tolist(), output_beam_sample.tolist()) |
| |
|
| | def test_beam_sample_generate_dict_output(self): |
| | for model_class in self.all_generative_model_classes: |
| | config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() |
| |
|
| | |
| | config.use_cache = False |
| |
|
| | |
| | |
| | |
| | config.eos_token_id = None |
| | config.forced_eos_token_id = None |
| |
|
| | model = model_class(config).to(torch_device).eval() |
| | logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) |
| |
|
| | num_return_sequences = 2 |
| | if model.config.is_encoder_decoder: |
| | max_length = 4 |
| | beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs( |
| | input_ids.shape[0] * num_return_sequences, max_length |
| | ) |
| | beam_kwargs["num_return_sequences"] = num_return_sequences |
| |
|
| | output_beam_sample, output_generate = self._beam_sample_generate( |
| | model=model, |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | max_length=max_length, |
| | num_return_sequences=num_return_sequences, |
| | beam_scorer=beam_scorer, |
| | beam_kwargs=beam_kwargs, |
| | logits_warper=logits_warper, |
| | logits_warper_kwargs=logits_warper_kwargs, |
| | output_scores=True, |
| | output_hidden_states=True, |
| | output_attentions=True, |
| | return_dict_in_generate=True, |
| | ) |
| |
|
| | if model.config.is_encoder_decoder: |
| | self.assertIsInstance(output_beam_sample, BeamSampleEncoderDecoderOutput) |
| | self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput) |
| | else: |
| | self.assertIsInstance(output_beam_sample, BeamSampleDecoderOnlyOutput) |
| | self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput) |
| |
|
| | self.assertListEqual(output_generate.sequences.tolist(), output_beam_sample.sequences.tolist()) |
| | self.assertTrue( |
| | torch.allclose(output_generate["sequences_scores"], output_beam_sample["sequences_scores"], atol=1e-3) |
| | ) |
| | self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],)) |
| | self.assertTrue((output_generate["sequences_scores"] < 0).all().item()) |
| |
|
| | for output in (output_beam_sample, output_generate): |
| | self._check_outputs( |
| | output, input_ids, model.config, num_return_sequences=num_return_sequences * beam_scorer.num_beams |
| | ) |
| |
|
| | def test_generate_without_input_ids(self): |
| | config, _, _, max_length = self._get_input_ids_and_config() |
| |
|
| | |
| | if config.bos_token_id is None: |
| | return |
| |
|
| | for model_class in self.all_generative_model_classes: |
| | model = model_class(config).to(torch_device) |
| | model.eval() |
| |
|
| | output_ids_generate = model.generate( |
| | do_sample=False, |
| | max_length=max_length, |
| | remove_invalid_values=True, |
| | ) |
| |
|
| | self.assertIsNotNone(output_ids_generate) |
| |
|
| | def test_group_beam_search_generate(self): |
| | for model_class in self.all_generative_model_classes: |
| | config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() |
| |
|
| | |
| | |
| | |
| | config.eos_token_id = None |
| | config.forced_eos_token_id = None |
| |
|
| | model = model_class(config).to(torch_device).eval() |
| | if model.config.is_encoder_decoder: |
| | max_length = 4 |
| |
|
| | logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( |
| | input_ids.shape[-1], |
| | config.eos_token_id, |
| | config.forced_bos_token_id, |
| | config.forced_eos_token_id, |
| | max_length, |
| | diversity_penalty=2.0, |
| | ) |
| |
|
| | |
| | beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs(input_ids.shape[0], max_length) |
| | output_generate, output_group_beam_search = self._group_beam_search_generate( |
| | model=model, |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | max_length=max_length, |
| | beam_scorer=beam_scorer, |
| | beam_kwargs=beam_kwargs, |
| | logits_processor=logits_processor, |
| | logits_process_kwargs=logits_process_kwargs, |
| | ) |
| | self.assertListEqual(output_generate.tolist(), output_group_beam_search.tolist()) |
| |
|
| | |
| | num_return_sequences = 2 |
| | if model.config.is_encoder_decoder: |
| | max_length = 4 |
| | beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs( |
| | input_ids.shape[0], max_length, num_return_sequences=num_return_sequences |
| | ) |
| | output_generate, output_group_beam_search = self._group_beam_search_generate( |
| | model=model, |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | max_length=max_length, |
| | beam_scorer=beam_scorer, |
| | beam_kwargs=beam_kwargs, |
| | logits_processor=logits_processor, |
| | logits_process_kwargs=logits_process_kwargs, |
| | ) |
| | self.assertListEqual(output_generate.tolist(), output_group_beam_search.tolist()) |
| |
|
| | def test_group_beam_search_generate_dict_output(self): |
| | for model_class in self.all_generative_model_classes: |
| | config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() |
| | config.use_cache = False |
| |
|
| | |
| | |
| | |
| | config.eos_token_id = None |
| | config.forced_eos_token_id = None |
| |
|
| | model = model_class(config).to(torch_device).eval() |
| | if model.config.is_encoder_decoder: |
| | max_length = 4 |
| |
|
| | logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( |
| | input_ids.shape[-1], |
| | config.eos_token_id, |
| | config.forced_bos_token_id, |
| | config.forced_eos_token_id, |
| | max_length, |
| | diversity_penalty=2.0, |
| | ) |
| |
|
| | num_return_sequences = 1 |
| | beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs( |
| | input_ids.shape[0], max_length, num_return_sequences=num_return_sequences |
| | ) |
| | output_generate, output_group_beam_search = self._group_beam_search_generate( |
| | model=model, |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | max_length=max_length, |
| | beam_scorer=beam_scorer, |
| | beam_kwargs=beam_kwargs, |
| | logits_processor=logits_processor, |
| | logits_process_kwargs=logits_process_kwargs, |
| | output_scores=True, |
| | output_hidden_states=True, |
| | output_attentions=True, |
| | return_dict_in_generate=True, |
| | ) |
| | if model.config.is_encoder_decoder: |
| | self.assertIsInstance(output_group_beam_search, BeamSearchEncoderDecoderOutput) |
| | self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) |
| | else: |
| | self.assertIsInstance(output_group_beam_search, BeamSearchDecoderOnlyOutput) |
| | self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) |
| |
|
| | self.assertListEqual(output_generate.sequences.tolist(), output_group_beam_search.sequences.tolist()) |
| | self.assertTrue( |
| | torch.allclose( |
| | output_generate["sequences_scores"], output_group_beam_search["sequences_scores"], atol=1e-3 |
| | ) |
| | ) |
| | self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],)) |
| | self.assertTrue((output_generate["sequences_scores"] < 0).all().item()) |
| |
|
| | for output in (output_group_beam_search, output_generate): |
| | self._check_outputs( |
| | output, input_ids, model.config, num_return_sequences=num_return_sequences * beam_scorer.num_beams |
| | ) |
| |
|
| | def test_generate_with_head_masking(self): |
| | """Test designed for encoder-decoder models to ensure the attention head masking is used.""" |
| | attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"] |
| | for model_class in self.all_generative_model_classes: |
| | config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() |
| | model = model_class(config).to(torch_device) |
| | |
| | if not config.is_encoder_decoder: |
| | continue |
| |
|
| | head_masking = { |
| | "head_mask": torch.zeros(config.encoder_layers, config.encoder_attention_heads, device=torch_device), |
| | "decoder_head_mask": torch.zeros( |
| | config.decoder_layers, config.decoder_attention_heads, device=torch_device |
| | ), |
| | "cross_attn_head_mask": torch.zeros( |
| | config.decoder_layers, config.decoder_attention_heads, device=torch_device |
| | ), |
| | } |
| |
|
| | signature = inspect.signature(model.forward) |
| | |
| | if not set(head_masking.keys()) < set([*signature.parameters.keys()]): |
| | continue |
| |
|
| | for attn_name, (name, mask) in zip(attention_names, head_masking.items()): |
| | out = model.generate( |
| | input_ids, |
| | attention_mask=attention_mask, |
| | num_beams=1, |
| | output_attentions=True, |
| | return_dict_in_generate=True, |
| | remove_invalid_values=True, |
| | **{name: mask}, |
| | ) |
| | |
| | attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1] |
| | self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0) |
| |
|
| | def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1): |
| | batch_size, seq_length = input_ids.shape |
| | num_sequences_in_output = batch_size * num_return_sequences |
| | gen_len = ( |
| | output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length |
| | ) |
| |
|
| | |
| | self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config) |
| |
|
| | |
| | if config.is_encoder_decoder: |
| | |
| | self._check_encoder_attention_for_generate(output.encoder_attentions, batch_size, config, seq_length) |
| | |
| | self._check_attentions_for_generate( |
| | num_sequences_in_output, |
| | output.decoder_attentions, |
| | min_length=1, |
| | max_length=output.sequences.shape[-1], |
| | config=config, |
| | use_cache=use_cache, |
| | ) |
| | else: |
| | |
| | attentions = output.attentions if not use_cache else output.attentions[1:] |
| | min_length = seq_length if not use_cache else seq_length + 1 |
| | self._check_attentions_for_generate( |
| | num_sequences_in_output, |
| | attentions=attentions, |
| | min_length=min_length, |
| | max_length=output.sequences.shape[-1], |
| | config=config, |
| | use_cache=use_cache, |
| | ) |
| |
|
| | |
| | if config.is_encoder_decoder: |
| | |
| | self._check_encoder_hidden_states_for_generate( |
| | output.encoder_hidden_states, batch_size, config, seq_length |
| | ) |
| |
|
| | |
| | self._check_hidden_states_for_generate( |
| | num_sequences_in_output, |
| | output.decoder_hidden_states, |
| | min_length=1, |
| | max_length=output.sequences.shape[-1], |
| | config=config, |
| | use_cache=use_cache, |
| | ) |
| | else: |
| | |
| | hidden_states = output.hidden_states if not use_cache else output.hidden_states[1:] |
| | min_length = seq_length if not use_cache else seq_length + 1 |
| | self._check_hidden_states_for_generate( |
| | num_sequences_in_output, |
| | hidden_states, |
| | min_length=min_length, |
| | max_length=output.sequences.shape[-1], |
| | config=config, |
| | use_cache=use_cache, |
| | ) |
| |
|
| | def _check_scores(self, batch_size, scores, length, config): |
| | expected_shape = (batch_size, config.vocab_size) |
| | self.assertIsInstance(scores, tuple) |
| | self.assertEqual(len(scores), length) |
| | self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores)) |
| |
|
| | def _check_attentions_for_generate( |
| | self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 |
| | ): |
| | self.assertIsInstance(attentions, tuple) |
| | self.assertListEqual( |
| | [isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions) |
| | ) |
| | self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups) |
| |
|
| | for idx, iter_attentions in enumerate(attentions): |
| | tgt_len = min_length + idx if not use_cache else 1 |
| | src_len = min_length + idx |
| |
|
| | expected_shape = ( |
| | batch_size * num_beam_groups, |
| | config.num_attention_heads, |
| | tgt_len, |
| | src_len, |
| | ) |
| | |
| | self.assertListEqual( |
| | [layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions) |
| | ) |
| |
|
| | def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length): |
| | encoder_expected_shape = (batch_size, config.num_attention_heads, seq_length, seq_length) |
| | self.assertIsInstance(attentions, tuple) |
| | self.assertListEqual( |
| | [layer_attentions.shape for layer_attentions in attentions], |
| | [encoder_expected_shape] * len(attentions), |
| | ) |
| |
|
| | def _check_hidden_states_for_generate( |
| | self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1 |
| | ): |
| | self.assertIsInstance(hidden_states, tuple) |
| | self.assertListEqual( |
| | [isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states], |
| | [True] * len(hidden_states), |
| | ) |
| | self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups) |
| |
|
| | for idx, iter_hidden_states in enumerate(hidden_states): |
| | seq_len = min_length + idx if not use_cache else 1 |
| | expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size) |
| | |
| | self.assertListEqual( |
| | [layer_hidden_states.shape for layer_hidden_states in iter_hidden_states], |
| | [expected_shape] * len(iter_hidden_states), |
| | ) |
| |
|
| | def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, seq_length): |
| | encoder_expected_shape = (batch_size, seq_length, config.hidden_size) |
| | self.assertIsInstance(hidden_states, tuple) |
| | self.assertListEqual( |
| | [layer_hidden_states.shape for layer_hidden_states in hidden_states], |
| | [encoder_expected_shape] * len(hidden_states), |
| | ) |
| |
|
| |
|
| | @require_torch |
| | class UtilsFunctionsTest(unittest.TestCase): |
| |
|
| | |
| | def test_top_k_top_p_filtering(self): |
| | logits = torch.tensor( |
| | [ |
| | [ |
| | 8.2220991, |
| | -0.5620044, |
| | 5.23229752, |
| | 4.0386393, |
| | -6.8798378, |
| | -0.54785802, |
| | -3.2012153, |
| | 2.92777176, |
| | 1.88171953, |
| | 7.35341276, |
| | 8.43207833, |
| | -9.85711836, |
| | -5.96209236, |
| | -1.13039161, |
| | -7.1115294, |
| | -0.8369633, |
| | -5.3186408, |
| | 7.06427407, |
| | 0.81369344, |
| | -0.82023817, |
| | -5.9179796, |
| | 0.58813443, |
| | -6.99778438, |
| | 4.71551189, |
| | -0.18771637, |
| | 7.44020759, |
| | 9.38450987, |
| | 2.12662941, |
| | -9.32562038, |
| | 2.35652522, |
| | ], |
| | [ |
| | 0.58425518, |
| | 4.53139238, |
| | -5.57510464, |
| | -6.28030699, |
| | -7.19529503, |
| | -4.02122551, |
| | 1.39337037, |
| | -6.06707057, |
| | 1.59480517, |
| | -9.643119, |
| | 0.03907799, |
| | 0.67231762, |
| | -8.88206726, |
| | 6.27115922, |
| | 2.28520723, |
| | 4.82767506, |
| | 4.30421368, |
| | 8.8275313, |
| | 5.44029958, |
| | -4.4735794, |
| | 7.38579536, |
| | -2.91051663, |
| | 2.61946077, |
| | -2.5674762, |
| | -9.48959302, |
| | -4.02922645, |
| | -1.35416918, |
| | 9.67702323, |
| | -5.89478553, |
| | 1.85370467, |
| | ], |
| | ], |
| | dtype=torch.float, |
| | device=torch_device, |
| | ) |
| |
|
| | non_inf_expected_idx = torch.tensor( |
| | [[0, 0], [0, 10], [0, 25], [0, 26], [1, 13], [1, 17], [1, 20], [1, 27]], |
| | dtype=torch.long, |
| | device=torch_device, |
| | ) |
| |
|
| | non_inf_expected_output = torch.tensor( |
| | [ |
| | 8.2221, |
| | 8.4321, |
| | 7.4402, |
| | 9.3845, |
| | 6.2712, |
| | 8.8275, |
| | 7.3858, |
| | 9.6770, |
| | ], |
| | dtype=torch.float, |
| | device=torch_device, |
| | ) |
| |
|
| | output = top_k_top_p_filtering(logits, top_k=10, top_p=0.6, min_tokens_to_keep=4) |
| | non_inf_output = output[output != -float("inf")].to(device=torch_device) |
| | non_inf_idx = (output != -float("inf")).nonzero().to(device=torch_device) |
| |
|
| | self.assertTrue(torch.allclose(non_inf_expected_output, non_inf_output, atol=1e-12)) |
| | self.assertTrue(torch.all(torch.eq(non_inf_expected_idx, non_inf_idx))) |
| |
|
| |
|
| | @require_torch |
| | class GenerationIntegrationTests(unittest.TestCase): |
| | @slow |
| | def test_diverse_beam_search(self): |
| | article = """Justin Timberlake and Jessica Biel, welcome to parenthood. |
| | The celebrity couple announced the arrival of their son, Silas Randall Timberlake, in statements to People. |
| | "Silas was the middle name of Timberlake's maternal grandfather Bill Bomar, who died in 2012, while Randall is the musician's own middle name, as well as his father's first," People reports. |
| | The couple announced the pregnancy in January, with an Instagram post. It is the first baby for both.""" |
| |
|
| | bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") |
| | bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn").to(torch_device) |
| | input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) |
| |
|
| | outputs = bart_model.generate( |
| | input_ids, |
| | num_beams=4, |
| | num_return_sequences=2, |
| | num_beam_groups=4, |
| | diversity_penalty=2.0, |
| | remove_invalid_values=True, |
| | ) |
| |
|
| | generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True) |
| |
|
| | self.assertListEqual( |
| | generated_text, |
| | [ |
| | "The couple announced the birth of their son, Silas Randall Timberlake, in a statement. Silas was the middle name of Timberlake's maternal grandfather Bill Bomar. Randall is the musician's own middle name, as well as his father's first. It is the first baby for both of them.", |
| | "Justin Timberlake and Jessica Biel have a son. The baby is named Silas Randall Timberlake. It is the first child for both. The couple announced the pregnancy in January. The name Silas is the middle name of Timberlake's maternal grandfather. It's also his own middle name.", |
| | ], |
| | ) |
| |
|
| | def test_max_length_backward_compat_greedy(self): |
| | article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" |
| | bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") |
| | bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device) |
| | input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) |
| |
|
| | max_length = 20 |
| | input_ids = input_ids.expand(2, -1) |
| | model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) |
| | input_ids = bart_model._prepare_decoder_input_ids_for_generation( |
| | input_ids, |
| | decoder_start_token_id=bart_model.config.decoder_start_token_id, |
| | bos_token_id=bart_model.config.bos_token_id, |
| | ) |
| |
|
| | with self.assertWarns(UserWarning): |
| | bart_model.greedy_search( |
| | input_ids, |
| | max_length=max_length, |
| | pad_token_id=bart_model.config.pad_token_id, |
| | eos_token_id=bart_model.config.eos_token_id, |
| | **model_kwargs, |
| | ) |
| |
|
| | def test_max_length_backward_compat_sample(self): |
| | article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" |
| | bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") |
| | bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device) |
| | input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) |
| |
|
| | max_length = 20 |
| | input_ids = input_ids.expand(2, -1) |
| | model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) |
| | input_ids = bart_model._prepare_decoder_input_ids_for_generation( |
| | input_ids, |
| | decoder_start_token_id=bart_model.config.decoder_start_token_id, |
| | bos_token_id=bart_model.config.bos_token_id, |
| | ) |
| | with torch.no_grad(): |
| | with self.assertWarns(UserWarning): |
| | bart_model.sample( |
| | input_ids, |
| | max_length=max_length, |
| | pad_token_id=bart_model.config.pad_token_id, |
| | eos_token_id=bart_model.config.eos_token_id, |
| | **model_kwargs, |
| | ) |
| |
|
| | def test_max_length_backward_compat_beam_search(self): |
| | article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" |
| | bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") |
| | bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device) |
| | input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) |
| |
|
| | batch_size = 1 |
| | max_length = 20 |
| | num_beams = 2 |
| |
|
| | input_ids = input_ids.expand(2, -1) |
| | model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) |
| | input_ids = bart_model._prepare_decoder_input_ids_for_generation( |
| | input_ids, |
| | decoder_start_token_id=bart_model.config.decoder_start_token_id, |
| | bos_token_id=bart_model.config.bos_token_id, |
| | ) |
| |
|
| | beam_scorer = BeamSearchScorer( |
| | batch_size=batch_size, |
| | num_beams=num_beams, |
| | device=torch_device, |
| | ) |
| | with self.assertWarns(UserWarning): |
| | _ = bart_model.beam_search( |
| | input_ids, num_beams=num_beams, max_length=max_length, beam_scorer=beam_scorer, **model_kwargs |
| | ) |
| |
|
| | def test_max_length_backward_compat_group_beam_search(self): |
| | article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" |
| | bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") |
| | bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device) |
| | input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) |
| |
|
| | batch_size = 1 |
| | max_length = 20 |
| | num_beams = 6 |
| | num_beam_groups = 3 |
| | num_return_sequences = num_beams * batch_size |
| |
|
| | input_ids = input_ids.expand(6, -1) |
| | model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) |
| | input_ids = bart_model._prepare_decoder_input_ids_for_generation( |
| | input_ids, |
| | decoder_start_token_id=bart_model.config.decoder_start_token_id, |
| | bos_token_id=bart_model.config.bos_token_id, |
| | ) |
| |
|
| | diverse_beam_scorer = BeamSearchScorer( |
| | batch_size=batch_size, |
| | num_beams=num_beams, |
| | device=torch_device, |
| | num_beam_hyps_to_keep=num_return_sequences, |
| | num_beam_groups=num_beam_groups, |
| | ) |
| | with self.assertWarns(UserWarning): |
| | bart_model.group_beam_search( |
| | input_ids, diverse_beam_scorer, num_beams=num_beams, max_length=max_length, **model_kwargs |
| | ) |
| |
|
| | def test_max_length_warning_if_different(self): |
| | article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" |
| | bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") |
| | bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device) |
| | input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) |
| |
|
| | batch_size = 1 |
| |
|
| | max_length = 20 |
| | num_beams = 6 |
| | num_beam_groups = 3 |
| | num_return_sequences = num_beams * batch_size |
| | stopping_criteria_max_length = 18 |
| | stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=stopping_criteria_max_length)]) |
| |
|
| | |
| | input_ids = input_ids.expand(6, -1) |
| | model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) |
| | input_ids = bart_model._prepare_decoder_input_ids_for_generation( |
| | input_ids, |
| | decoder_start_token_id=bart_model.config.decoder_start_token_id, |
| | bos_token_id=bart_model.config.bos_token_id, |
| | ) |
| |
|
| | with self.assertWarns(UserWarning): |
| | bart_model.greedy_search( |
| | input_ids, |
| | max_length=max_length, |
| | pad_token_id=bart_model.config.pad_token_id, |
| | stopping_criteria=stopping_criteria, |
| | eos_token_id=bart_model.config.eos_token_id, |
| | **model_kwargs, |
| | ) |
| |
|
| | |
| | with self.assertWarns(UserWarning): |
| | with torch.no_grad(): |
| | bart_model.sample( |
| | input_ids, |
| | max_length=max_length, |
| | stopping_criteria=stopping_criteria, |
| | pad_token_id=bart_model.config.pad_token_id, |
| | eos_token_id=bart_model.config.eos_token_id, |
| | **model_kwargs, |
| | ) |
| |
|
| | |
| | beam_scorer = BeamSearchScorer( |
| | batch_size=batch_size, |
| | num_beams=num_beams, |
| | device=torch_device, |
| | ) |
| | with self.assertWarns(UserWarning): |
| | with torch.no_grad(): |
| | bart_model.beam_search( |
| | input_ids, |
| | num_beams=num_beams, |
| | stopping_criteria=stopping_criteria, |
| | max_length=max_length, |
| | beam_scorer=beam_scorer, |
| | **model_kwargs, |
| | ) |
| |
|
| | |
| | diverse_beam_scorer = BeamSearchScorer( |
| | batch_size=batch_size, |
| | num_beams=num_beams, |
| | device=torch_device, |
| | num_beam_hyps_to_keep=num_return_sequences, |
| | num_beam_groups=num_beam_groups, |
| | ) |
| | with self.assertWarns(UserWarning): |
| | bart_model.group_beam_search( |
| | input_ids, |
| | diverse_beam_scorer, |
| | stopping_criteria=stopping_criteria, |
| | num_beams=num_beams, |
| | max_length=max_length, |
| | **model_kwargs, |
| | ) |
| |
|
| | def test_beam_search_warning_if_max_length_is_passed(self): |
| | article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" |
| | bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") |
| | bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device) |
| |
|
| | batch_size = 1 |
| | num_beams = 3 |
| |
|
| | input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) |
| | input_ids = input_ids.expand(num_beams, -1) |
| | model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) |
| |
|
| | stopping_criteria_max_length = 18 |
| | stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=stopping_criteria_max_length)]) |
| |
|
| | with self.assertWarns(UserWarning): |
| | beam_scorer = BeamSearchScorer( |
| | batch_size=batch_size, |
| | num_beams=num_beams, |
| | device=torch_device, |
| | max_length=10, |
| | ) |
| |
|
| | generated_ids = bart_model.beam_search( |
| | input_ids, |
| | num_beams=num_beams, |
| | stopping_criteria=stopping_criteria, |
| | beam_scorer=beam_scorer, |
| | **model_kwargs, |
| | ) |
| |
|
| | beam_scorer_no_max_len = BeamSearchScorer( |
| | batch_size=batch_size, |
| | num_beams=num_beams, |
| | device=torch_device, |
| | ) |
| |
|
| | generated_ids_no_max_len = bart_model.beam_search( |
| | input_ids, |
| | num_beams=num_beams, |
| | stopping_criteria=stopping_criteria, |
| | beam_scorer=beam_scorer_no_max_len, |
| | **model_kwargs, |
| | ) |
| |
|
| | |
| | self.assertEqual(generated_ids.tolist(), generated_ids_no_max_len.tolist()) |
| |
|
| | def test_max_new_tokens(self): |
| | article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" |
| | bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") |
| | bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device) |
| | input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) |
| |
|
| | self.assertEqual(list(input_ids.shape), [1, 15]) |
| |
|
| | |
| | max_new_tokens = 3 |
| | outputs = bart_model.generate(input_ids, max_new_tokens=max_new_tokens) |
| | |
| | self.assertEqual(list(outputs.shape), [1, 4]) |
| |
|
| | |
| | outputs = bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=max_new_tokens) |
| | |
| | self.assertEqual(list(outputs.shape), [1, 18]) |
| |
|
| | |
| | with self.assertWarns(UserWarning): |
| | outputs = bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20) |
| |
|