| | import unittest |
| |
|
| | from transformers.testing_utils import Expectations |
| |
|
| |
|
| | class ExpectationsTest(unittest.TestCase): |
| | def test_expectations(self): |
| | |
| | |
| | expectations = Expectations( |
| | { |
| | (None, None): 1, |
| | ("cuda", 8): 2, |
| | ("cuda", 7): 3, |
| | ("rocm", 8): 4, |
| | ("rocm", None): 5, |
| | ("cpu", None): 6, |
| | ("xpu", 3): 7, |
| | } |
| | ) |
| |
|
| | def check(expected_id, device_prop): |
| | found_id = expectations.find_expectation(device_prop) |
| | assert found_id == expected_id, f"Expected {expected_id} for {device_prop}, found {found_id}" |
| |
|
| | |
| | check(1, ("npu", None, None)) |
| | check(7, ("xpu", 3, None)) |
| | check(2, ("cuda", 8, None)) |
| | check(3, ("cuda", 7, None)) |
| | check(4, ("rocm", 9, None)) |
| | check(4, ("rocm", None, None)) |
| | check(2, ("cuda", 2, None)) |
| |
|
| | |
| | expectations = Expectations({("cuda", 8): 1}) |
| | with self.assertRaises(ValueError): |
| | expectations.find_expectation(("xpu", None)) |
| |
|