File size: 981 Bytes
cb2428f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import unittest

from transformers.testing_utils import Expectations


class ExpectationsTest(unittest.TestCase):
    def test_expectations(self):
        expectations = Expectations(
            {
                (None, None): 1,
                ("cuda", 8): 2,
                ("cuda", 7): 3,
                ("rocm", 8): 4,
                ("rocm", None): 5,
                ("cpu", None): 6,
                ("xpu", 3): 7,
            }
        )

        def check(value, key):
            assert expectations.find_expectation(key) == value

        # npu has no matches so should find default expectation
        check(1, ("npu", None))
        check(7, ("xpu", 3))
        check(2, ("cuda", 8))
        check(3, ("cuda", 7))
        check(4, ("rocm", 9))
        check(4, ("rocm", None))
        check(2, ("cuda", 2))

        expectations = Expectations({("cuda", 8): 1})
        with self.assertRaises(ValueError):
            expectations.find_expectation(("xpu", None))