| |
| |
| |
| |
| @@ -27,7 +27,6 @@ |
| from huggingface_hub import hf_hub_download |
| from parameterized import parameterized |
| |
| -import transformers |
| from transformers import WhisperConfig |
| from transformers.testing_utils import ( |
| is_flaky, |
| @@ -41,7 +40,7 @@ |
| slow, |
| torch_device, |
| ) |
| -from transformers.utils import cached_property, is_torch_available, is_torch_xpu_available, is_torchaudio_available |
| +from transformers.utils import is_torch_available, is_torch_xpu_available, is_torchaudio_available |
| from transformers.utils.import_utils import is_datasets_available |
| |
| from ...generation.test_utils import GenerationTesterMixin |
| @@ -1432,33 +1431,22 @@ def test_generate_compilation_all_outputs(self): |
| @require_torch |
| @require_torchaudio |
| class WhisperModelIntegrationTests(unittest.TestCase): |
| - def setUp(self): |
| - self._unpatched_generation_mixin_generate = transformers.GenerationMixin.generate |
| - |
| - def tearDown(self): |
| - transformers.GenerationMixin.generate = self._unpatched_generation_mixin_generate |
| - |
| - @cached_property |
| - def default_processor(self): |
| - return WhisperProcessor.from_pretrained("openai/whisper-base") |
| + _dataset = None |
| + |
| + @classmethod |
| + def _load_dataset(cls): |
| + # Lazy loading of the dataset. Because it is a class method, it will only be loaded once per pytest process. |
| + if cls._dataset is None: |
| + cls._dataset = datasets.load_dataset( |
| + "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation" |
| + ) |
| |
| def _load_datasamples(self, num_samples): |
| - ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") |
| - # automatic decoding with librispeech |
| + self._load_dataset() |
| + ds = self._dataset |
| speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"] |
| - |
| return [x["array"] for x in speech_samples] |
| |
| - def _patch_generation_mixin_generate(self, check_args_fn=None): |
| - test = self |
| - |
| - def generate(self, *args, **kwargs): |
| - if check_args_fn is not None: |
| - check_args_fn(*args, **kwargs) |
| - return test._unpatched_generation_mixin_generate(self, *args, **kwargs) |
| - |
| - transformers.GenerationMixin.generate = generate |
| - |
| @slow |
| def test_tiny_logits_librispeech(self): |
| torch_device = "cpu" |
| @@ -1586,8 +1574,6 @@ def test_large_logits_librispeech(self): |
| |
| @slow |
| def test_tiny_en_generation(self): |
| - torch_device = "cpu" |
| - set_seed(0) |
| processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") |
| model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") |
| model.to(torch_device) |
| @@ -1605,8 +1591,6 @@ def test_tiny_en_generation(self): |
| |
| @slow |
| def test_tiny_generation(self): |
| - torch_device = "cpu" |
| - set_seed(0) |
| processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") |
| model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") |
| model.to(torch_device) |
| @@ -1623,8 +1607,6 @@ def test_tiny_generation(self): |
| |
| @slow |
| def test_large_generation(self): |
| - torch_device = "cpu" |
| - set_seed(0) |
| processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3") |
| model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3") |
| model.to(torch_device) |
| @@ -1643,7 +1625,6 @@ def test_large_generation(self): |
| |
| @slow |
| def test_large_generation_multilingual(self): |
| - set_seed(0) |
| processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3") |
| model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3") |
| model.to(torch_device) |
| @@ -1710,8 +1691,6 @@ def test_large_batched_generation(self): |
| |
| @slow |
| def test_large_batched_generation_multilingual(self): |
| - torch_device = "cpu" |
| - set_seed(0) |
| processor = WhisperProcessor.from_pretrained("openai/whisper-large") |
| model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large") |
| model.to(torch_device) |
| @@ -2727,11 +2706,6 @@ def test_whisper_longform_single_batch_beam(self): |
| "renormalize_logits": True, # necessary to match OAI beam search implementation |
| } |
| |
| - def check_gen_kwargs(inputs, generation_config, *args, **kwargs): |
| - self.assertEqual(generation_config.num_beams, gen_kwargs["num_beams"]) |
| - |
| - self._patch_generation_mixin_generate(check_args_fn=check_gen_kwargs) |
| - |
| torch.manual_seed(0) |
| result = model.generate(input_features, **gen_kwargs) |
| decoded = processor.batch_decode(result, skip_special_tokens=True) |
|
|