File size: 4,851 Bytes
a9bd396
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import unittest
from unittest.mock import patch

from transformers.testing_utils import require_kernels


@require_kernels
class HubKernelsTests(unittest.TestCase):
    def test_disable_hub_kernels(self):
        """
        Test that _kernels_enabled is False when USE_HUB_KERNELS when USE_HUB_KERNELS=OFF
        """
        with patch.dict(os.environ, {"USE_HUB_KERNELS": "ON"}):
            # Re-import to ensure the environment variable takes effect
            import importlib

            from transformers.integrations import hub_kernels

            importlib.reload(hub_kernels)

            # Verify that kernels are disabled
            self.assertFalse(hub_kernels._kernels_enabled)

    def test_enable_hub_kernels_default(self):
        """
        Test that _kernels_enabled is True when USE_HUB_KERNELS is not provided (default behavior)
        """
        # Remove USE_HUB_KERNELS from the environment if it exists
        env_without_hub_kernels = {k: v for k, v in os.environ.items() if k != "USE_HUB_KERNELS"}
        with patch.dict(os.environ, env_without_hub_kernels, clear=True):
            # Re-import to ensure the environment variable change takes effect
            import importlib

            from transformers.integrations import hub_kernels

            importlib.reload(hub_kernels)

            # Verify that kernels are enabled by default
            self.assertTrue(hub_kernels._kernels_enabled)

    def test_enable_hub_kernels_on(self):
        """
        Test that _kernels_enabled is True when USE_HUB_KERNELS=ON
        """
        with patch.dict(os.environ, {"USE_HUB_KERNELS": "ON"}):
            # Re-import to ensure the environment variable takes effect
            import importlib

            from transformers.integrations import hub_kernels

            importlib.reload(hub_kernels)

            # Verify that kernels are enabled
            self.assertTrue(hub_kernels._kernels_enabled)

    @patch("kernels.use_kernel_forward_from_hub")
    def test_use_kernel_forward_from_hub_not_called_when_disabled(self, mocked_use_kernel_forward):
        """
        Test that kernels.use_kernel_forward_from_hub is not called when USE_HUB_KERNELS is disabled
        """
        # Set environment variable to disable hub kernels
        with patch.dict(os.environ, {"USE_HUB_KERNELS": "OFF"}):
            # Re-import to ensure the environment variable takes effect
            import importlib

            from transformers.integrations import hub_kernels

            importlib.reload(hub_kernels)

            # Call the function with a test layer name
            decorator = hub_kernels.use_kernel_forward_from_hub("DummyLayer")

            # Verify that the kernels function was never called
            mocked_use_kernel_forward.assert_not_called()

            # Verify that we get a no-op decorator
            class FooClass:
                pass

            result = decorator(FooClass)
            self.assertIs(result, FooClass)

    @patch("kernels.use_kernel_forward_from_hub")
    def test_use_kernel_forward_from_hub_called_when_enabled_default(self, mocked_use_kernel_forward):
        """
        Test that kernels.use_kernel_forward_from_hub is called when USE_HUB_KERNELS is not set (default)
        """
        # Remove USE_HUB_KERNELS from the environment if it exists
        env_without_hub_kernels = {k: v for k, v in os.environ.items() if k != "USE_HUB_KERNELS"}
        with patch.dict(os.environ, env_without_hub_kernels, clear=True):
            # Re-import to ensure the environment variable change takes effect
            import importlib

            from transformers.integrations import hub_kernels

            importlib.reload(hub_kernels)

            # Call the function with a test layer name
            hub_kernels.use_kernel_forward_from_hub("FooLayer")

            # Verify that the kernels function was called once with the correct argument
            mocked_use_kernel_forward.assert_called_once_with("FooLayer")

    @patch("kernels.use_kernel_forward_from_hub")
    def test_use_kernel_forward_from_hub_called_when_enabled_on(self, mocked_use_kernel_forward):
        """
        Test that kernels.use_kernel_forward_from_hub is called when USE_HUB_KERNELS=ON
        """
        with patch.dict(os.environ, {"USE_HUB_KERNELS": "ON"}):
            # Re-import to ensure the environment variable change takes effect
            import importlib

            from transformers.integrations import hub_kernels

            importlib.reload(hub_kernels)

            # Call the function with a test layer name
            hub_kernels.use_kernel_forward_from_hub("FooLayer")

            # Verify that the kernels function was called once with the correct argument
            mocked_use_kernel_forward.assert_called_once_with("FooLayer")