| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import tempfile |
| import unittest |
|
|
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
| from transformers.testing_utils import ( |
| is_torch_available, |
| require_optimum, |
| require_torch, |
| slow, |
| ) |
|
|
|
|
| if is_torch_available(): |
| import torch |
|
|
|
|
| @require_torch |
| @require_optimum |
| @slow |
| class BetterTransformerIntegrationTest(unittest.TestCase): |
| |
| |
|
|
| def test_transform_and_reverse(self): |
| r""" |
| Classic tests to simply check if the conversion has been successful. |
| """ |
| model_id = "hf-internal-testing/tiny-random-t5" |
| tokenizer = AutoTokenizer.from_pretrained(model_id) |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_id) |
|
|
| inp = tokenizer("This is me", return_tensors="pt") |
|
|
| model = model.to_bettertransformer() |
|
|
| self.assertTrue(any("BetterTransformer" in mod.__class__.__name__ for _, mod in model.named_modules())) |
|
|
| output = model.generate(**inp) |
|
|
| model = model.reverse_bettertransformer() |
|
|
| self.assertFalse(any("BetterTransformer" in mod.__class__.__name__ for _, mod in model.named_modules())) |
|
|
| with tempfile.TemporaryDirectory() as tmpdirname: |
| model.save_pretrained(tmpdirname) |
|
|
| model_reloaded = AutoModelForSeq2SeqLM.from_pretrained(tmpdirname) |
|
|
| self.assertFalse( |
| any("BetterTransformer" in mod.__class__.__name__ for _, mod in model_reloaded.named_modules()) |
| ) |
|
|
| output_from_pretrained = model_reloaded.generate(**inp) |
| torch.testing.assert_close(output, output_from_pretrained) |
|
|
| def test_error_save_pretrained(self): |
| r""" |
| The save_pretrained method should raise a ValueError if the model is in BetterTransformer mode. |
| All should be good if the model is reversed. |
| """ |
| model_id = "hf-internal-testing/tiny-random-t5" |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_id) |
|
|
| model = model.to_bettertransformer() |
|
|
| with tempfile.TemporaryDirectory() as tmpdirname: |
| with self.assertRaises(ValueError): |
| model.save_pretrained(tmpdirname) |
|
|
| model = model.reverse_bettertransformer() |
| model.save_pretrained(tmpdirname) |
|
|