File size: 7,462 Bytes
59f1501
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
# mypy: allow-untyped-defs

"""Testing utilities for Dynamo, providing a specialized TestCase class and test running functionality.



This module extends PyTorch's testing framework with Dynamo-specific testing capabilities.

It includes:

- A custom TestCase class that handles Dynamo-specific setup/teardown

- Test running utilities with dependency checking

- Automatic reset of Dynamo state between tests

- Proper handling of gradient mode state

"""

import contextlib
import importlib
import inspect
import logging
import os
import re
import sys
import unittest
from typing import Union

import torch
import torch.testing
from torch._logging._internal import trace_log
from torch.testing._internal.common_utils import (  # type: ignore[attr-defined]
    IS_WINDOWS,
    TEST_WITH_CROSSREF,
    TEST_WITH_TORCHDYNAMO,
    TestCase as TorchTestCase,
)

from . import config, reset, utils


log = logging.getLogger(__name__)


def run_tests(needs: Union[str, tuple[str, ...]] = ()) -> None:
    from torch.testing._internal.common_utils import run_tests

    if TEST_WITH_TORCHDYNAMO or TEST_WITH_CROSSREF:
        return  # skip testing

    if (
        not torch.xpu.is_available()
        and IS_WINDOWS
        and os.environ.get("TORCHINDUCTOR_WINDOWS_TESTS", "0") == "0"
    ):
        return

    if isinstance(needs, str):
        needs = (needs,)
    for need in needs:
        if need == "cuda":
            if not torch.cuda.is_available():
                return
        else:
            try:
                importlib.import_module(need)
            except ImportError:
                return
    run_tests()


class TestCase(TorchTestCase):
    _exit_stack: contextlib.ExitStack

    @classmethod
    def tearDownClass(cls) -> None:
        cls._exit_stack.close()
        super().tearDownClass()

    @classmethod
    def setUpClass(cls) -> None:
        super().setUpClass()
        cls._exit_stack = contextlib.ExitStack()  # type: ignore[attr-defined]
        cls._exit_stack.enter_context(  # type: ignore[attr-defined]
            config.patch(
                raise_on_ctx_manager_usage=True,
                suppress_errors=False,
                log_compilation_metrics=False,
            ),
        )

    def setUp(self) -> None:
        self._prior_is_grad_enabled = torch.is_grad_enabled()
        super().setUp()
        reset()
        utils.counters.clear()
        self.handler = logging.NullHandler()
        trace_log.addHandler(self.handler)

    def tearDown(self) -> None:
        trace_log.removeHandler(self.handler)
        for k, v in utils.counters.items():
            print(k, v.most_common())
        reset()
        utils.counters.clear()
        super().tearDown()
        if self._prior_is_grad_enabled is not torch.is_grad_enabled():
            log.warning("Running test changed grad mode")
            torch.set_grad_enabled(self._prior_is_grad_enabled)


class CPythonTestCase(TestCase):
    """

    Test class for CPython tests located in "test/dynamo/CPython/Py_version/*".



    This class enables specific features that are disabled by default, such as

    tracing through unittest methods.

    """

    _stack: contextlib.ExitStack
    dynamo_strict_nopython = True

    # Restore original unittest methods to simplify tracing CPython test cases.
    assertEqual = unittest.TestCase.assertEqual  # type: ignore[assignment]
    assertNotEqual = unittest.TestCase.assertNotEqual  # type: ignore[assignment]
    assertTrue = unittest.TestCase.assertTrue
    assertFalse = unittest.TestCase.assertFalse
    assertIs = unittest.TestCase.assertIs
    assertIsNot = unittest.TestCase.assertIsNot
    assertIsNone = unittest.TestCase.assertIsNone
    assertIsNotNone = unittest.TestCase.assertIsNotNone
    assertIn = unittest.TestCase.assertIn
    assertNotIn = unittest.TestCase.assertNotIn
    assertIsInstance = unittest.TestCase.assertIsInstance
    assertNotIsInstance = unittest.TestCase.assertNotIsInstance
    assertAlmostEqual = unittest.TestCase.assertAlmostEqual
    assertNotAlmostEqual = unittest.TestCase.assertNotAlmostEqual
    assertGreater = unittest.TestCase.assertGreater
    assertGreaterEqual = unittest.TestCase.assertGreaterEqual
    assertLess = unittest.TestCase.assertLess
    assertLessEqual = unittest.TestCase.assertLessEqual
    assertRegex = unittest.TestCase.assertRegex
    assertNotRegex = unittest.TestCase.assertNotRegex
    assertCountEqual = unittest.TestCase.assertCountEqual
    assertMultiLineEqual = unittest.TestCase.assertMultiLineEqual
    assertSequenceEqual = unittest.TestCase.assertSequenceEqual
    assertListEqual = unittest.TestCase.assertListEqual
    assertTupleEqual = unittest.TestCase.assertTupleEqual
    assertSetEqual = unittest.TestCase.assertSetEqual
    assertDictEqual = unittest.TestCase.assertDictEqual
    assertRaises = unittest.TestCase.assertRaises
    assertRaisesRegex = unittest.TestCase.assertRaisesRegex
    assertWarns = unittest.TestCase.assertWarns
    assertWarnsRegex = unittest.TestCase.assertWarnsRegex
    assertLogs = unittest.TestCase.assertLogs
    fail = unittest.TestCase.fail
    failureException = unittest.TestCase.failureException

    def compile_fn(self, fn, backend, nopython):
        # We want to compile only the test function, excluding any setup code
        # from unittest
        method = getattr(self, self._testMethodName)
        method = torch._dynamo.optimize(backend, nopython=nopython)(method)
        setattr(self, self._testMethodName, method)
        return fn

    def _dynamo_test_key(self):
        suffix = super()._dynamo_test_key()
        test_cls = self.__class__
        test_file = inspect.getfile(test_cls).split(os.sep)[-1].split(".")[0]
        py_ver = re.search(r"/([\d_]+)/", inspect.getfile(test_cls))
        if py_ver:
            py_ver = py_ver.group().strip(os.sep).replace("_", "")  # type: ignore[assignment]
        else:
            return suffix
        return f"CPython{py_ver}-{test_file}-{suffix}"

    @classmethod
    def tearDownClass(cls) -> None:
        cls._stack.close()
        super().tearDownClass()

    @classmethod
    def setUpClass(cls) -> None:
        # Skip test if python versions doesn't match
        prefix = os.path.join("dynamo", "cpython") + os.path.sep
        regex = re.escape(prefix) + r"\d_\d{2}"
        search_path = inspect.getfile(cls)
        m = re.search(regex, search_path)
        if m:
            test_py_ver = tuple(map(int, m.group().removeprefix(prefix).split("_")))
            py_ver = sys.version_info[:2]
            if py_ver < test_py_ver:
                expected = ".".join(map(str, test_py_ver))
                got = ".".join(map(str, py_ver))
                raise unittest.SkipTest(
                    f"Test requires Python {expected} but got Python {got}"
                )
        else:
            raise unittest.SkipTest(
                f"Test requires a specific Python version but not found in path {inspect.getfile(cls)}"
            )

        super().setUpClass()
        cls._stack = contextlib.ExitStack()  # type: ignore[attr-defined]
        cls._stack.enter_context(  # type: ignore[attr-defined]
            config.patch(
                enable_trace_unittest=True,
            ),
        )