| import unittest | |
| from transformers.testing_utils import Expectations | |
| class ExpectationsTest(unittest.TestCase): | |
| def test_expectations(self): | |
| expectations = Expectations( | |
| { | |
| (None, None): 1, | |
| ("cuda", 8): 2, | |
| ("cuda", 7): 3, | |
| ("rocm", 8): 4, | |
| ("rocm", None): 5, | |
| ("cpu", None): 6, | |
| ("xpu", 3): 7, | |
| } | |
| ) | |
| def check(value, key): | |
| assert expectations.find_expectation(key) == value | |
| # npu has no matches so should find default expectation | |
| check(1, ("npu", None)) | |
| check(7, ("xpu", 3)) | |
| check(2, ("cuda", 8)) | |
| check(3, ("cuda", 7)) | |
| check(4, ("rocm", 9)) | |
| check(4, ("rocm", None)) | |
| check(2, ("cuda", 2)) | |
| expectations = Expectations({("cuda", 8): 1}) | |
| with self.assertRaises(ValueError): | |
| expectations.find_expectation(("xpu", None)) | |