diff --git a/.gitattributes b/.gitattributes index eb711db8e1e8071e662fc9168420d027af0102ef..fcfca0322cd34739ebe732a8e6c18311df99e157 100644 --- a/.gitattributes +++ b/.gitattributes @@ -38,3 +38,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/networkx/algorith tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Symtab.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Plex/Scanners.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Plex/DFA.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/FusedNode.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/TestCyCache.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/TestCyCache.py new file mode 100644 index 0000000000000000000000000000000000000000..a1b49320c22c1144ca2928fc52d97062ab3affdb --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/TestCyCache.py @@ -0,0 +1,119 @@ +import difflib +import glob +import gzip +import os +import sys +import tempfile +import unittest + +import Cython.Build.Dependencies +import Cython.Utils +from Cython.TestUtils import CythonTest + + +class TestCyCache(CythonTest): + + def setUp(self): + CythonTest.setUp(self) + self.temp_dir = tempfile.mkdtemp( + prefix='cycache-test', + dir='TEST_TMP' if os.path.isdir('TEST_TMP') else None) + self.src_dir = tempfile.mkdtemp(prefix='src', dir=self.temp_dir) + self.cache_dir = tempfile.mkdtemp(prefix='cache', dir=self.temp_dir) + + def cache_files(self, file_glob): + return glob.glob(os.path.join(self.cache_dir, file_glob)) + + def fresh_cythonize(self, *args, **kwargs): + Cython.Utils.clear_function_caches() + Cython.Build.Dependencies._dep_tree = None # discard method caches + Cython.Build.Dependencies.cythonize(*args, **kwargs) + + def test_cycache_switch(self): + content1 = 'value = 1\n' + content2 = 'value = 2\n' + a_pyx = os.path.join(self.src_dir, 'a.pyx') + a_c = a_pyx[:-4] + '.c' + + with open(a_pyx, 'w') as f: + f.write(content1) + self.fresh_cythonize(a_pyx, cache=self.cache_dir) + self.fresh_cythonize(a_pyx, cache=self.cache_dir) + self.assertEqual(1, len(self.cache_files('a.c*'))) + with open(a_c) as f: + a_contents1 = f.read() + os.unlink(a_c) + + with open(a_pyx, 'w') as f: + f.write(content2) + self.fresh_cythonize(a_pyx, cache=self.cache_dir) + with open(a_c) as f: + a_contents2 = f.read() + os.unlink(a_c) + + self.assertNotEqual(a_contents1, a_contents2, 'C file not changed!') + self.assertEqual(2, len(self.cache_files('a.c*'))) + + with open(a_pyx, 'w') as f: + f.write(content1) + self.fresh_cythonize(a_pyx, cache=self.cache_dir) + self.assertEqual(2, len(self.cache_files('a.c*'))) + with open(a_c) as f: + a_contents = f.read() + self.assertEqual( + a_contents, a_contents1, + msg='\n'.join(list(difflib.unified_diff( + a_contents.split('\n'), a_contents1.split('\n')))[:10])) + + def test_cycache_uses_cache(self): + a_pyx = os.path.join(self.src_dir, 'a.pyx') + a_c = a_pyx[:-4] + '.c' + with open(a_pyx, 'w') as f: + f.write('pass') + self.fresh_cythonize(a_pyx, cache=self.cache_dir) + a_cache = os.path.join(self.cache_dir, os.listdir(self.cache_dir)[0]) + with gzip.GzipFile(a_cache, 'wb') as gzipfile: + gzipfile.write('fake stuff'.encode('ascii')) + os.unlink(a_c) + self.fresh_cythonize(a_pyx, cache=self.cache_dir) + with open(a_c) as f: + a_contents = f.read() + self.assertEqual(a_contents, 'fake stuff', + 'Unexpected contents: %s...' % a_contents[:100]) + + def test_multi_file_output(self): + a_pyx = os.path.join(self.src_dir, 'a.pyx') + a_c = a_pyx[:-4] + '.c' + a_h = a_pyx[:-4] + '.h' + a_api_h = a_pyx[:-4] + '_api.h' + with open(a_pyx, 'w') as f: + f.write('cdef public api int foo(int x): return x\n') + self.fresh_cythonize(a_pyx, cache=self.cache_dir) + expected = [a_c, a_h, a_api_h] + for output in expected: + self.assertTrue(os.path.exists(output), output) + os.unlink(output) + self.fresh_cythonize(a_pyx, cache=self.cache_dir) + for output in expected: + self.assertTrue(os.path.exists(output), output) + + def test_options_invalidation(self): + hash_pyx = os.path.join(self.src_dir, 'options.pyx') + hash_c = hash_pyx[:-len('.pyx')] + '.c' + + with open(hash_pyx, 'w') as f: + f.write('pass') + self.fresh_cythonize(hash_pyx, cache=self.cache_dir, cplus=False) + self.assertEqual(1, len(self.cache_files('options.c*'))) + + os.unlink(hash_c) + self.fresh_cythonize(hash_pyx, cache=self.cache_dir, cplus=True) + self.assertEqual(2, len(self.cache_files('options.c*'))) + + os.unlink(hash_c) + self.fresh_cythonize(hash_pyx, cache=self.cache_dir, cplus=False, show_version=False) + self.assertEqual(2, len(self.cache_files('options.c*'))) + + os.unlink(hash_c) + self.fresh_cythonize(hash_pyx, cache=self.cache_dir, cplus=False, show_version=True) + self.assertEqual(2, len(self.cache_files('options.c*'))) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/TestDependencies.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/TestDependencies.py new file mode 100644 index 0000000000000000000000000000000000000000..d3888117d846d67f91fec9cf14d8f852a5a65d82 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/TestDependencies.py @@ -0,0 +1,142 @@ +import contextlib +import os.path +import sys +import tempfile +import unittest +from io import open +from os.path import join as pjoin + +from ..Dependencies import extended_iglob + + +@contextlib.contextmanager +def writable_file(dir_path, filename): + with open(pjoin(dir_path, filename), "w", encoding="utf8") as f: + yield f + + +class TestGlobbing(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._orig_dir = os.getcwd() + if sys.version_info[0] < 3: + temp_path = cls._tmpdir = tempfile.mkdtemp() + else: + cls._tmpdir = tempfile.TemporaryDirectory() + temp_path = cls._tmpdir.name + os.chdir(temp_path) + + for dir1 in "abcd": + for dir1x in [dir1, dir1 + 'x']: + for dir2 in "xyz": + dir_path = pjoin(dir1x, dir2) + os.makedirs(dir_path) + with writable_file(dir_path, "file2_pyx.pyx") as f: + f.write(u'""" PYX """') + with writable_file(dir_path, "file2_py.py") as f: + f.write(u'""" PY """') + + with writable_file(dir1x, "file1_pyx.pyx") as f: + f.write(u'""" PYX """') + with writable_file(dir1x, "file1_py.py") as f: + f.write(u'""" PY """') + + @classmethod + def tearDownClass(cls): + os.chdir(cls._orig_dir) + if sys.version_info[0] < 3: + import shutil + shutil.rmtree(cls._tmpdir) + else: + cls._tmpdir.cleanup() + + def files_equal(self, pattern, expected_files): + expected_files = sorted(expected_files) + # It's the users's choice whether '/' will appear on Windows. + matched_files = sorted(path.replace('/', os.sep) for path in extended_iglob(pattern)) + self.assertListEqual(matched_files, expected_files) # / + + # Special case for Windows: also support '\' in patterns. + if os.sep == '\\' and '/' in pattern: + matched_files = sorted(extended_iglob(pattern.replace('/', '\\'))) + self.assertListEqual(matched_files, expected_files) # \ + + def test_extended_iglob_simple(self): + ax_files = [pjoin("a", "x", "file2_pyx.pyx"), pjoin("a", "x", "file2_py.py")] + self.files_equal("a/x/*", ax_files) + self.files_equal("a/x/*.c12", []) + self.files_equal("a/x/*.{py,pyx,c12}", ax_files) + self.files_equal("a/x/*.{py,pyx}", ax_files) + self.files_equal("a/x/*.{pyx}", ax_files[:1]) + self.files_equal("a/x/*.pyx", ax_files[:1]) + self.files_equal("a/x/*.{py}", ax_files[1:]) + self.files_equal("a/x/*.py", ax_files[1:]) + + def test_extended_iglob_simple_star(self): + for basedir in "ad": + files = [ + pjoin(basedir, dirname, filename) + for dirname in "xyz" + for filename in ["file2_pyx.pyx", "file2_py.py"] + ] + self.files_equal(basedir + "/*/*", files) + self.files_equal(basedir + "/*/*.c12", []) + self.files_equal(basedir + "/*/*.{py,pyx,c12}", files) + self.files_equal(basedir + "/*/*.{py,pyx}", files) + self.files_equal(basedir + "/*/*.{pyx}", files[::2]) + self.files_equal(basedir + "/*/*.pyx", files[::2]) + self.files_equal(basedir + "/*/*.{py}", files[1::2]) + self.files_equal(basedir + "/*/*.py", files[1::2]) + + for subdir in "xy*": + files = [ + pjoin(basedir, dirname, filename) + for dirname in "xyz" + if subdir in ('*', dirname) + for filename in ["file2_pyx.pyx", "file2_py.py"] + ] + path = basedir + '/' + subdir + '/' + self.files_equal(path + "*", files) + self.files_equal(path + "*.{py,pyx}", files) + self.files_equal(path + "*.{pyx}", files[::2]) + self.files_equal(path + "*.pyx", files[::2]) + self.files_equal(path + "*.{py}", files[1::2]) + self.files_equal(path + "*.py", files[1::2]) + + def test_extended_iglob_double_star(self): + basedirs = os.listdir(".") + files = [ + pjoin(basedir, dirname, filename) + for basedir in basedirs + for dirname in "xyz" + for filename in ["file2_pyx.pyx", "file2_py.py"] + ] + all_files = [ + pjoin(basedir, filename) + for basedir in basedirs + for filename in ["file1_pyx.pyx", "file1_py.py"] + ] + files + self.files_equal("*/*/*", files) + self.files_equal("*/*/**/*", files) + self.files_equal("*/**/*.*", all_files) + self.files_equal("**/*.*", all_files) + self.files_equal("*/**/*.c12", []) + self.files_equal("**/*.c12", []) + self.files_equal("*/*/*.{py,pyx,c12}", files) + self.files_equal("*/*/**/*.{py,pyx,c12}", files) + self.files_equal("*/**/*/*.{py,pyx,c12}", files) + self.files_equal("**/*/*/*.{py,pyx,c12}", files) + self.files_equal("**/*.{py,pyx,c12}", all_files) + self.files_equal("*/*/*.{py,pyx}", files) + self.files_equal("**/*/*/*.{py,pyx}", files) + self.files_equal("*/**/*/*.{py,pyx}", files) + self.files_equal("**/*.{py,pyx}", all_files) + self.files_equal("*/*/*.{pyx}", files[::2]) + self.files_equal("**/*.{pyx}", all_files[::2]) + self.files_equal("*/**/*/*.pyx", files[::2]) + self.files_equal("*/*/*.pyx", files[::2]) + self.files_equal("**/*.pyx", all_files[::2]) + self.files_equal("*/*/*.{py}", files[1::2]) + self.files_equal("**/*.{py}", all_files[1::2]) + self.files_equal("*/*/*.py", files[1::2]) + self.files_equal("**/*.py", all_files[1::2]) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/TestIpythonMagic.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/TestIpythonMagic.py new file mode 100644 index 0000000000000000000000000000000000000000..65d801c6b72e883fc766f034836234973d987ffe --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/TestIpythonMagic.py @@ -0,0 +1,295 @@ +# -*- coding: utf-8 -*- +# tag: ipython + +"""Tests for the Cython magics extension.""" + +from __future__ import absolute_import + +import os +import io +import sys +from contextlib import contextmanager +from unittest import skipIf + +from Cython.Build import IpythonMagic +from Cython.TestUtils import CythonTest +from Cython.Compiler.Annotate import AnnotationCCodeWriter + +try: + import IPython.testing.globalipapp +except ImportError: + # Disable tests and fake helpers for initialisation below. + def skip_if_not_installed(_): + return None +else: + def skip_if_not_installed(c): + return c + +# not using IPython's decorators here because they depend on "nose" +skip_win32 = skipIf(sys.platform == 'win32', "Skip on Windows") +skip_py27 = skipIf(sys.version_info[:2] == (2,7), "Disabled in Py2.7") + +try: + # disable IPython history thread before it gets started to avoid having to clean it up + from IPython.core.history import HistoryManager + HistoryManager.enabled = False +except ImportError: + pass + + +@contextmanager +def capture_output(): + backup = sys.stdout, sys.stderr + try: + replacement = [ + io.TextIOWrapper(io.BytesIO(), encoding=sys.stdout.encoding), + io.TextIOWrapper(io.BytesIO(), encoding=sys.stderr.encoding), + ] + sys.stdout, sys.stderr = replacement + output = [] + yield output + finally: + sys.stdout, sys.stderr = backup + for wrapper in replacement: + wrapper.seek(0) # rewind + output.append(wrapper.read()) + wrapper.close() + + +code = u"""\ +def f(x): + return 2*x +""" + +cython3_code = u"""\ +def f(int x): + return 2 / x + +def call(x): + return f(*(x,)) +""" + +pgo_cython3_code = cython3_code + u"""\ +def main(): + for _ in range(100): call(5) +main() +""" + +compile_error_code = u'''\ +cdef extern from *: + """ + xxx a=1; + """ + int a; +def doit(): + return a +''' + +compile_warning_code = u'''\ +cdef extern from *: + """ + #pragma message ( "CWarning" ) + int a = 42; + """ + int a; +def doit(): + return a +''' + + +@skip_if_not_installed +class TestIPythonMagic(CythonTest): + + @classmethod + def setUpClass(cls): + CythonTest.setUpClass() + cls._ip = IPython.testing.globalipapp.get_ipython() + + def setUp(self): + CythonTest.setUp(self) + self._ip.extension_manager.load_extension('cython') + + def test_cython_inline(self): + ip = self._ip + ip.ex('a=10; b=20') + result = ip.run_cell_magic('cython_inline', '', 'return a+b') + self.assertEqual(result, 30) + + @skip_win32 + def test_cython_pyximport(self): + ip = self._ip + module_name = '_test_cython_pyximport' + ip.run_cell_magic('cython_pyximport', module_name, code) + ip.ex('g = f(10)') + self.assertEqual(ip.user_ns['g'], 20.0) + ip.run_cell_magic('cython_pyximport', module_name, code) + ip.ex('h = f(-10)') + self.assertEqual(ip.user_ns['h'], -20.0) + try: + os.remove(module_name + '.pyx') + except OSError: + pass + + def test_cython(self): + ip = self._ip + ip.run_cell_magic('cython', '', code) + ip.ex('g = f(10)') + self.assertEqual(ip.user_ns['g'], 20.0) + + def test_cython_name(self): + # The Cython module named 'mymodule' defines the function f. + ip = self._ip + ip.run_cell_magic('cython', '--name=mymodule', code) + # This module can now be imported in the interactive namespace. + ip.ex('import mymodule; g = mymodule.f(10)') + self.assertEqual(ip.user_ns['g'], 20.0) + + def test_cython_language_level(self): + # The Cython cell defines the functions f() and call(). + ip = self._ip + ip.run_cell_magic('cython', '', cython3_code) + ip.ex('g = f(10); h = call(10)') + if sys.version_info[0] < 3: + self.assertEqual(ip.user_ns['g'], 2 // 10) + self.assertEqual(ip.user_ns['h'], 2 // 10) + else: + self.assertEqual(ip.user_ns['g'], 2.0 / 10.0) + self.assertEqual(ip.user_ns['h'], 2.0 / 10.0) + + def test_cython3(self): + # The Cython cell defines the functions f() and call(). + ip = self._ip + ip.run_cell_magic('cython', '-3', cython3_code) + ip.ex('g = f(10); h = call(10)') + self.assertEqual(ip.user_ns['g'], 2.0 / 10.0) + self.assertEqual(ip.user_ns['h'], 2.0 / 10.0) + + def test_cython2(self): + # The Cython cell defines the functions f() and call(). + ip = self._ip + ip.run_cell_magic('cython', '-2', cython3_code) + ip.ex('g = f(10); h = call(10)') + self.assertEqual(ip.user_ns['g'], 2 // 10) + self.assertEqual(ip.user_ns['h'], 2 // 10) + + def test_cython_compile_error_shown(self): + ip = self._ip + with capture_output() as out: + ip.run_cell_magic('cython', '-3', compile_error_code) + captured_out, captured_err = out + + # it could be that c-level output is captured by distutil-extension + # (and not by us) and is printed to stdout: + captured_all = captured_out + "\n" + captured_err + self.assertTrue("error" in captured_all, msg="error in " + captured_all) + + def test_cython_link_error_shown(self): + ip = self._ip + with capture_output() as out: + ip.run_cell_magic('cython', '-3 -l=xxxxxxxx', code) + captured_out, captured_err = out + + # it could be that c-level output is captured by distutil-extension + # (and not by us) and is printed to stdout: + captured_all = captured_out + "\n!" + captured_err + self.assertTrue("error" in captured_all, msg="error in " + captured_all) + + def test_cython_warning_shown(self): + ip = self._ip + with capture_output() as out: + # force rebuild, otherwise no warning as after the first success + # no build step is performed + ip.run_cell_magic('cython', '-3 -f', compile_warning_code) + captured_out, captured_err = out + + # check that warning was printed to stdout even if build hasn't failed + self.assertTrue("CWarning" in captured_out) + + @skip_py27 # Not strictly broken in Py2.7 but currently fails in CI due to C compiler issues. + @skip_win32 + def test_cython3_pgo(self): + # The Cython cell defines the functions f() and call(). + ip = self._ip + ip.run_cell_magic('cython', '-3 --pgo', pgo_cython3_code) + ip.ex('g = f(10); h = call(10); main()') + self.assertEqual(ip.user_ns['g'], 2.0 / 10.0) + self.assertEqual(ip.user_ns['h'], 2.0 / 10.0) + + @skip_win32 + def test_extlibs(self): + ip = self._ip + code = u""" +from libc.math cimport sin +x = sin(0.0) + """ + ip.user_ns['x'] = 1 + ip.run_cell_magic('cython', '-l m', code) + self.assertEqual(ip.user_ns['x'], 0) + + + def test_cython_verbose(self): + ip = self._ip + ip.run_cell_magic('cython', '--verbose', code) + ip.ex('g = f(10)') + self.assertEqual(ip.user_ns['g'], 20.0) + + def test_cython_verbose_thresholds(self): + @contextmanager + def mock_distutils(): + class MockLog: + DEBUG = 1 + INFO = 2 + thresholds = [INFO] + + def set_threshold(self, val): + self.thresholds.append(val) + return self.thresholds[-2] + + + new_log = MockLog() + old_log = IpythonMagic.distutils.log + try: + IpythonMagic.distutils.log = new_log + yield new_log + finally: + IpythonMagic.distutils.log = old_log + + ip = self._ip + with mock_distutils() as verbose_log: + ip.run_cell_magic('cython', '--verbose', code) + ip.ex('g = f(10)') + self.assertEqual(ip.user_ns['g'], 20.0) + self.assertEqual([verbose_log.INFO, verbose_log.DEBUG, verbose_log.INFO], + verbose_log.thresholds) + + with mock_distutils() as normal_log: + ip.run_cell_magic('cython', '', code) + ip.ex('g = f(10)') + self.assertEqual(ip.user_ns['g'], 20.0) + self.assertEqual([normal_log.INFO], normal_log.thresholds) + + def test_cython_no_annotate(self): + ip = self._ip + html = ip.run_cell_magic('cython', '', code) + self.assertTrue(html is None) + + def test_cython_annotate(self): + ip = self._ip + html = ip.run_cell_magic('cython', '--annotate', code) + # somewhat brittle way to differentiate between annotated htmls + # with/without complete source code: + self.assertTrue(AnnotationCCodeWriter.COMPLETE_CODE_TITLE not in html.data) + + def test_cython_annotate_default(self): + ip = self._ip + html = ip.run_cell_magic('cython', '-a', code) + # somewhat brittle way to differentiate between annotated htmls + # with/without complete source code: + self.assertTrue(AnnotationCCodeWriter.COMPLETE_CODE_TITLE not in html.data) + + def test_cython_annotate_complete_c_code(self): + ip = self._ip + html = ip.run_cell_magic('cython', '--annotate-fullc', code) + # somewhat brittle way to differentiate between annotated htmls + # with/without complete source code: + self.assertTrue(AnnotationCCodeWriter.COMPLETE_CODE_TITLE in html.data) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/TestStripLiterals.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/TestStripLiterals.py new file mode 100644 index 0000000000000000000000000000000000000000..cbe5c65a906b1759f17a026d63213d0c936ab66e --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/TestStripLiterals.py @@ -0,0 +1,56 @@ +from Cython.Build.Dependencies import strip_string_literals + +from Cython.TestUtils import CythonTest + +class TestStripLiterals(CythonTest): + + def t(self, before, expected): + actual, literals = strip_string_literals(before, prefix="_L") + self.assertEqual(expected, actual) + for key, value in literals.items(): + actual = actual.replace(key, value) + self.assertEqual(before, actual) + + def test_empty(self): + self.t("", "") + + def test_single_quote(self): + self.t("'x'", "'_L1_'") + + def test_double_quote(self): + self.t('"x"', '"_L1_"') + + def test_nested_quotes(self): + self.t(""" '"' "'" """, """ '_L1_' "_L2_" """) + + def test_triple_quote(self): + self.t(" '''a\n''' ", " '''_L1_''' ") + + def test_backslash(self): + self.t(r"'a\'b'", "'_L1_'") + self.t(r"'a\\'", "'_L1_'") + self.t(r"'a\\\'b'", "'_L1_'") + + def test_unicode(self): + self.t("u'abc'", "u'_L1_'") + + def test_raw(self): + self.t(r"r'abc\\'", "r'_L1_'") + + def test_raw_unicode(self): + self.t(r"ru'abc\\'", "ru'_L1_'") + + def test_comment(self): + self.t("abc # foo", "abc #_L1_") + + def test_comment_and_quote(self): + self.t("abc # 'x'", "abc #_L1_") + self.t("'abc#'", "'_L1_'") + + def test_include(self): + self.t("include 'a.pxi' # something here", + "include '_L1_' #_L2_") + + def test_extern(self): + self.t("cdef extern from 'a.h': # comment", + "cdef extern from '_L1_': #_L2_") diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestCythonizeArgsParser.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestCythonizeArgsParser.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d76a6eb48351831a6d7db0290cab203bb5695063 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestCythonizeArgsParser.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestDependencies.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestDependencies.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc407dd4f0f2062a4b8ab346fc9286330272d1ff Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestDependencies.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestInline.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestInline.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..043da4cbcc35019bd9212005e9f8c921e6b649fc Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestInline.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestIpythonMagic.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestIpythonMagic.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4d7ec80f6a989713a0325a976955d1a7f8cf95e Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestIpythonMagic.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestRecythonize.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestRecythonize.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3627b3d4b21440d17e992326fcc7c7c13d91ab94 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestRecythonize.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestStripLiterals.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestStripLiterals.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f78112c17091de300bf0021aa22250d81313b4b0 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestStripLiterals.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/Buffer.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/Buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..e86e1e9c24d206d537d682b7c0b67c1b0fabf96c --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/Buffer.py @@ -0,0 +1,749 @@ +from __future__ import absolute_import + +from .Visitor import CythonTransform +from .ModuleNode import ModuleNode +from .Errors import CompileError +from .UtilityCode import CythonUtilityCode +from .Code import UtilityCode, TempitaUtilityCode + +from . import Options +from . import Interpreter +from . import PyrexTypes +from . import Naming +from . import Symtab + +def dedent(text, reindent=0): + from textwrap import dedent + text = dedent(text) + if reindent > 0: + indent = " " * reindent + text = '\n'.join([indent + x for x in text.split('\n')]) + return text + +class IntroduceBufferAuxiliaryVars(CythonTransform): + + # + # Entry point + # + + buffers_exists = False + using_memoryview = False + + def __call__(self, node): + assert isinstance(node, ModuleNode) + self.max_ndim = 0 + result = super(IntroduceBufferAuxiliaryVars, self).__call__(node) + if self.buffers_exists: + use_bufstruct_declare_code(node.scope) + use_py2_buffer_functions(node.scope) + + return result + + + # + # Basic operations for transforms + # + def handle_scope(self, node, scope): + # For all buffers, insert extra variables in the scope. + # The variables are also accessible from the buffer_info + # on the buffer entry + scope_items = scope.entries.items() + bufvars = [entry for name, entry in scope_items if entry.type.is_buffer] + if len(bufvars) > 0: + bufvars.sort(key=lambda entry: entry.name) + self.buffers_exists = True + + memviewslicevars = [entry for name, entry in scope_items if entry.type.is_memoryviewslice] + if len(memviewslicevars) > 0: + self.buffers_exists = True + + + for (name, entry) in scope_items: + if name == 'memoryview' and isinstance(entry.utility_code_definition, CythonUtilityCode): + self.using_memoryview = True + break + del scope_items + + if isinstance(node, ModuleNode) and len(bufvars) > 0: + # for now...note that pos is wrong + raise CompileError(node.pos, "Buffer vars not allowed in module scope") + for entry in bufvars: + if entry.type.dtype.is_ptr: + raise CompileError(node.pos, "Buffers with pointer types not yet supported.") + + name = entry.name + buftype = entry.type + if buftype.ndim > Options.buffer_max_dims: + raise CompileError(node.pos, + "Buffer ndims exceeds Options.buffer_max_dims = %d" % Options.buffer_max_dims) + if buftype.ndim > self.max_ndim: + self.max_ndim = buftype.ndim + + # Declare auxiliary vars + def decvar(type, prefix): + cname = scope.mangle(prefix, name) + aux_var = scope.declare_var(name=None, cname=cname, + type=type, pos=node.pos) + if entry.is_arg: + aux_var.used = True # otherwise, NameNode will mark whether it is used + + return aux_var + + auxvars = ((PyrexTypes.c_pyx_buffer_nd_type, Naming.pybuffernd_prefix), + (PyrexTypes.c_pyx_buffer_type, Naming.pybufferstruct_prefix)) + pybuffernd, rcbuffer = [decvar(type, prefix) for (type, prefix) in auxvars] + + entry.buffer_aux = Symtab.BufferAux(pybuffernd, rcbuffer) + + scope.buffer_entries = bufvars + self.scope = scope + + def visit_ModuleNode(self, node): + self.handle_scope(node, node.scope) + self.visitchildren(node) + return node + + def visit_FuncDefNode(self, node): + self.handle_scope(node, node.local_scope) + self.visitchildren(node) + return node + +# +# Analysis +# +buffer_options = ("dtype", "ndim", "mode", "negative_indices", "cast") # ordered! +buffer_defaults = {"ndim": 1, "mode": "full", "negative_indices": True, "cast": False} +buffer_positional_options_count = 1 # anything beyond this needs keyword argument + +ERR_BUF_OPTION_UNKNOWN = '"%s" is not a buffer option' +ERR_BUF_TOO_MANY = 'Too many buffer options' +ERR_BUF_DUP = '"%s" buffer option already supplied' +ERR_BUF_MISSING = '"%s" missing' +ERR_BUF_MODE = 'Only allowed buffer modes are: "c", "fortran", "full", "strided" (as a compile-time string)' +ERR_BUF_NDIM = 'ndim must be a non-negative integer' +ERR_BUF_DTYPE = 'dtype must be "object", numeric type or a struct' +ERR_BUF_BOOL = '"%s" must be a boolean' + +def analyse_buffer_options(globalpos, env, posargs, dictargs, defaults=None, need_complete=True): + """ + Must be called during type analysis, as analyse is called + on the dtype argument. + + posargs and dictargs should consist of a list and a dict + of tuples (value, pos). Defaults should be a dict of values. + + Returns a dict containing all the options a buffer can have and + its value (with the positions stripped). + """ + if defaults is None: + defaults = buffer_defaults + + posargs, dictargs = Interpreter.interpret_compiletime_options( + posargs, dictargs, type_env=env, type_args=(0, 'dtype')) + + if len(posargs) > buffer_positional_options_count: + raise CompileError(posargs[-1][1], ERR_BUF_TOO_MANY) + + options = {} + for name, (value, pos) in dictargs.items(): + if name not in buffer_options: + raise CompileError(pos, ERR_BUF_OPTION_UNKNOWN % name) + options[name] = value + + for name, (value, pos) in zip(buffer_options, posargs): + if name not in buffer_options: + raise CompileError(pos, ERR_BUF_OPTION_UNKNOWN % name) + if name in options: + raise CompileError(pos, ERR_BUF_DUP % name) + options[name] = value + + # Check that they are all there and copy defaults + for name in buffer_options: + if name not in options: + try: + options[name] = defaults[name] + except KeyError: + if need_complete: + raise CompileError(globalpos, ERR_BUF_MISSING % name) + + dtype = options.get("dtype") + if dtype and dtype.is_extension_type: + raise CompileError(globalpos, ERR_BUF_DTYPE) + + ndim = options.get("ndim") + if ndim and (not isinstance(ndim, int) or ndim < 0): + raise CompileError(globalpos, ERR_BUF_NDIM) + + mode = options.get("mode") + if mode and not (mode in ('full', 'strided', 'c', 'fortran')): + raise CompileError(globalpos, ERR_BUF_MODE) + + def assert_bool(name): + x = options.get(name) + if not isinstance(x, bool): + raise CompileError(globalpos, ERR_BUF_BOOL % name) + + assert_bool('negative_indices') + assert_bool('cast') + + return options + + +# +# Code generation +# + +class BufferEntry(object): + def __init__(self, entry): + self.entry = entry + self.type = entry.type + self.cname = entry.buffer_aux.buflocal_nd_var.cname + self.buf_ptr = "%s.rcbuffer->pybuffer.buf" % self.cname + self.buf_ptr_type = entry.type.buffer_ptr_type + self.init_attributes() + + def init_attributes(self): + self.shape = self.get_buf_shapevars() + self.strides = self.get_buf_stridevars() + self.suboffsets = self.get_buf_suboffsetvars() + + def get_buf_suboffsetvars(self): + return self._for_all_ndim("%s.diminfo[%d].suboffsets") + + def get_buf_stridevars(self): + return self._for_all_ndim("%s.diminfo[%d].strides") + + def get_buf_shapevars(self): + return self._for_all_ndim("%s.diminfo[%d].shape") + + def _for_all_ndim(self, s): + return [s % (self.cname, i) for i in range(self.type.ndim)] + + def generate_buffer_lookup_code(self, code, index_cnames): + # Create buffer lookup and return it + # This is done via utility macros/inline functions, which vary + # according to the access mode used. + params = [] + nd = self.type.ndim + mode = self.type.mode + if mode == 'full': + for i, s, o in zip(index_cnames, + self.get_buf_stridevars(), + self.get_buf_suboffsetvars()): + params.append(i) + params.append(s) + params.append(o) + funcname = "__Pyx_BufPtrFull%dd" % nd + funcgen = buf_lookup_full_code + else: + if mode == 'strided': + funcname = "__Pyx_BufPtrStrided%dd" % nd + funcgen = buf_lookup_strided_code + elif mode == 'c': + funcname = "__Pyx_BufPtrCContig%dd" % nd + funcgen = buf_lookup_c_code + elif mode == 'fortran': + funcname = "__Pyx_BufPtrFortranContig%dd" % nd + funcgen = buf_lookup_fortran_code + else: + assert False + for i, s in zip(index_cnames, self.get_buf_stridevars()): + params.append(i) + params.append(s) + + # Make sure the utility code is available + if funcname not in code.globalstate.utility_codes: + code.globalstate.utility_codes.add(funcname) + protocode = code.globalstate['utility_code_proto'] + defcode = code.globalstate['utility_code_def'] + funcgen(protocode, defcode, name=funcname, nd=nd) + + buf_ptr_type_code = self.buf_ptr_type.empty_declaration_code() + ptrcode = "%s(%s, %s, %s)" % (funcname, buf_ptr_type_code, self.buf_ptr, + ", ".join(params)) + return ptrcode + + +def get_flags(buffer_aux, buffer_type): + flags = 'PyBUF_FORMAT' + mode = buffer_type.mode + if mode == 'full': + flags += '| PyBUF_INDIRECT' + elif mode == 'strided': + flags += '| PyBUF_STRIDES' + elif mode == 'c': + flags += '| PyBUF_C_CONTIGUOUS' + elif mode == 'fortran': + flags += '| PyBUF_F_CONTIGUOUS' + else: + assert False + if buffer_aux.writable_needed: flags += "| PyBUF_WRITABLE" + return flags + +def used_buffer_aux_vars(entry): + buffer_aux = entry.buffer_aux + buffer_aux.buflocal_nd_var.used = True + buffer_aux.rcbuf_var.used = True + +def put_unpack_buffer_aux_into_scope(buf_entry, code): + # Generate code to copy the needed struct info into local + # variables. + buffer_aux, mode = buf_entry.buffer_aux, buf_entry.type.mode + pybuffernd_struct = buffer_aux.buflocal_nd_var.cname + + fldnames = ['strides', 'shape'] + if mode == 'full': + fldnames.append('suboffsets') + + ln = [] + for i in range(buf_entry.type.ndim): + for fldname in fldnames: + ln.append("%s.diminfo[%d].%s = %s.rcbuffer->pybuffer.%s[%d];" % ( + pybuffernd_struct, i, fldname, + pybuffernd_struct, fldname, i, + )) + code.putln(' '.join(ln)) + +def put_init_vars(entry, code): + bufaux = entry.buffer_aux + pybuffernd_struct = bufaux.buflocal_nd_var.cname + pybuffer_struct = bufaux.rcbuf_var.cname + # init pybuffer_struct + code.putln("%s.pybuffer.buf = NULL;" % pybuffer_struct) + code.putln("%s.refcount = 0;" % pybuffer_struct) + # init the buffer object + # code.put_init_var_to_py_none(entry) + # init the pybuffernd_struct + code.putln("%s.data = NULL;" % pybuffernd_struct) + code.putln("%s.rcbuffer = &%s;" % (pybuffernd_struct, pybuffer_struct)) + + +def put_acquire_arg_buffer(entry, code, pos): + buffer_aux = entry.buffer_aux + getbuffer = get_getbuffer_call(code, entry.cname, buffer_aux, entry.type) + + # Acquire any new buffer + code.putln("{") + code.putln("__Pyx_BufFmt_StackElem __pyx_stack[%d];" % entry.type.dtype.struct_nesting_depth()) + code.putln(code.error_goto_if("%s == -1" % getbuffer, pos)) + code.putln("}") + # An exception raised in arg parsing cannot be caught, so no + # need to care about the buffer then. + put_unpack_buffer_aux_into_scope(entry, code) + + +def put_release_buffer_code(code, entry): + code.globalstate.use_utility_code(acquire_utility_code) + code.putln("__Pyx_SafeReleaseBuffer(&%s.rcbuffer->pybuffer);" % entry.buffer_aux.buflocal_nd_var.cname) + + +def get_getbuffer_call(code, obj_cname, buffer_aux, buffer_type): + ndim = buffer_type.ndim + cast = int(buffer_type.cast) + flags = get_flags(buffer_aux, buffer_type) + pybuffernd_struct = buffer_aux.buflocal_nd_var.cname + + dtype_typeinfo = get_type_information_cname(code, buffer_type.dtype) + + code.globalstate.use_utility_code(acquire_utility_code) + return ("__Pyx_GetBufferAndValidate(&%(pybuffernd_struct)s.rcbuffer->pybuffer, " + "(PyObject*)%(obj_cname)s, &%(dtype_typeinfo)s, %(flags)s, %(ndim)d, " + "%(cast)d, __pyx_stack)" % locals()) + + +def put_assign_to_buffer(lhs_cname, rhs_cname, buf_entry, + is_initialized, pos, code): + """ + Generate code for reassigning a buffer variables. This only deals with getting + the buffer auxiliary structure and variables set up correctly, the assignment + itself and refcounting is the responsibility of the caller. + + However, the assignment operation may throw an exception so that the reassignment + never happens. + + Depending on the circumstances there are two possible outcomes: + - Old buffer released, new acquired, rhs assigned to lhs + - Old buffer released, new acquired which fails, reaqcuire old lhs buffer + (which may or may not succeed). + """ + + buffer_aux, buffer_type = buf_entry.buffer_aux, buf_entry.type + pybuffernd_struct = buffer_aux.buflocal_nd_var.cname + flags = get_flags(buffer_aux, buffer_type) + + code.putln("{") # Set up necessary stack for getbuffer + code.putln("__Pyx_BufFmt_StackElem __pyx_stack[%d];" % buffer_type.dtype.struct_nesting_depth()) + + getbuffer = get_getbuffer_call(code, "%s", buffer_aux, buffer_type) # fill in object below + + if is_initialized: + # Release any existing buffer + code.putln('__Pyx_SafeReleaseBuffer(&%s.rcbuffer->pybuffer);' % pybuffernd_struct) + # Acquire + retcode_cname = code.funcstate.allocate_temp(PyrexTypes.c_int_type, manage_ref=False) + code.putln("%s = %s;" % (retcode_cname, getbuffer % rhs_cname)) + code.putln('if (%s) {' % (code.unlikely("%s < 0" % retcode_cname))) + # If acquisition failed, attempt to reacquire the old buffer + # before raising the exception. A failure of reacquisition + # will cause the reacquisition exception to be reported, one + # can consider working around this later. + exc_temps = tuple(code.funcstate.allocate_temp(PyrexTypes.py_object_type, manage_ref=False) + for _ in range(3)) + code.putln('PyErr_Fetch(&%s, &%s, &%s);' % exc_temps) + code.putln('if (%s) {' % code.unlikely("%s == -1" % (getbuffer % lhs_cname))) + code.putln('Py_XDECREF(%s); Py_XDECREF(%s); Py_XDECREF(%s);' % exc_temps) # Do not refnanny these! + code.globalstate.use_utility_code(raise_buffer_fallback_code) + code.putln('__Pyx_RaiseBufferFallbackError();') + code.putln('} else {') + code.putln('PyErr_Restore(%s, %s, %s);' % exc_temps) + code.putln('}') + code.putln('%s = %s = %s = 0;' % exc_temps) + for t in exc_temps: + code.funcstate.release_temp(t) + code.putln('}') + # Unpack indices + put_unpack_buffer_aux_into_scope(buf_entry, code) + code.putln(code.error_goto_if_neg(retcode_cname, pos)) + code.funcstate.release_temp(retcode_cname) + else: + # Our entry had no previous value, so set to None when acquisition fails. + # In this case, auxiliary vars should be set up right in initialization to a zero-buffer, + # so it suffices to set the buf field to NULL. + code.putln('if (%s) {' % code.unlikely("%s == -1" % (getbuffer % rhs_cname))) + code.putln('%s = %s; __Pyx_INCREF(Py_None); %s.rcbuffer->pybuffer.buf = NULL;' % + (lhs_cname, + PyrexTypes.typecast(buffer_type, PyrexTypes.py_object_type, "Py_None"), + pybuffernd_struct)) + code.putln(code.error_goto(pos)) + code.put('} else {') + # Unpack indices + put_unpack_buffer_aux_into_scope(buf_entry, code) + code.putln('}') + + code.putln("}") # Release stack + + +def put_buffer_lookup_code(entry, index_signeds, index_cnames, directives, + pos, code, negative_indices, in_nogil_context): + """ + Generates code to process indices and calculate an offset into + a buffer. Returns a C string which gives a pointer which can be + read from or written to at will (it is an expression so caller should + store it in a temporary if it is used more than once). + + As the bounds checking can have any number of combinations of unsigned + arguments, smart optimizations etc. we insert it directly in the function + body. The lookup however is delegated to a inline function that is instantiated + once per ndim (lookup with suboffsets tend to get quite complicated). + + entry is a BufferEntry + """ + negative_indices = directives['wraparound'] and negative_indices + + if directives['boundscheck']: + # Check bounds and fix negative indices. + # We allocate a temporary which is initialized to -1, meaning OK (!). + # If an error occurs, the temp is set to the index dimension the + # error is occurring at. + failed_dim_temp = code.funcstate.allocate_temp(PyrexTypes.c_int_type, manage_ref=False) + code.putln("%s = -1;" % failed_dim_temp) + for dim, (signed, cname, shape) in enumerate(zip(index_signeds, index_cnames, entry.get_buf_shapevars())): + if signed != 0: + # not unsigned, deal with negative index + code.putln("if (%s < 0) {" % cname) + if negative_indices: + code.putln("%s += %s;" % (cname, shape)) + code.putln("if (%s) %s = %d;" % ( + code.unlikely("%s < 0" % cname), + failed_dim_temp, dim)) + else: + code.putln("%s = %d;" % (failed_dim_temp, dim)) + code.put("} else ") + # check bounds in positive direction + if signed != 0: + cast = "" + else: + cast = "(size_t)" + code.putln("if (%s) %s = %d;" % ( + code.unlikely("%s >= %s%s" % (cname, cast, shape)), + failed_dim_temp, dim)) + + if in_nogil_context: + code.globalstate.use_utility_code(raise_indexerror_nogil) + func = '__Pyx_RaiseBufferIndexErrorNogil' + else: + code.globalstate.use_utility_code(raise_indexerror_code) + func = '__Pyx_RaiseBufferIndexError' + + code.putln("if (%s) {" % code.unlikely("%s != -1" % failed_dim_temp)) + code.putln('%s(%s);' % (func, failed_dim_temp)) + code.putln(code.error_goto(pos)) + code.putln('}') + code.funcstate.release_temp(failed_dim_temp) + elif negative_indices: + # Only fix negative indices. + for signed, cname, shape in zip(index_signeds, index_cnames, entry.get_buf_shapevars()): + if signed != 0: + code.putln("if (%s < 0) %s += %s;" % (cname, cname, shape)) + + return entry.generate_buffer_lookup_code(code, index_cnames) + + +def use_bufstruct_declare_code(env): + env.use_utility_code(buffer_struct_declare_code) + + +def buf_lookup_full_code(proto, defin, name, nd): + """ + Generates a buffer lookup function for the right number + of dimensions. The function gives back a void* at the right location. + """ + # _i_ndex, _s_tride, sub_o_ffset + macroargs = ", ".join(["i%d, s%d, o%d" % (i, i, i) for i in range(nd)]) + proto.putln("#define %s(type, buf, %s) (type)(%s_imp(buf, %s))" % (name, macroargs, name, macroargs)) + + funcargs = ", ".join(["Py_ssize_t i%d, Py_ssize_t s%d, Py_ssize_t o%d" % (i, i, i) for i in range(nd)]) + proto.putln("static CYTHON_INLINE void* %s_imp(void* buf, %s);" % (name, funcargs)) + defin.putln(dedent(""" + static CYTHON_INLINE void* %s_imp(void* buf, %s) { + char* ptr = (char*)buf; + """) % (name, funcargs) + "".join([dedent("""\ + ptr += s%d * i%d; + if (o%d >= 0) ptr = *((char**)ptr) + o%d; + """) % (i, i, i, i) for i in range(nd)] + ) + "\nreturn ptr;\n}") + + +def buf_lookup_strided_code(proto, defin, name, nd): + """ + Generates a buffer lookup function for the right number + of dimensions. The function gives back a void* at the right location. + """ + # _i_ndex, _s_tride + args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)]) + offset = " + ".join(["i%d * s%d" % (i, i) for i in range(nd)]) + proto.putln("#define %s(type, buf, %s) (type)((char*)buf + %s)" % (name, args, offset)) + + +def buf_lookup_c_code(proto, defin, name, nd): + """ + Similar to strided lookup, but can assume that the last dimension + doesn't need a multiplication as long as. + Still we keep the same signature for now. + """ + if nd == 1: + proto.putln("#define %s(type, buf, i0, s0) ((type)buf + i0)" % name) + else: + args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)]) + offset = " + ".join(["i%d * s%d" % (i, i) for i in range(nd - 1)]) + proto.putln("#define %s(type, buf, %s) ((type)((char*)buf + %s) + i%d)" % (name, args, offset, nd - 1)) + + +def buf_lookup_fortran_code(proto, defin, name, nd): + """ + Like C lookup, but the first index is optimized instead. + """ + if nd == 1: + proto.putln("#define %s(type, buf, i0, s0) ((type)buf + i0)" % name) + else: + args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)]) + offset = " + ".join(["i%d * s%d" % (i, i) for i in range(1, nd)]) + proto.putln("#define %s(type, buf, %s) ((type)((char*)buf + %s) + i%d)" % (name, args, offset, 0)) + + +def use_py2_buffer_functions(env): + env.use_utility_code(GetAndReleaseBufferUtilityCode()) + + +class GetAndReleaseBufferUtilityCode(object): + # Emulation of PyObject_GetBuffer and PyBuffer_Release for Python 2. + # For >= 2.6 we do double mode -- use the new buffer interface on objects + # which has the right tp_flags set, but emulation otherwise. + + requires = None + is_cython_utility = False + + def __init__(self): + pass + + def __eq__(self, other): + return isinstance(other, GetAndReleaseBufferUtilityCode) + + def __hash__(self): + return 24342342 + + def get_tree(self, **kwargs): pass + + def put_code(self, output): + code = output['utility_code_def'] + proto_code = output['utility_code_proto'] + env = output.module_node.scope + cython_scope = env.context.cython_scope + + # Search all types for __getbuffer__ overloads + types = [] + visited_scopes = set() + def find_buffer_types(scope): + if scope in visited_scopes: + return + visited_scopes.add(scope) + for m in scope.cimported_modules: + find_buffer_types(m) + for e in scope.type_entries: + if isinstance(e.utility_code_definition, CythonUtilityCode): + continue + t = e.type + if t.is_extension_type: + if scope is cython_scope and not e.used: + continue + release = get = None + for x in t.scope.pyfunc_entries: + if x.name == u"__getbuffer__": get = x.func_cname + elif x.name == u"__releasebuffer__": release = x.func_cname + if get: + types.append((t.typeptr_cname, get, release)) + + find_buffer_types(env) + + util_code = TempitaUtilityCode.load( + "GetAndReleaseBuffer", from_file="Buffer.c", + context=dict(types=types)) + + proto = util_code.format_code(util_code.proto) + impl = util_code.format_code( + util_code.inject_string_constants(util_code.impl, output)[1]) + + proto_code.putln(proto) + code.putln(impl) + + +def mangle_dtype_name(dtype): + # Use prefixes to separate user defined types from builtins + # (consider "typedef float unsigned_int") + if dtype.is_pyobject: + return "object" + elif dtype.is_ptr: + return "ptr" + else: + if dtype.is_typedef or dtype.is_struct_or_union: + prefix = "nn_" + else: + prefix = "" + return prefix + dtype.specialization_name() + +def get_type_information_cname(code, dtype, maxdepth=None): + """ + Output the run-time type information (__Pyx_TypeInfo) for given dtype, + and return the name of the type info struct. + + Structs with two floats of the same size are encoded as complex numbers. + One can separate between complex numbers declared as struct or with native + encoding by inspecting to see if the fields field of the type is + filled in. + """ + namesuffix = mangle_dtype_name(dtype) + name = "__Pyx_TypeInfo_%s" % namesuffix + structinfo_name = "__Pyx_StructFields_%s" % namesuffix + + if dtype.is_error: return "" + + # It's critical that walking the type info doesn't use more stack + # depth than dtype.struct_nesting_depth() returns, so use an assertion for this + if maxdepth is None: maxdepth = dtype.struct_nesting_depth() + if maxdepth <= 0: + assert False + + if name not in code.globalstate.utility_codes: + code.globalstate.utility_codes.add(name) + typecode = code.globalstate['typeinfo'] + + arraysizes = [] + if dtype.is_array: + while dtype.is_array: + arraysizes.append(dtype.size) + dtype = dtype.base_type + + complex_possible = dtype.is_struct_or_union and dtype.can_be_complex() + + declcode = dtype.empty_declaration_code() + if dtype.is_simple_buffer_dtype(): + structinfo_name = "NULL" + elif dtype.is_struct: + struct_scope = dtype.scope + if dtype.is_cv_qualified: + struct_scope = struct_scope.base_type_scope + # Must pre-call all used types in order not to recurse during utility code writing. + fields = struct_scope.var_entries + assert len(fields) > 0 + types = [get_type_information_cname(code, f.type, maxdepth - 1) + for f in fields] + typecode.putln("static __Pyx_StructField %s[] = {" % structinfo_name, safe=True) + + if dtype.is_cv_qualified: + # roughly speaking, remove "const" from struct_type + struct_type = dtype.cv_base_type.empty_declaration_code() + else: + struct_type = dtype.empty_declaration_code() + + for f, typeinfo in zip(fields, types): + typecode.putln(' {&%s, "%s", offsetof(%s, %s)},' % + (typeinfo, f.name, struct_type, f.cname), safe=True) + + typecode.putln(' {NULL, NULL, 0}', safe=True) + typecode.putln("};", safe=True) + else: + assert False + + rep = str(dtype) + + flags = "0" + is_unsigned = "0" + if dtype is PyrexTypes.c_char_type: + is_unsigned = "__PYX_IS_UNSIGNED(%s)" % declcode + typegroup = "'H'" + elif dtype.is_int: + is_unsigned = "__PYX_IS_UNSIGNED(%s)" % declcode + typegroup = "%s ? 'U' : 'I'" % is_unsigned + elif complex_possible or dtype.is_complex: + typegroup = "'C'" + elif dtype.is_float: + typegroup = "'R'" + elif dtype.is_struct: + typegroup = "'S'" + if dtype.packed: + flags = "__PYX_BUF_FLAGS_PACKED_STRUCT" + elif dtype.is_pyobject: + typegroup = "'O'" + else: + assert False, dtype + + typeinfo = ('static __Pyx_TypeInfo %s = ' + '{ "%s", %s, sizeof(%s), { %s }, %s, %s, %s, %s };') + tup = (name, rep, structinfo_name, declcode, + ', '.join([str(x) for x in arraysizes]) or '0', len(arraysizes), + typegroup, is_unsigned, flags) + typecode.putln(typeinfo % tup, safe=True) + + return name + +def load_buffer_utility(util_code_name, context=None, **kwargs): + if context is None: + return UtilityCode.load(util_code_name, "Buffer.c", **kwargs) + else: + return TempitaUtilityCode.load(util_code_name, "Buffer.c", context=context, **kwargs) + +context = dict(max_dims=Options.buffer_max_dims) +buffer_struct_declare_code = load_buffer_utility("BufferStructDeclare", context=context) +buffer_formats_declare_code = load_buffer_utility("BufferFormatStructs") + +# Utility function to set the right exception +# The caller should immediately goto_error +raise_indexerror_code = load_buffer_utility("BufferIndexError") +raise_indexerror_nogil = load_buffer_utility("BufferIndexErrorNogil") +raise_buffer_fallback_code = load_buffer_utility("BufferFallbackError") + +acquire_utility_code = load_buffer_utility("BufferGetAndValidate", context=context) +buffer_format_check_code = load_buffer_utility("BufferFormatCheck", context=context) + +# See utility code BufferFormatFromTypeInfo +_typeinfo_to_format_code = load_buffer_utility("TypeInfoToFormat") diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/Code.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/Code.py new file mode 100644 index 0000000000000000000000000000000000000000..cd7ca03443d49978e5335602f186756798c3d87f --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/Code.py @@ -0,0 +1,2745 @@ +# cython: language_level=3str +# cython: auto_pickle=False +# +# Code output module +# + +from __future__ import absolute_import + +import cython +cython.declare(os=object, re=object, operator=object, textwrap=object, + Template=object, Naming=object, Options=object, StringEncoding=object, + Utils=object, SourceDescriptor=object, StringIOTree=object, + DebugFlags=object, basestring=object, defaultdict=object, + closing=object, partial=object) + +import hashlib +import operator +import os +import re +import shutil +import textwrap +from string import Template +from functools import partial +from contextlib import closing, contextmanager +from collections import defaultdict + +from . import Naming +from . import Options +from . import DebugFlags +from . import StringEncoding +from .. import Utils +from .Scanning import SourceDescriptor +from ..StringIOTree import StringIOTree + +try: + from __builtin__ import basestring +except ImportError: + from builtins import str as basestring + + +non_portable_builtins_map = { + # builtins that have different names in different Python versions + 'bytes' : ('PY_MAJOR_VERSION < 3', 'str'), + 'unicode' : ('PY_MAJOR_VERSION >= 3', 'str'), + 'basestring' : ('PY_MAJOR_VERSION >= 3', 'str'), + 'xrange' : ('PY_MAJOR_VERSION >= 3', 'range'), + 'raw_input' : ('PY_MAJOR_VERSION >= 3', 'input'), +} + +ctypedef_builtins_map = { + # types of builtins in "ctypedef class" statements which we don't + # import either because the names conflict with C types or because + # the type simply is not exposed. + 'py_int' : '&PyInt_Type', + 'py_long' : '&PyLong_Type', + 'py_float' : '&PyFloat_Type', + 'wrapper_descriptor' : '&PyWrapperDescr_Type', +} + +basicsize_builtins_map = { + # builtins whose type has a different tp_basicsize than sizeof(...) + 'PyTypeObject': 'PyHeapTypeObject', +} + +uncachable_builtins = [ + # Global/builtin names that cannot be cached because they may or may not + # be available at import time, for various reasons: + ## Python 3.13+ + '_IncompleteInputError', + 'PythonFinalizationError', + ## Python 3.11+ + 'BaseExceptionGroup', + 'ExceptionGroup', + ## - Py3.10+ + 'aiter', + 'anext', + 'EncodingWarning', + ## - Py3.7+ + 'breakpoint', # might deserve an implementation in Cython + ## - Py3.4+ + '__loader__', + '__spec__', + ## - Py3+ + 'BlockingIOError', + 'BrokenPipeError', + 'ChildProcessError', + 'ConnectionAbortedError', + 'ConnectionError', + 'ConnectionRefusedError', + 'ConnectionResetError', + 'FileExistsError', + 'FileNotFoundError', + 'InterruptedError', + 'IsADirectoryError', + 'ModuleNotFoundError', + 'NotADirectoryError', + 'PermissionError', + 'ProcessLookupError', + 'RecursionError', + 'ResourceWarning', + #'StopAsyncIteration', # backported + 'TimeoutError', + '__build_class__', + 'ascii', # might deserve an implementation in Cython + #'exec', # implemented in Cython + ## - platform specific + 'WindowsError', + ## - others + '_', # e.g. used by gettext +] + +special_py_methods = cython.declare(frozenset, frozenset(( + '__cinit__', '__dealloc__', '__richcmp__', '__next__', + '__await__', '__aiter__', '__anext__', + '__getreadbuffer__', '__getwritebuffer__', '__getsegcount__', + '__getcharbuffer__', '__getbuffer__', '__releasebuffer__', +))) + +modifier_output_mapper = { + 'inline': 'CYTHON_INLINE' +}.get + + +class IncludeCode(object): + """ + An include file and/or verbatim C code to be included in the + generated sources. + """ + # attributes: + # + # pieces {order: unicode}: pieces of C code to be generated. + # For the included file, the key "order" is zero. + # For verbatim include code, the "order" is the "order" + # attribute of the original IncludeCode where this piece + # of C code was first added. This is needed to prevent + # duplication if the same include code is found through + # multiple cimports. + # location int: where to put this include in the C sources, one + # of the constants INITIAL, EARLY, LATE + # order int: sorting order (automatically set by increasing counter) + + # Constants for location. If the same include occurs with different + # locations, the earliest one takes precedense. + INITIAL = 0 + EARLY = 1 + LATE = 2 + + counter = 1 # Counter for "order" + + def __init__(self, include=None, verbatim=None, late=True, initial=False): + self.order = self.counter + type(self).counter += 1 + self.pieces = {} + + if include: + if include[0] == '<' and include[-1] == '>': + self.pieces[0] = u'#include {0}'.format(include) + late = False # system include is never late + else: + self.pieces[0] = u'#include "{0}"'.format(include) + + if verbatim: + self.pieces[self.order] = verbatim + + if initial: + self.location = self.INITIAL + elif late: + self.location = self.LATE + else: + self.location = self.EARLY + + def dict_update(self, d, key): + """ + Insert `self` in dict `d` with key `key`. If that key already + exists, update the attributes of the existing value with `self`. + """ + if key in d: + other = d[key] + other.location = min(self.location, other.location) + other.pieces.update(self.pieces) + else: + d[key] = self + + def sortkey(self): + return self.order + + def mainpiece(self): + """ + Return the main piece of C code, corresponding to the include + file. If there was no include file, return None. + """ + return self.pieces.get(0) + + def write(self, code): + # Write values of self.pieces dict, sorted by the keys + for k in sorted(self.pieces): + code.putln(self.pieces[k]) + + +def get_utility_dir(): + # make this a function and not global variables: + # http://trac.cython.org/cython_trac/ticket/475 + Cython_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + return os.path.join(Cython_dir, "Utility") + +read_utilities_hook = None +""" +Override the hook for reading a utilities file that contains code fragments used +by the codegen. + +The hook functions takes the path of the utilities file, and returns a list +of strings, one per line. + +The default behavior is to open a file relative to get_utility_dir(). +""" + +def read_utilities_from_utility_dir(path): + """ + Read all lines of the file at the provided path from a path relative + to get_utility_dir(). + """ + filename = os.path.join(get_utility_dir(), path) + with closing(Utils.open_source_file(filename, encoding='UTF-8')) as f: + return f.readlines() + +# by default, read utilities from the utility directory. +read_utilities_hook = read_utilities_from_utility_dir + +class UtilityCodeBase(object): + """ + Support for loading utility code from a file. + + Code sections in the file can be specified as follows: + + ##### MyUtility.proto ##### + + [proto declarations] + + ##### MyUtility.init ##### + + [code run at module initialization] + + ##### MyUtility ##### + #@requires: MyOtherUtility + #@substitute: naming + + [definitions] + + ##### MyUtility ##### + #@substitute: tempita + + [requires tempita substitution + - context can't be specified here though so only + tempita utility that requires no external context + will benefit from this tag + - only necessary when @required from non-tempita code] + + for prototypes and implementation respectively. For non-python or + -cython files backslashes should be used instead. 5 to 30 comment + characters may be used on either side. + + If the @cname decorator is not used and this is a CythonUtilityCode, + one should pass in the 'name' keyword argument to be used for name + mangling of such entries. + """ + + is_cython_utility = False + _utility_cache = {} + + @classmethod + def _add_utility(cls, utility, type, lines, begin_lineno, tags=None): + if utility is None: + return + + code = '\n'.join(lines) + if tags and 'substitute' in tags and 'naming' in tags['substitute']: + try: + code = Template(code).substitute(vars(Naming)) + except (KeyError, ValueError) as e: + raise RuntimeError("Error parsing templated utility code of type '%s' at line %d: %s" % ( + type, begin_lineno, e)) + + # remember correct line numbers at least until after templating + code = '\n' * begin_lineno + code + + if type == 'proto': + utility[0] = code + elif type == 'impl': + utility[1] = code + else: + all_tags = utility[2] + all_tags[type] = code + + if tags: + all_tags = utility[2] + for name, values in tags.items(): + all_tags.setdefault(name, set()).update(values) + + @classmethod + def load_utilities_from_file(cls, path): + utilities = cls._utility_cache.get(path) + if utilities: + return utilities + + _, ext = os.path.splitext(path) + if ext in ('.pyx', '.py', '.pxd', '.pxi'): + comment = '#' + strip_comments = partial(re.compile(r'^\s*#(?!\s*cython\s*:).*').sub, '') + rstrip = StringEncoding._unicode.rstrip + else: + comment = '/' + strip_comments = partial(re.compile(r'^\s*//.*|/\*[^*]*\*/').sub, '') + rstrip = partial(re.compile(r'\s+(\\?)$').sub, r'\1') + match_special = re.compile( + (r'^%(C)s{5,30}\s*(?P(?:\w|\.)+)\s*%(C)s{5,30}|' + r'^%(C)s+@(?P\w+)\s*:\s*(?P(?:\w|[.:])+)') % + {'C': comment}).match + match_type = re.compile(r'(.+)[.](proto(?:[.]\S+)?|impl|init|cleanup)$').match + + all_lines = read_utilities_hook(path) + + utilities = defaultdict(lambda: [None, None, {}]) + lines = [] + tags = defaultdict(set) + utility = type = None + begin_lineno = 0 + + for lineno, line in enumerate(all_lines): + m = match_special(line) + if m: + if m.group('name'): + cls._add_utility(utility, type, lines, begin_lineno, tags) + + begin_lineno = lineno + 1 + del lines[:] + tags.clear() + + name = m.group('name') + mtype = match_type(name) + if mtype: + name, type = mtype.groups() + else: + type = 'impl' + utility = utilities[name] + else: + tags[m.group('tag')].add(m.group('value')) + lines.append('') # keep line number correct + else: + lines.append(rstrip(strip_comments(line))) + + if utility is None: + raise ValueError("Empty utility code file") + + # Don't forget to add the last utility code + cls._add_utility(utility, type, lines, begin_lineno, tags) + + utilities = dict(utilities) # un-defaultdict-ify + cls._utility_cache[path] = utilities + return utilities + + @classmethod + def load(cls, util_code_name, from_file, **kwargs): + """ + Load utility code from a file specified by from_file (relative to + Cython/Utility) and name util_code_name. + """ + + if '::' in util_code_name: + from_file, util_code_name = util_code_name.rsplit('::', 1) + assert from_file + utilities = cls.load_utilities_from_file(from_file) + proto, impl, tags = utilities[util_code_name] + + if tags: + if "substitute" in tags and "tempita" in tags["substitute"]: + if not issubclass(cls, TempitaUtilityCode): + return TempitaUtilityCode.load(util_code_name, from_file, **kwargs) + orig_kwargs = kwargs.copy() + for name, values in tags.items(): + if name in kwargs: + continue + # only pass lists when we have to: most argument expect one value or None + if name == 'requires': + if orig_kwargs: + values = [cls.load(dep, from_file, **orig_kwargs) + for dep in sorted(values)] + else: + # dependencies are rarely unique, so use load_cached() when we can + values = [cls.load_cached(dep, from_file) + for dep in sorted(values)] + elif name == 'substitute': + # don't want to pass "naming" or "tempita" to the constructor + # since these will have been handled + values = values - {'naming', 'tempita'} + if not values: + continue + elif not values: + values = None + elif len(values) == 1: + values = list(values)[0] + kwargs[name] = values + + if proto is not None: + kwargs['proto'] = proto + if impl is not None: + kwargs['impl'] = impl + + if 'name' not in kwargs: + kwargs['name'] = util_code_name + + if 'file' not in kwargs and from_file: + kwargs['file'] = from_file + return cls(**kwargs) + + @classmethod + def load_cached(cls, utility_code_name, from_file, __cache={}): + """ + Calls .load(), but using a per-type cache based on utility name and file name. + """ + key = (utility_code_name, from_file, cls) + try: + return __cache[key] + except KeyError: + pass + code = __cache[key] = cls.load(utility_code_name, from_file) + return code + + @classmethod + def load_as_string(cls, util_code_name, from_file, **kwargs): + """ + Load a utility code as a string. Returns (proto, implementation) + """ + util = cls.load(util_code_name, from_file, **kwargs) + proto, impl = util.proto, util.impl + return util.format_code(proto), util.format_code(impl) + + def format_code(self, code_string, replace_empty_lines=re.compile(r'\n\n+').sub): + """ + Format a code section for output. + """ + if code_string: + code_string = replace_empty_lines('\n', code_string.strip()) + '\n\n' + return code_string + + def __repr__(self): + return "<%s(%s)>" % (type(self).__name__, self.name) + + def get_tree(self, **kwargs): + return None + + def __deepcopy__(self, memodict=None): + # No need to deep-copy utility code since it's essentially immutable. + return self + + +class UtilityCode(UtilityCodeBase): + """ + Stores utility code to add during code generation. + + See GlobalState.put_utility_code. + + hashes/equals by instance + + proto C prototypes + impl implementation code + init code to call on module initialization + requires utility code dependencies + proto_block the place in the resulting file where the prototype should + end up + name name of the utility code (or None) + file filename of the utility code file this utility was loaded + from (or None) + """ + + def __init__(self, proto=None, impl=None, init=None, cleanup=None, requires=None, + proto_block='utility_code_proto', name=None, file=None): + # proto_block: Which code block to dump prototype in. See GlobalState. + self.proto = proto + self.impl = impl + self.init = init + self.cleanup = cleanup + self.requires = requires + self._cache = {} + self.specialize_list = [] + self.proto_block = proto_block + self.name = name + self.file = file + + def __hash__(self): + return hash((self.proto, self.impl)) + + def __eq__(self, other): + if self is other: + return True + self_type, other_type = type(self), type(other) + if self_type is not other_type and not (isinstance(other, self_type) or isinstance(self, other_type)): + return False + + self_init = getattr(self, 'init', None) + other_init = getattr(other, 'init', None) + self_proto = getattr(self, 'proto', None) + other_proto = getattr(other, 'proto', None) + return (self_init, self_proto, self.impl) == (other_init, other_proto, other.impl) + + def none_or_sub(self, s, context): + """ + Format a string in this utility code with context. If None, do nothing. + """ + if s is None: + return None + return s % context + + def specialize(self, pyrex_type=None, **data): + name = self.name + if pyrex_type is not None: + data['type'] = pyrex_type.empty_declaration_code() + data['type_name'] = pyrex_type.specialization_name() + name = "%s[%s]" % (name, data['type_name']) + # Dicts aren't hashable... + key = tuple(sorted(data.items())) + try: + return self._cache[key] + except KeyError: + if self.requires is None: + requires = None + else: + requires = [r.specialize(data) for r in self.requires] + + s = self._cache[key] = UtilityCode( + self.none_or_sub(self.proto, data), + self.none_or_sub(self.impl, data), + self.none_or_sub(self.init, data), + self.none_or_sub(self.cleanup, data), + requires, + self.proto_block, + name, + ) + + self.specialize_list.append(s) + return s + + def inject_string_constants(self, impl, output): + """Replace 'PYIDENT("xyz")' by a constant Python identifier cname. + """ + if 'PYIDENT(' not in impl and 'PYUNICODE(' not in impl: + return False, impl + + replacements = {} + def externalise(matchobj): + key = matchobj.groups() + try: + cname = replacements[key] + except KeyError: + str_type, name = key + cname = replacements[key] = output.get_py_string_const( + StringEncoding.EncodedString(name), identifier=str_type == 'IDENT').cname + return cname + + impl = re.sub(r'PY(IDENT|UNICODE)\("([^"]+)"\)', externalise, impl) + assert 'PYIDENT(' not in impl and 'PYUNICODE(' not in impl + return True, impl + + def inject_unbound_methods(self, impl, output): + """Replace 'UNBOUND_METHOD(type, "name")' by a constant Python identifier cname. + """ + if 'CALL_UNBOUND_METHOD(' not in impl: + return False, impl + + def externalise(matchobj): + type_cname, method_name, obj_cname, args = matchobj.groups() + args = [arg.strip() for arg in args[1:].split(',')] if args else [] + assert len(args) < 3, "CALL_UNBOUND_METHOD() does not support %d call arguments" % len(args) + return output.cached_unbound_method_call_code(obj_cname, type_cname, method_name, args) + + impl = re.sub( + r'CALL_UNBOUND_METHOD\(' + r'([a-zA-Z_]+),' # type cname + r'\s*"([^"]+)",' # method name + r'\s*([^),]+)' # object cname + r'((?:,[^),]+)*)' # args* + r'\)', externalise, impl) + assert 'CALL_UNBOUND_METHOD(' not in impl + + return True, impl + + def wrap_c_strings(self, impl): + """Replace CSTRING('''xyz''') by a C compatible string + """ + if 'CSTRING(' not in impl: + return impl + + def split_string(matchobj): + content = matchobj.group(1).replace('"', '\042') + return ''.join( + '"%s\\n"\n' % line if not line.endswith('\\') or line.endswith('\\\\') else '"%s"\n' % line[:-1] + for line in content.splitlines()) + + impl = re.sub(r'CSTRING\(\s*"""([^"]*(?:"[^"]+)*)"""\s*\)', split_string, impl) + assert 'CSTRING(' not in impl + return impl + + def put_code(self, output): + if self.requires: + for dependency in self.requires: + output.use_utility_code(dependency) + if self.proto: + writer = output[self.proto_block] + writer.putln("/* %s.proto */" % self.name) + writer.put_or_include( + self.format_code(self.proto), '%s_proto' % self.name) + if self.impl: + impl = self.format_code(self.wrap_c_strings(self.impl)) + is_specialised1, impl = self.inject_string_constants(impl, output) + is_specialised2, impl = self.inject_unbound_methods(impl, output) + writer = output['utility_code_def'] + writer.putln("/* %s */" % self.name) + if not (is_specialised1 or is_specialised2): + # no module specific adaptations => can be reused + writer.put_or_include(impl, '%s_impl' % self.name) + else: + writer.put(impl) + if self.init: + writer = output['init_globals'] + writer.putln("/* %s.init */" % self.name) + if isinstance(self.init, basestring): + writer.put(self.format_code(self.init)) + else: + self.init(writer, output.module_pos) + # 'init' code can end with an 'if' statement for an error condition like: + # if (check_ok()) ; else + writer.putln(writer.error_goto_if_PyErr(output.module_pos)) + writer.putln() + if self.cleanup and Options.generate_cleanup_code: + writer = output['cleanup_globals'] + writer.putln("/* %s.cleanup */" % self.name) + if isinstance(self.cleanup, basestring): + writer.put_or_include( + self.format_code(self.cleanup), + '%s_cleanup' % self.name) + else: + self.cleanup(writer, output.module_pos) + + +def sub_tempita(s, context, file=None, name=None, __cache={}): + "Run tempita on string s with given context." + if not s: + return None + + if file: + name = "%s:%s" % (file, name) + if name: + context['__name'] = name + + try: + template = __cache[s] + except KeyError: + from ..Tempita import Template + template = __cache[s] = Template(s, name=name) + + return template.substitute(context) + + +class TempitaUtilityCode(UtilityCode): + def __init__(self, name=None, proto=None, impl=None, init=None, file=None, context=None, **kwargs): + if context is None: + context = {} + proto = sub_tempita(proto, context, file, name) + impl = sub_tempita(impl, context, file, name) + init = sub_tempita(init, context, file, name) + super(TempitaUtilityCode, self).__init__( + proto, impl, init=init, name=name, file=file, **kwargs) + + @classmethod + def load_cached(cls, utility_code_name, from_file=None, context=None, __cache={}): + context_key = tuple(sorted(context.items())) if context else None + assert hash(context_key) is not None # raise TypeError if not hashable + key = (cls, from_file, utility_code_name, context_key) + try: + return __cache[key] + except KeyError: + pass + code = __cache[key] = cls.load(utility_code_name, from_file, context=context) + return code + + def none_or_sub(self, s, context): + """ + Format a string in this utility code with context. If None, do nothing. + """ + if s is None: + return None + return sub_tempita(s, context, self.file, self.name) + + +class LazyUtilityCode(UtilityCodeBase): + """ + Utility code that calls a callback with the root code writer when + available. Useful when you only have 'env' but not 'code'. + """ + __name__ = '' + requires = None + + def __init__(self, callback): + self.callback = callback + + def put_code(self, globalstate): + utility = self.callback(globalstate.rootwriter) + globalstate.use_utility_code(utility) + + +class FunctionState(object): + # return_label string function return point label + # error_label string error catch point label + # error_without_exception boolean Can go to the error label without an exception (e.g. __next__ can return NULL) + # continue_label string loop continue point label + # break_label string loop break point label + # return_from_error_cleanup_label string + # label_counter integer counter for naming labels + # in_try_finally boolean inside try of try...finally + # exc_vars (string * 3) exception variables for reraise, or None + # can_trace boolean line tracing is supported in the current context + # scope Scope the scope object of the current function + + # Not used for now, perhaps later + def __init__(self, owner, names_taken=set(), scope=None): + self.names_taken = names_taken + self.owner = owner + self.scope = scope + + self.error_label = None + self.label_counter = 0 + self.labels_used = set() + self.return_label = self.new_label() + self.new_error_label() + self.continue_label = None + self.break_label = None + self.yield_labels = [] + + self.in_try_finally = 0 + self.exc_vars = None + self.current_except = None + self.can_trace = False + self.gil_owned = True + + self.temps_allocated = [] # of (name, type, manage_ref, static) + self.temps_free = {} # (type, manage_ref) -> list of free vars with same type/managed status + self.temps_used_type = {} # name -> (type, manage_ref) + self.zombie_temps = set() # temps that must not be reused after release + self.temp_counter = 0 + self.closure_temps = None + + # This is used to collect temporaries, useful to find out which temps + # need to be privatized in parallel sections + self.collect_temps_stack = [] + + # This is used for the error indicator, which needs to be local to the + # function. It used to be global, which relies on the GIL being held. + # However, exceptions may need to be propagated through 'nogil' + # sections, in which case we introduce a race condition. + self.should_declare_error_indicator = False + self.uses_error_indicator = False + + self.error_without_exception = False + + self.needs_refnanny = False + + # safety checks + + def validate_exit(self): + # validate that all allocated temps have been freed + if self.temps_allocated: + leftovers = self.temps_in_use() + if leftovers: + msg = "TEMPGUARD: Temps left over at end of '%s': %s" % (self.scope.name, ', '.join([ + '%s [%s]' % (name, ctype) + for name, ctype, is_pytemp in sorted(leftovers)]), + ) + #print(msg) + raise RuntimeError(msg) + + # labels + + def new_label(self, name=None): + n = self.label_counter + self.label_counter = n + 1 + label = "%s%d" % (Naming.label_prefix, n) + if name is not None: + label += '_' + name + return label + + def new_yield_label(self, expr_type='yield'): + label = self.new_label('resume_from_%s' % expr_type) + num_and_label = (len(self.yield_labels) + 1, label) + self.yield_labels.append(num_and_label) + return num_and_label + + def new_error_label(self, prefix=""): + old_err_lbl = self.error_label + self.error_label = self.new_label(prefix + 'error') + return old_err_lbl + + def get_loop_labels(self): + return ( + self.continue_label, + self.break_label) + + def set_loop_labels(self, labels): + (self.continue_label, + self.break_label) = labels + + def new_loop_labels(self, prefix=""): + old_labels = self.get_loop_labels() + self.set_loop_labels( + (self.new_label(prefix + "continue"), + self.new_label(prefix + "break"))) + return old_labels + + def get_all_labels(self): + return ( + self.continue_label, + self.break_label, + self.return_label, + self.error_label) + + def set_all_labels(self, labels): + (self.continue_label, + self.break_label, + self.return_label, + self.error_label) = labels + + def all_new_labels(self): + old_labels = self.get_all_labels() + new_labels = [] + for old_label, name in zip(old_labels, ['continue', 'break', 'return', 'error']): + if old_label: + new_labels.append(self.new_label(name)) + else: + new_labels.append(old_label) + self.set_all_labels(new_labels) + return old_labels + + def use_label(self, lbl): + self.labels_used.add(lbl) + + def label_used(self, lbl): + return lbl in self.labels_used + + # temp handling + + def allocate_temp(self, type, manage_ref, static=False, reusable=True): + """ + Allocates a temporary (which may create a new one or get a previously + allocated and released one of the same type). Type is simply registered + and handed back, but will usually be a PyrexType. + + If type.needs_refcounting, manage_ref comes into play. If manage_ref is set to + True, the temp will be decref-ed on return statements and in exception + handling clauses. Otherwise the caller has to deal with any reference + counting of the variable. + + If not type.needs_refcounting, then manage_ref will be ignored, but it + still has to be passed. It is recommended to pass False by convention + if it is known that type will never be a reference counted type. + + static=True marks the temporary declaration with "static". + This is only used when allocating backing store for a module-level + C array literals. + + if reusable=False, the temp will not be reused after release. + + A C string referring to the variable is returned. + """ + if type.is_cv_qualified and not type.is_reference: + type = type.cv_base_type + elif type.is_reference and not type.is_fake_reference: + type = type.ref_base_type + elif type.is_cfunction: + from . import PyrexTypes + type = PyrexTypes.c_ptr_type(type) # A function itself isn't an l-value + elif type.is_cpp_class and not type.is_fake_reference and self.scope.directives['cpp_locals']: + self.scope.use_utility_code(UtilityCode.load_cached("OptionalLocals", "CppSupport.cpp")) + if not type.needs_refcounting: + # Make manage_ref canonical, so that manage_ref will always mean + # a decref is needed. + manage_ref = False + + freelist = self.temps_free.get((type, manage_ref)) + if reusable and freelist is not None and freelist[0]: + result = freelist[0].pop() + freelist[1].remove(result) + else: + while True: + self.temp_counter += 1 + result = "%s%d" % (Naming.codewriter_temp_prefix, self.temp_counter) + if result not in self.names_taken: break + self.temps_allocated.append((result, type, manage_ref, static)) + if not reusable: + self.zombie_temps.add(result) + self.temps_used_type[result] = (type, manage_ref) + if DebugFlags.debug_temp_code_comments: + self.owner.putln("/* %s allocated (%s)%s */" % (result, type, "" if reusable else " - zombie")) + + if self.collect_temps_stack: + self.collect_temps_stack[-1].add((result, type)) + + return result + + def release_temp(self, name): + """ + Releases a temporary so that it can be reused by other code needing + a temp of the same type. + """ + type, manage_ref = self.temps_used_type[name] + freelist = self.temps_free.get((type, manage_ref)) + if freelist is None: + freelist = ([], set()) # keep order in list and make lookups in set fast + self.temps_free[(type, manage_ref)] = freelist + if name in freelist[1]: + raise RuntimeError("Temp %s freed twice!" % name) + if name not in self.zombie_temps: + freelist[0].append(name) + freelist[1].add(name) + if DebugFlags.debug_temp_code_comments: + self.owner.putln("/* %s released %s*/" % ( + name, " - zombie" if name in self.zombie_temps else "")) + + def temps_in_use(self): + """Return a list of (cname,type,manage_ref) tuples of temp names and their type + that are currently in use. + """ + used = [] + for name, type, manage_ref, static in self.temps_allocated: + freelist = self.temps_free.get((type, manage_ref)) + if freelist is None or name not in freelist[1]: + used.append((name, type, manage_ref and type.needs_refcounting)) + return used + + def temps_holding_reference(self): + """Return a list of (cname,type) tuples of temp names and their type + that are currently in use. This includes only temps + with a reference counted type which owns its reference. + """ + return [(name, type) + for name, type, manage_ref in self.temps_in_use() + if manage_ref and type.needs_refcounting] + + def all_managed_temps(self): + """Return a list of (cname, type) tuples of refcount-managed Python objects. + """ + return [(cname, type) + for cname, type, manage_ref, static in self.temps_allocated + if manage_ref] + + def all_free_managed_temps(self): + """Return a list of (cname, type) tuples of refcount-managed Python + objects that are not currently in use. This is used by + try-except and try-finally blocks to clean up temps in the + error case. + """ + return sorted([ # Enforce deterministic order. + (cname, type) + for (type, manage_ref), freelist in self.temps_free.items() if manage_ref + for cname in freelist[0] + ]) + + def start_collecting_temps(self): + """ + Useful to find out which temps were used in a code block + """ + self.collect_temps_stack.append(set()) + + def stop_collecting_temps(self): + return self.collect_temps_stack.pop() + + def init_closure_temps(self, scope): + self.closure_temps = ClosureTempAllocator(scope) + + +class NumConst(object): + """Global info about a Python number constant held by GlobalState. + + cname string + value string + py_type string int, long, float + value_code string evaluation code if different from value + """ + + def __init__(self, cname, value, py_type, value_code=None): + self.cname = cname + self.value = value + self.py_type = py_type + self.value_code = value_code or value + + +class PyObjectConst(object): + """Global info about a generic constant held by GlobalState. + """ + # cname string + # type PyrexType + + def __init__(self, cname, type): + self.cname = cname + self.type = type + + +cython.declare(possible_unicode_identifier=object, possible_bytes_identifier=object, + replace_identifier=object, find_alphanums=object) +possible_unicode_identifier = re.compile(br"(?![0-9])\w+$".decode('ascii'), re.U).match +possible_bytes_identifier = re.compile(r"(?![0-9])\w+$".encode('ASCII')).match +replace_identifier = re.compile(r'[^a-zA-Z0-9_]+').sub +find_alphanums = re.compile('([a-zA-Z0-9]+)').findall + +class StringConst(object): + """Global info about a C string constant held by GlobalState. + """ + # cname string + # text EncodedString or BytesLiteral + # py_strings {(identifier, encoding) : PyStringConst} + + def __init__(self, cname, text, byte_string): + self.cname = cname + self.text = text + self.escaped_value = StringEncoding.escape_byte_string(byte_string) + self.py_strings = None + self.py_versions = [] + + def add_py_version(self, version): + if not version: + self.py_versions = [2, 3] + elif version not in self.py_versions: + self.py_versions.append(version) + + def get_py_string_const(self, encoding, identifier=None, + is_str=False, py3str_cstring=None): + py_strings = self.py_strings + text = self.text + + is_str = bool(identifier or is_str) + is_unicode = encoding is None and not is_str + + if encoding is None: + # unicode string + encoding_key = None + else: + # bytes or str + encoding = encoding.lower() + if encoding in ('utf8', 'utf-8', 'ascii', 'usascii', 'us-ascii'): + encoding = None + encoding_key = None + else: + encoding_key = ''.join(find_alphanums(encoding)) + + key = (is_str, is_unicode, encoding_key, py3str_cstring) + if py_strings is not None: + try: + return py_strings[key] + except KeyError: + pass + else: + self.py_strings = {} + + if identifier: + intern = True + elif identifier is None: + if isinstance(text, bytes): + intern = bool(possible_bytes_identifier(text)) + else: + intern = bool(possible_unicode_identifier(text)) + else: + intern = False + if intern: + prefix = Naming.interned_prefixes['str'] + else: + prefix = Naming.py_const_prefix + + if encoding_key: + encoding_prefix = '_%s' % encoding_key + else: + encoding_prefix = '' + + pystring_cname = "%s%s%s_%s" % ( + prefix, + (is_str and 's') or (is_unicode and 'u') or 'b', + encoding_prefix, + self.cname[len(Naming.const_prefix):]) + + py_string = PyStringConst( + pystring_cname, encoding, is_unicode, is_str, py3str_cstring, intern) + self.py_strings[key] = py_string + return py_string + +class PyStringConst(object): + """Global info about a Python string constant held by GlobalState. + """ + # cname string + # py3str_cstring string + # encoding string + # intern boolean + # is_unicode boolean + # is_str boolean + + def __init__(self, cname, encoding, is_unicode, is_str=False, + py3str_cstring=None, intern=False): + self.cname = cname + self.py3str_cstring = py3str_cstring + self.encoding = encoding + self.is_str = is_str + self.is_unicode = is_unicode + self.intern = intern + + def __lt__(self, other): + return self.cname < other.cname + + +class GlobalState(object): + # filename_table {string : int} for finding filename table indexes + # filename_list [string] filenames in filename table order + # input_file_contents dict contents (=list of lines) of any file that was used as input + # to create this output C code. This is + # used to annotate the comments. + # + # utility_codes set IDs of used utility code (to avoid reinsertion) + # + # declared_cnames {string:Entry} used in a transition phase to merge pxd-declared + # constants etc. into the pyx-declared ones (i.e, + # check if constants are already added). + # In time, hopefully the literals etc. will be + # supplied directly instead. + # + # const_cnames_used dict global counter for unique constant identifiers + # + + # parts {string:CCodeWriter} + + + # interned_strings + # consts + # interned_nums + + # directives set Temporary variable used to track + # the current set of directives in the code generation + # process. + + directives = {} + + code_layout = [ + 'h_code', + 'filename_table', + 'utility_code_proto_before_types', + 'numeric_typedefs', # Let these detailed individual parts stay!, + 'complex_type_declarations', # as the proper solution is to make a full DAG... + 'type_declarations', # More coarse-grained blocks would simply hide + 'utility_code_proto', # the ugliness, not fix it + 'module_declarations', + 'typeinfo', + 'before_global_var', + 'global_var', + 'string_decls', + 'decls', + 'late_includes', + 'module_state', + 'module_state_clear', + 'module_state_traverse', + 'module_state_defines', # redefines names used in module_state/_clear/_traverse + 'module_code', # user code goes here + 'pystring_table', + 'cached_builtins', + 'cached_constants', + 'init_constants', + 'init_globals', # (utility code called at init-time) + 'init_module', + 'cleanup_globals', + 'cleanup_module', + 'main_method', + 'utility_code_pragmas', # silence some irrelevant warnings in utility code + 'utility_code_def', + 'utility_code_pragmas_end', # clean-up the utility_code_pragmas + 'end' + ] + + # h files can only have a much smaller list of sections + h_code_layout = [ + 'h_code', + 'utility_code_proto_before_types', + 'type_declarations', + 'utility_code_proto', + 'end' + ] + + def __init__(self, writer, module_node, code_config, common_utility_include_dir=None): + self.filename_table = {} + self.filename_list = [] + self.input_file_contents = {} + self.utility_codes = set() + self.declared_cnames = {} + self.in_utility_code_generation = False + self.code_config = code_config + self.common_utility_include_dir = common_utility_include_dir + self.parts = {} + self.module_node = module_node # because some utility code generation needs it + # (generating backwards-compatible Get/ReleaseBuffer + + self.const_cnames_used = {} + self.string_const_index = {} + self.dedup_const_index = {} + self.pyunicode_ptr_const_index = {} + self.num_const_index = {} + self.py_constants = [] + self.cached_cmethods = {} + self.initialised_constants = set() + + writer.set_global_state(self) + self.rootwriter = writer + + def initialize_main_c_code(self): + rootwriter = self.rootwriter + for i, part in enumerate(self.code_layout): + w = self.parts[part] = rootwriter.insertion_point() + if i > 0: + w.putln("/* #### Code section: %s ### */" % part) + + if not Options.cache_builtins: + del self.parts['cached_builtins'] + else: + w = self.parts['cached_builtins'] + w.enter_cfunc_scope() + w.putln("static CYTHON_SMALL_CODE int __Pyx_InitCachedBuiltins(void) {") + + w = self.parts['cached_constants'] + w.enter_cfunc_scope() + w.putln("") + w.putln("static CYTHON_SMALL_CODE int __Pyx_InitCachedConstants(void) {") + w.put_declare_refcount_context() + w.put_setup_refcount_context(StringEncoding.EncodedString("__Pyx_InitCachedConstants")) + + w = self.parts['init_globals'] + w.enter_cfunc_scope() + w.putln("") + w.putln("static CYTHON_SMALL_CODE int __Pyx_InitGlobals(void) {") + + w = self.parts['init_constants'] + w.enter_cfunc_scope() + w.putln("") + w.putln("static CYTHON_SMALL_CODE int __Pyx_InitConstants(void) {") + + if not Options.generate_cleanup_code: + del self.parts['cleanup_globals'] + else: + w = self.parts['cleanup_globals'] + w.enter_cfunc_scope() + w.putln("") + w.putln("static CYTHON_SMALL_CODE void __Pyx_CleanupGlobals(void) {") + + code = self.parts['utility_code_proto'] + code.putln("") + code.putln("/* --- Runtime support code (head) --- */") + + code = self.parts['utility_code_def'] + if self.code_config.emit_linenums: + code.write('\n#line 1 "cython_utility"\n') + code.putln("") + code.putln("/* --- Runtime support code --- */") + + def initialize_main_h_code(self): + rootwriter = self.rootwriter + for part in self.h_code_layout: + self.parts[part] = rootwriter.insertion_point() + + def finalize_main_c_code(self): + self.close_global_decls() + + # + # utility_code_def + # + code = self.parts['utility_code_def'] + util = TempitaUtilityCode.load_cached("TypeConversions", "TypeConversion.c") + code.put(util.format_code(util.impl)) + code.putln("") + + # + # utility code pragmas + # + code = self.parts['utility_code_pragmas'] + util = UtilityCode.load_cached("UtilityCodePragmas", "ModuleSetupCode.c") + code.putln(util.format_code(util.impl)) + code.putln("") + code = self.parts['utility_code_pragmas_end'] + util = UtilityCode.load_cached("UtilityCodePragmasEnd", "ModuleSetupCode.c") + code.putln(util.format_code(util.impl)) + code.putln("") + + def __getitem__(self, key): + return self.parts[key] + + # + # Global constants, interned objects, etc. + # + def close_global_decls(self): + # This is called when it is known that no more global declarations will + # declared. + self.generate_const_declarations() + if Options.cache_builtins: + w = self.parts['cached_builtins'] + w.putln("return 0;") + if w.label_used(w.error_label): + w.put_label(w.error_label) + w.putln("return -1;") + w.putln("}") + w.exit_cfunc_scope() + + w = self.parts['cached_constants'] + w.put_finish_refcount_context() + w.putln("return 0;") + if w.label_used(w.error_label): + w.put_label(w.error_label) + w.put_finish_refcount_context() + w.putln("return -1;") + w.putln("}") + w.exit_cfunc_scope() + + for part in ['init_globals', 'init_constants']: + w = self.parts[part] + w.putln("return 0;") + if w.label_used(w.error_label): + w.put_label(w.error_label) + w.putln("return -1;") + w.putln("}") + w.exit_cfunc_scope() + + if Options.generate_cleanup_code: + w = self.parts['cleanup_globals'] + w.putln("}") + w.exit_cfunc_scope() + + if Options.generate_cleanup_code: + w = self.parts['cleanup_module'] + w.putln("}") + w.exit_cfunc_scope() + + def put_pyobject_decl(self, entry): + self['global_var'].putln("static PyObject *%s;" % entry.cname) + + # constant handling at code generation time + + def get_cached_constants_writer(self, target=None): + if target is not None: + if target in self.initialised_constants: + # Return None on second/later calls to prevent duplicate creation code. + return None + self.initialised_constants.add(target) + return self.parts['cached_constants'] + + def get_int_const(self, str_value, longness=False): + py_type = longness and 'long' or 'int' + try: + c = self.num_const_index[(str_value, py_type)] + except KeyError: + c = self.new_num_const(str_value, py_type) + return c + + def get_float_const(self, str_value, value_code): + try: + c = self.num_const_index[(str_value, 'float')] + except KeyError: + c = self.new_num_const(str_value, 'float', value_code) + return c + + def get_py_const(self, type, prefix='', cleanup_level=None, dedup_key=None): + if dedup_key is not None: + const = self.dedup_const_index.get(dedup_key) + if const is not None: + return const + # create a new Python object constant + const = self.new_py_const(type, prefix) + if (cleanup_level is not None + and cleanup_level <= Options.generate_cleanup_code + # Note that this function is used for all argument defaults + # which aren't just Python objects + and type.needs_refcounting): + cleanup_writer = self.parts['cleanup_globals'] + cleanup_writer.putln('Py_CLEAR(%s);' % const.cname) + if dedup_key is not None: + self.dedup_const_index[dedup_key] = const + return const + + def get_string_const(self, text, py_version=None): + # return a C string constant, creating a new one if necessary + if text.is_unicode: + byte_string = text.utf8encode() + else: + byte_string = text.byteencode() + try: + c = self.string_const_index[byte_string] + except KeyError: + c = self.new_string_const(text, byte_string) + c.add_py_version(py_version) + return c + + def get_pyunicode_ptr_const(self, text): + # return a Py_UNICODE[] constant, creating a new one if necessary + assert text.is_unicode + try: + c = self.pyunicode_ptr_const_index[text] + except KeyError: + c = self.pyunicode_ptr_const_index[text] = self.new_const_cname() + return c + + def get_py_string_const(self, text, identifier=None, + is_str=False, unicode_value=None): + # return a Python string constant, creating a new one if necessary + py3str_cstring = None + if is_str and unicode_value is not None \ + and unicode_value.utf8encode() != text.byteencode(): + py3str_cstring = self.get_string_const(unicode_value, py_version=3) + c_string = self.get_string_const(text, py_version=2) + else: + c_string = self.get_string_const(text) + py_string = c_string.get_py_string_const( + text.encoding, identifier, is_str, py3str_cstring) + return py_string + + def get_interned_identifier(self, text): + return self.get_py_string_const(text, identifier=True) + + def new_string_const(self, text, byte_string): + cname = self.new_string_const_cname(byte_string) + c = StringConst(cname, text, byte_string) + self.string_const_index[byte_string] = c + return c + + def new_num_const(self, value, py_type, value_code=None): + cname = self.new_num_const_cname(value, py_type) + c = NumConst(cname, value, py_type, value_code) + self.num_const_index[(value, py_type)] = c + return c + + def new_py_const(self, type, prefix=''): + cname = self.new_const_cname(prefix) + c = PyObjectConst(cname, type) + self.py_constants.append(c) + return c + + def new_string_const_cname(self, bytes_value): + # Create a new globally-unique nice name for a C string constant. + value = bytes_value.decode('ASCII', 'ignore') + return self.new_const_cname(value=value) + + def unique_const_cname(self, format_str): # type: (str) -> str + used = self.const_cnames_used + cname = value = format_str.format(sep='', counter='') + while cname in used: + counter = used[value] = used[value] + 1 + cname = format_str.format(sep='_', counter=counter) + used[cname] = 1 + return cname + + def new_num_const_cname(self, value, py_type): # type: (str, str) -> str + if py_type == 'long': + value += 'L' + py_type = 'int' + prefix = Naming.interned_prefixes[py_type] + + value = value.replace('.', '_').replace('+', '_').replace('-', 'neg_') + if len(value) > 42: + # update tests/run/large_integer_T5290.py in case the amount is changed + cname = self.unique_const_cname( + prefix + "large{counter}_" + value[:18] + "_xxx_" + value[-18:]) + else: + cname = "%s%s" % (prefix, value) + return cname + + def new_const_cname(self, prefix='', value=''): + value = replace_identifier('_', value)[:32].strip('_') + name_suffix = self.unique_const_cname(value + "{sep}{counter}") + if prefix: + prefix = Naming.interned_prefixes[prefix] + else: + prefix = Naming.const_prefix + return "%s%s" % (prefix, name_suffix) + + def get_cached_unbound_method(self, type_cname, method_name): + key = (type_cname, method_name) + try: + cname = self.cached_cmethods[key] + except KeyError: + cname = self.cached_cmethods[key] = self.new_const_cname( + 'umethod', '%s_%s' % (type_cname, method_name)) + return cname + + def cached_unbound_method_call_code(self, obj_cname, type_cname, method_name, arg_cnames): + # admittedly, not the best place to put this method, but it is reused by UtilityCode and ExprNodes ... + utility_code_name = "CallUnboundCMethod%d" % len(arg_cnames) + self.use_utility_code(UtilityCode.load_cached(utility_code_name, "ObjectHandling.c")) + cache_cname = self.get_cached_unbound_method(type_cname, method_name) + args = [obj_cname] + arg_cnames + return "__Pyx_%s(&%s, %s)" % ( + utility_code_name, + cache_cname, + ', '.join(args), + ) + + def add_cached_builtin_decl(self, entry): + if entry.is_builtin and entry.is_const: + if self.should_declare(entry.cname, entry): + self.put_pyobject_decl(entry) + w = self.parts['cached_builtins'] + condition = None + if entry.name in non_portable_builtins_map: + condition, replacement = non_portable_builtins_map[entry.name] + w.putln('#if %s' % condition) + self.put_cached_builtin_init( + entry.pos, StringEncoding.EncodedString(replacement), + entry.cname) + w.putln('#else') + self.put_cached_builtin_init( + entry.pos, StringEncoding.EncodedString(entry.name), + entry.cname) + if condition: + w.putln('#endif') + + def put_cached_builtin_init(self, pos, name, cname): + w = self.parts['cached_builtins'] + interned_cname = self.get_interned_identifier(name).cname + self.use_utility_code( + UtilityCode.load_cached("GetBuiltinName", "ObjectHandling.c")) + w.putln('%s = __Pyx_GetBuiltinName(%s); if (!%s) %s' % ( + cname, + interned_cname, + cname, + w.error_goto(pos))) + + def generate_const_declarations(self): + self.generate_cached_methods_decls() + self.generate_string_constants() + self.generate_num_constants() + self.generate_object_constant_decls() + + def generate_object_constant_decls(self): + consts = [(len(c.cname), c.cname, c) + for c in self.py_constants] + consts.sort() + for _, cname, c in consts: + self.parts['module_state'].putln("%s;" % c.type.declaration_code(cname)) + self.parts['module_state_defines'].putln( + "#define %s %s->%s" % (cname, Naming.modulestateglobal_cname, cname)) + if not c.type.needs_refcounting: + # Note that py_constants is used for all argument defaults + # which aren't necessarily PyObjects, so aren't appropriate + # to clear. + continue + + self.parts['module_state_clear'].put_xdecref_clear( + "clear_module_state->%s" % cname, + c.type, + clear_before_decref=True, + nanny=False, + ) + + if c.type.is_memoryviewslice: + # TODO: Implement specific to type like CodeWriter.put_xdecref_clear() + cname += "->memview" + + self.parts['module_state_traverse'].putln( + "Py_VISIT(traverse_module_state->%s);" % cname) + + def generate_cached_methods_decls(self): + if not self.cached_cmethods: + return + + decl = self.parts['decls'] + init = self.parts['init_constants'] + cnames = [] + for (type_cname, method_name), cname in sorted(self.cached_cmethods.items()): + cnames.append(cname) + method_name_cname = self.get_interned_identifier(StringEncoding.EncodedString(method_name)).cname + decl.putln('static __Pyx_CachedCFunction %s = {0, 0, 0, 0, 0};' % ( + cname)) + # split type reference storage as it might not be static + init.putln('%s.type = (PyObject*)&%s;' % ( + cname, type_cname)) + # method name string isn't static in limited api + init.putln('%s.method_name = &%s;' % ( + cname, method_name_cname)) + + if Options.generate_cleanup_code: + cleanup = self.parts['cleanup_globals'] + for cname in cnames: + cleanup.putln("Py_CLEAR(%s.method);" % cname) + + def generate_string_constants(self): + c_consts = [(len(c.cname), c.cname, c) for c in self.string_const_index.values()] + c_consts.sort() + py_strings = [] + + decls_writer = self.parts['string_decls'] + for _, cname, c in c_consts: + conditional = False + if c.py_versions and (2 not in c.py_versions or 3 not in c.py_versions): + conditional = True + decls_writer.putln("#if PY_MAJOR_VERSION %s 3" % ( + (2 in c.py_versions) and '<' or '>=')) + decls_writer.putln('static const char %s[] = "%s";' % ( + cname, StringEncoding.split_string_literal(c.escaped_value)), + safe=True) # Braces in user strings are not for indentation. + if conditional: + decls_writer.putln("#endif") + if c.py_strings is not None: + for py_string in c.py_strings.values(): + py_strings.append((c.cname, len(py_string.cname), py_string)) + + for c, cname in sorted(self.pyunicode_ptr_const_index.items()): + utf16_array, utf32_array = StringEncoding.encode_pyunicode_string(c) + if utf16_array: + # Narrow and wide representations differ + decls_writer.putln("#ifdef Py_UNICODE_WIDE") + decls_writer.putln("static Py_UNICODE %s[] = { %s };" % (cname, utf32_array)) + if utf16_array: + decls_writer.putln("#else") + decls_writer.putln("static Py_UNICODE %s[] = { %s };" % (cname, utf16_array)) + decls_writer.putln("#endif") + + init_constants = self.parts['init_constants'] + if py_strings: + self.use_utility_code(UtilityCode.load_cached("InitStrings", "StringTools.c")) + py_strings.sort() + w = self.parts['pystring_table'] + w.putln("") + w.putln("static int __Pyx_CreateStringTabAndInitStrings(void) {") + # the stringtab is a function local rather than a global to + # ensure that it doesn't conflict with module state + w.putln("__Pyx_StringTabEntry %s[] = {" % Naming.stringtab_cname) + for py_string_args in py_strings: + c_cname, _, py_string = py_string_args + if not py_string.is_str or not py_string.encoding or \ + py_string.encoding in ('ASCII', 'USASCII', 'US-ASCII', + 'UTF8', 'UTF-8'): + encoding = '0' + else: + encoding = '"%s"' % py_string.encoding.lower() + + self.parts['module_state'].putln("PyObject *%s;" % py_string.cname) + self.parts['module_state_defines'].putln("#define %s %s->%s" % ( + py_string.cname, + Naming.modulestateglobal_cname, + py_string.cname)) + self.parts['module_state_clear'].putln("Py_CLEAR(clear_module_state->%s);" % + py_string.cname) + self.parts['module_state_traverse'].putln("Py_VISIT(traverse_module_state->%s);" % + py_string.cname) + if py_string.py3str_cstring: + w.putln("#if PY_MAJOR_VERSION >= 3") + w.putln("{&%s, %s, sizeof(%s), %s, %d, %d, %d}," % ( + py_string.cname, + py_string.py3str_cstring.cname, + py_string.py3str_cstring.cname, + '0', 1, 0, + py_string.intern + )) + w.putln("#else") + w.putln("{&%s, %s, sizeof(%s), %s, %d, %d, %d}," % ( + py_string.cname, + c_cname, + c_cname, + encoding, + py_string.is_unicode, + py_string.is_str, + py_string.intern + )) + if py_string.py3str_cstring: + w.putln("#endif") + w.putln("{0, 0, 0, 0, 0, 0, 0}") + w.putln("};") + w.putln("return __Pyx_InitStrings(%s);" % Naming.stringtab_cname) + w.putln("}") + + init_constants.putln( + "if (__Pyx_CreateStringTabAndInitStrings() < 0) %s;" % + init_constants.error_goto(self.module_pos)) + + def generate_num_constants(self): + consts = [(c.py_type, c.value[0] == '-', len(c.value), c.value, c.value_code, c) + for c in self.num_const_index.values()] + consts.sort() + init_constants = self.parts['init_constants'] + for py_type, _, _, value, value_code, c in consts: + cname = c.cname + self.parts['module_state'].putln("PyObject *%s;" % cname) + self.parts['module_state_defines'].putln("#define %s %s->%s" % ( + cname, Naming.modulestateglobal_cname, cname)) + self.parts['module_state_clear'].putln( + "Py_CLEAR(clear_module_state->%s);" % cname) + self.parts['module_state_traverse'].putln( + "Py_VISIT(traverse_module_state->%s);" % cname) + if py_type == 'float': + function = 'PyFloat_FromDouble(%s)' + elif py_type == 'long': + function = 'PyLong_FromString((char *)"%s", 0, 0)' + elif Utils.long_literal(value): + function = 'PyInt_FromString((char *)"%s", 0, 0)' + elif len(value.lstrip('-')) > 4: + function = "PyInt_FromLong(%sL)" + else: + function = "PyInt_FromLong(%s)" + init_constants.putln('%s = %s; %s' % ( + cname, function % value_code, + init_constants.error_goto_if_null(cname, self.module_pos))) + + # The functions below are there in a transition phase only + # and will be deprecated. They are called from Nodes.BlockNode. + # The copy&paste duplication is intentional in order to be able + # to see quickly how BlockNode worked, until this is replaced. + + def should_declare(self, cname, entry): + if cname in self.declared_cnames: + other = self.declared_cnames[cname] + assert str(entry.type) == str(other.type) + assert entry.init == other.init + return False + else: + self.declared_cnames[cname] = entry + return True + + # + # File name state + # + + def lookup_filename(self, source_desc): + entry = source_desc.get_filenametable_entry() + try: + index = self.filename_table[entry] + except KeyError: + index = len(self.filename_list) + self.filename_list.append(source_desc) + self.filename_table[entry] = index + return index + + def commented_file_contents(self, source_desc): + try: + return self.input_file_contents[source_desc] + except KeyError: + pass + source_file = source_desc.get_lines(encoding='ASCII', + error_handling='ignore') + try: + F = [u' * ' + line.rstrip().replace( + u'*/', u'*[inserted by cython to avoid comment closer]/' + ).replace( + u'/*', u'/[inserted by cython to avoid comment start]*' + ) + for line in source_file] + finally: + if hasattr(source_file, 'close'): + source_file.close() + if not F: F.append(u'') + self.input_file_contents[source_desc] = F + return F + + # + # Utility code state + # + + def use_utility_code(self, utility_code): + """ + Adds code to the C file. utility_code should + a) implement __eq__/__hash__ for the purpose of knowing whether the same + code has already been included + b) implement put_code, which takes a globalstate instance + + See UtilityCode. + """ + if utility_code and utility_code not in self.utility_codes: + self.utility_codes.add(utility_code) + utility_code.put_code(self) + + def use_entry_utility_code(self, entry): + if entry is None: + return + if entry.utility_code: + self.use_utility_code(entry.utility_code) + if entry.utility_code_definition: + self.use_utility_code(entry.utility_code_definition) + + +def funccontext_property(func): + name = func.__name__ + attribute_of = operator.attrgetter(name) + def get(self): + return attribute_of(self.funcstate) + def set(self, value): + setattr(self.funcstate, name, value) + return property(get, set) + + +class CCodeConfig(object): + # emit_linenums boolean write #line pragmas? + # emit_code_comments boolean copy the original code into C comments? + # c_line_in_traceback boolean append the c file and line number to the traceback for exceptions? + + def __init__(self, emit_linenums=True, emit_code_comments=True, c_line_in_traceback=True): + self.emit_code_comments = emit_code_comments + self.emit_linenums = emit_linenums + self.c_line_in_traceback = c_line_in_traceback + + +class CCodeWriter(object): + """ + Utility class to output C code. + + When creating an insertion point one must care about the state that is + kept: + - formatting state (level, bol) is cloned and used in insertion points + as well + - labels, temps, exc_vars: One must construct a scope in which these can + exist by calling enter_cfunc_scope/exit_cfunc_scope (these are for + sanity checking and forward compatibility). Created insertion points + looses this scope and cannot access it. + - marker: Not copied to insertion point + - filename_table, filename_list, input_file_contents: All codewriters + coming from the same root share the same instances simultaneously. + """ + + # f file output file + # buffer StringIOTree + + # level int indentation level + # bol bool beginning of line? + # marker string comment to emit before next line + # funcstate FunctionState contains state local to a C function used for code + # generation (labels and temps state etc.) + # globalstate GlobalState contains state global for a C file (input file info, + # utility code, declared constants etc.) + # pyclass_stack list used during recursive code generation to pass information + # about the current class one is in + # code_config CCodeConfig configuration options for the C code writer + + @cython.locals(create_from='CCodeWriter') + def __init__(self, create_from=None, buffer=None, copy_formatting=False): + if buffer is None: buffer = StringIOTree() + self.buffer = buffer + self.last_pos = None + self.last_marked_pos = None + self.pyclass_stack = [] + + self.funcstate = None + self.globalstate = None + self.code_config = None + self.level = 0 + self.call_level = 0 + self.bol = 1 + + if create_from is not None: + # Use same global state + self.set_global_state(create_from.globalstate) + self.funcstate = create_from.funcstate + # Clone formatting state + if copy_formatting: + self.level = create_from.level + self.bol = create_from.bol + self.call_level = create_from.call_level + self.last_pos = create_from.last_pos + self.last_marked_pos = create_from.last_marked_pos + + def create_new(self, create_from, buffer, copy_formatting): + # polymorphic constructor -- very slightly more versatile + # than using __class__ + result = CCodeWriter(create_from, buffer, copy_formatting) + return result + + def set_global_state(self, global_state): + assert self.globalstate is None # prevent overwriting once it's set + self.globalstate = global_state + self.code_config = global_state.code_config + + def copyto(self, f): + self.buffer.copyto(f) + + def getvalue(self): + return self.buffer.getvalue() + + def write(self, s): + if '\n' in s: + self._write_lines(s) + else: + self._write_to_buffer(s) + + def _write_lines(self, s): + # Cygdb needs to know which Cython source line corresponds to which C line. + # Therefore, we write this information into "self.buffer.markers" and then write it from there + # into cython_debug/cython_debug_info_* (see ModuleNode._serialize_lineno_map). + filename_line = self.last_marked_pos[:2] if self.last_marked_pos else (None, 0) + self.buffer.markers.extend([filename_line] * s.count('\n')) + + self._write_to_buffer(s) + + def _write_to_buffer(self, s): + self.buffer.write(s) + + def insertion_point(self): + other = self.create_new(create_from=self, buffer=self.buffer.insertion_point(), copy_formatting=True) + return other + + def new_writer(self): + """ + Creates a new CCodeWriter connected to the same global state, which + can later be inserted using insert. + """ + return CCodeWriter(create_from=self) + + def insert(self, writer): + """ + Inserts the contents of another code writer (created with + the same global state) in the current location. + + It is ok to write to the inserted writer also after insertion. + """ + assert writer.globalstate is self.globalstate + self.buffer.insert(writer.buffer) + + # Properties delegated to function scope + @funccontext_property + def label_counter(self): pass + @funccontext_property + def return_label(self): pass + @funccontext_property + def error_label(self): pass + @funccontext_property + def labels_used(self): pass + @funccontext_property + def continue_label(self): pass + @funccontext_property + def break_label(self): pass + @funccontext_property + def return_from_error_cleanup_label(self): pass + @funccontext_property + def yield_labels(self): pass + + def label_interceptor(self, new_labels, orig_labels, skip_to_label=None, pos=None, trace=True): + """ + Helper for generating multiple label interceptor code blocks. + + @param new_labels: the new labels that should be intercepted + @param orig_labels: the original labels that we should dispatch to after the interception + @param skip_to_label: a label to skip to before starting the code blocks + @param pos: the node position to mark for each interceptor block + @param trace: add a trace line for the pos marker or not + """ + for label, orig_label in zip(new_labels, orig_labels): + if not self.label_used(label): + continue + if skip_to_label: + # jump over the whole interception block + self.put_goto(skip_to_label) + skip_to_label = None + + if pos is not None: + self.mark_pos(pos, trace=trace) + self.put_label(label) + yield (label, orig_label) + self.put_goto(orig_label) + + # Functions delegated to function scope + def new_label(self, name=None): return self.funcstate.new_label(name) + def new_error_label(self, *args): return self.funcstate.new_error_label(*args) + def new_yield_label(self, *args): return self.funcstate.new_yield_label(*args) + def get_loop_labels(self): return self.funcstate.get_loop_labels() + def set_loop_labels(self, labels): return self.funcstate.set_loop_labels(labels) + def new_loop_labels(self, *args): return self.funcstate.new_loop_labels(*args) + def get_all_labels(self): return self.funcstate.get_all_labels() + def set_all_labels(self, labels): return self.funcstate.set_all_labels(labels) + def all_new_labels(self): return self.funcstate.all_new_labels() + def use_label(self, lbl): return self.funcstate.use_label(lbl) + def label_used(self, lbl): return self.funcstate.label_used(lbl) + + + def enter_cfunc_scope(self, scope=None): + self.funcstate = FunctionState(self, scope=scope) + + def exit_cfunc_scope(self): + self.funcstate.validate_exit() + self.funcstate = None + + # constant handling + + def get_py_int(self, str_value, longness): + return self.globalstate.get_int_const(str_value, longness).cname + + def get_py_float(self, str_value, value_code): + return self.globalstate.get_float_const(str_value, value_code).cname + + def get_py_const(self, type, prefix='', cleanup_level=None, dedup_key=None): + return self.globalstate.get_py_const(type, prefix, cleanup_level, dedup_key).cname + + def get_string_const(self, text): + return self.globalstate.get_string_const(text).cname + + def get_pyunicode_ptr_const(self, text): + return self.globalstate.get_pyunicode_ptr_const(text) + + def get_py_string_const(self, text, identifier=None, + is_str=False, unicode_value=None): + return self.globalstate.get_py_string_const( + text, identifier, is_str, unicode_value).cname + + def get_argument_default_const(self, type): + return self.globalstate.get_py_const(type).cname + + def intern(self, text): + return self.get_py_string_const(text) + + def intern_identifier(self, text): + return self.get_py_string_const(text, identifier=True) + + def get_cached_constants_writer(self, target=None): + return self.globalstate.get_cached_constants_writer(target) + + # code generation + + def putln(self, code="", safe=False): + if self.last_pos and self.bol: + self.emit_marker() + if self.code_config.emit_linenums and self.last_marked_pos: + source_desc, line, _ = self.last_marked_pos + self._write_lines('\n#line %s "%s"\n' % (line, source_desc.get_escaped_description())) + if code: + if safe: + self.put_safe(code) + else: + self.put(code) + self._write_lines("\n") + self.bol = 1 + + def mark_pos(self, pos, trace=True): + if pos is None: + return + if self.last_marked_pos and self.last_marked_pos[:2] == pos[:2]: + return + self.last_pos = (pos, trace) + + def emit_marker(self): + pos, trace = self.last_pos + self.last_marked_pos = pos + self.last_pos = None + self._write_lines("\n") + if self.code_config.emit_code_comments: + self.indent() + self._write_lines("/* %s */\n" % self._build_marker(pos)) + if trace and self.funcstate and self.funcstate.can_trace and self.globalstate.directives['linetrace']: + self.indent() + self._write_lines('__Pyx_TraceLine(%d,%d,%s)\n' % ( + pos[1], not self.funcstate.gil_owned, self.error_goto(pos))) + + def _build_marker(self, pos): + source_desc, line, col = pos + assert isinstance(source_desc, SourceDescriptor) + contents = self.globalstate.commented_file_contents(source_desc) + lines = contents[max(0, line-3):line] # line numbers start at 1 + lines[-1] += u' # <<<<<<<<<<<<<<' + lines += contents[line:line+2] + return u'"%s":%d\n%s\n' % (source_desc.get_escaped_description(), line, u'\n'.join(lines)) + + def put_safe(self, code): + # put code, but ignore {} + self.write(code) + self.bol = 0 + + def put_or_include(self, code, name): + include_dir = self.globalstate.common_utility_include_dir + if include_dir and len(code) > 1024: + include_file = "%s_%s.h" % ( + name, hashlib.sha1(code.encode('utf8')).hexdigest()) + path = os.path.join(include_dir, include_file) + if not os.path.exists(path): + tmp_path = '%s.tmp%s' % (path, os.getpid()) + with closing(Utils.open_new_file(tmp_path)) as f: + f.write(code) + shutil.move(tmp_path, path) + code = '#include "%s"\n' % path + self.put(code) + + def put(self, code): + fix_indent = False + if "{" in code: + dl = code.count("{") + else: + dl = 0 + if "}" in code: + dl -= code.count("}") + if dl < 0: + self.level += dl + elif dl == 0 and code[0] == "}": + # special cases like "} else {" need a temporary dedent + fix_indent = True + self.level -= 1 + if self.bol: + self.indent() + self.write(code) + self.bol = 0 + if dl > 0: + self.level += dl + elif fix_indent: + self.level += 1 + + def increase_indent(self): + self.level += 1 + + def decrease_indent(self): + self.level -= 1 + + def begin_block(self): + self.putln("{") + self.increase_indent() + + def end_block(self): + self.decrease_indent() + self.putln("}") + + def indent(self): + self._write_to_buffer(" " * self.level) + + def get_py_version_hex(self, pyversion): + return "0x%02X%02X%02X%02X" % (tuple(pyversion) + (0,0,0,0))[:4] + + def put_label(self, lbl): + if lbl in self.funcstate.labels_used: + self.putln("%s:;" % lbl) + + def put_goto(self, lbl): + self.funcstate.use_label(lbl) + self.putln("goto %s;" % lbl) + + def put_var_declaration(self, entry, storage_class="", + dll_linkage=None, definition=True): + #print "Code.put_var_declaration:", entry.name, "definition =", definition ### + if entry.visibility == 'private' and not (definition or entry.defined_in_pxd): + #print "...private and not definition, skipping", entry.cname ### + return + if entry.visibility == "private" and not entry.used: + #print "...private and not used, skipping", entry.cname ### + return + if not entry.cf_used: + self.put('CYTHON_UNUSED ') + if storage_class: + self.put("%s " % storage_class) + if entry.is_cpp_optional: + self.put(entry.type.cpp_optional_declaration_code( + entry.cname, dll_linkage=dll_linkage)) + else: + self.put(entry.type.declaration_code( + entry.cname, dll_linkage=dll_linkage)) + if entry.init is not None: + self.put_safe(" = %s" % entry.type.literal_code(entry.init)) + elif entry.type.is_pyobject: + self.put(" = NULL") + self.putln(";") + self.funcstate.scope.use_entry_utility_code(entry) + + def put_temp_declarations(self, func_context): + for name, type, manage_ref, static in func_context.temps_allocated: + if type.is_cpp_class and not type.is_fake_reference and func_context.scope.directives['cpp_locals']: + decl = type.cpp_optional_declaration_code(name) + else: + decl = type.declaration_code(name) + if type.is_pyobject: + self.putln("%s = NULL;" % decl) + elif type.is_memoryviewslice: + self.putln("%s = %s;" % (decl, type.literal_code(type.default_value))) + else: + self.putln("%s%s;" % (static and "static " or "", decl)) + + if func_context.should_declare_error_indicator: + if self.funcstate.uses_error_indicator: + unused = '' + else: + unused = 'CYTHON_UNUSED ' + # Initialize these variables to silence compiler warnings + self.putln("%sint %s = 0;" % (unused, Naming.lineno_cname)) + self.putln("%sconst char *%s = NULL;" % (unused, Naming.filename_cname)) + self.putln("%sint %s = 0;" % (unused, Naming.clineno_cname)) + + def put_generated_by(self): + self.putln(Utils.GENERATED_BY_MARKER) + self.putln("") + + def put_h_guard(self, guard): + self.putln("#ifndef %s" % guard) + self.putln("#define %s" % guard) + + def unlikely(self, cond): + if Options.gcc_branch_hints: + return 'unlikely(%s)' % cond + else: + return cond + + def build_function_modifiers(self, modifiers, mapper=modifier_output_mapper): + if not modifiers: + return '' + return '%s ' % ' '.join([mapper(m,m) for m in modifiers]) + + # Python objects and reference counting + + def entry_as_pyobject(self, entry): + type = entry.type + if (not entry.is_self_arg and not entry.type.is_complete() + or entry.type.is_extension_type): + return "(PyObject *)" + entry.cname + else: + return entry.cname + + def as_pyobject(self, cname, type): + from .PyrexTypes import py_object_type, typecast + return typecast(py_object_type, type, cname) + + def put_gotref(self, cname, type): + type.generate_gotref(self, cname) + + def put_giveref(self, cname, type): + type.generate_giveref(self, cname) + + def put_xgiveref(self, cname, type): + type.generate_xgiveref(self, cname) + + def put_xgotref(self, cname, type): + type.generate_xgotref(self, cname) + + def put_incref(self, cname, type, nanny=True): + # Note: original put_Memslice_Incref/Decref also added in some utility code + # this is unnecessary since the relevant utility code is loaded anyway if a memoryview is used + # and so has been removed. However, it's potentially a feature that might be useful here + type.generate_incref(self, cname, nanny=nanny) + + def put_xincref(self, cname, type, nanny=True): + type.generate_xincref(self, cname, nanny=nanny) + + def put_decref(self, cname, type, nanny=True, have_gil=True): + type.generate_decref(self, cname, nanny=nanny, have_gil=have_gil) + + def put_xdecref(self, cname, type, nanny=True, have_gil=True): + type.generate_xdecref(self, cname, nanny=nanny, have_gil=have_gil) + + def put_decref_clear(self, cname, type, clear_before_decref=False, nanny=True, have_gil=True): + type.generate_decref_clear(self, cname, clear_before_decref=clear_before_decref, + nanny=nanny, have_gil=have_gil) + + def put_xdecref_clear(self, cname, type, clear_before_decref=False, nanny=True, have_gil=True): + type.generate_xdecref_clear(self, cname, clear_before_decref=clear_before_decref, + nanny=nanny, have_gil=have_gil) + + def put_decref_set(self, cname, type, rhs_cname): + type.generate_decref_set(self, cname, rhs_cname) + + def put_xdecref_set(self, cname, type, rhs_cname): + type.generate_xdecref_set(self, cname, rhs_cname) + + def put_incref_memoryviewslice(self, slice_cname, type, have_gil): + # TODO ideally this would just be merged into "put_incref" + type.generate_incref_memoryviewslice(self, slice_cname, have_gil=have_gil) + + def put_var_incref_memoryviewslice(self, entry, have_gil): + self.put_incref_memoryviewslice(entry.cname, entry.type, have_gil=have_gil) + + def put_var_gotref(self, entry): + self.put_gotref(entry.cname, entry.type) + + def put_var_giveref(self, entry): + self.put_giveref(entry.cname, entry.type) + + def put_var_xgotref(self, entry): + self.put_xgotref(entry.cname, entry.type) + + def put_var_xgiveref(self, entry): + self.put_xgiveref(entry.cname, entry.type) + + def put_var_incref(self, entry, **kwds): + self.put_incref(entry.cname, entry.type, **kwds) + + def put_var_xincref(self, entry, **kwds): + self.put_xincref(entry.cname, entry.type, **kwds) + + def put_var_decref(self, entry, **kwds): + self.put_decref(entry.cname, entry.type, **kwds) + + def put_var_xdecref(self, entry, **kwds): + self.put_xdecref(entry.cname, entry.type, **kwds) + + def put_var_decref_clear(self, entry, **kwds): + self.put_decref_clear(entry.cname, entry.type, clear_before_decref=entry.in_closure, **kwds) + + def put_var_decref_set(self, entry, rhs_cname, **kwds): + self.put_decref_set(entry.cname, entry.type, rhs_cname, **kwds) + + def put_var_xdecref_set(self, entry, rhs_cname, **kwds): + self.put_xdecref_set(entry.cname, entry.type, rhs_cname, **kwds) + + def put_var_xdecref_clear(self, entry, **kwds): + self.put_xdecref_clear(entry.cname, entry.type, clear_before_decref=entry.in_closure, **kwds) + + def put_var_decrefs(self, entries, used_only = 0): + for entry in entries: + if not used_only or entry.used: + if entry.xdecref_cleanup: + self.put_var_xdecref(entry) + else: + self.put_var_decref(entry) + + def put_var_xdecrefs(self, entries): + for entry in entries: + self.put_var_xdecref(entry) + + def put_var_xdecrefs_clear(self, entries): + for entry in entries: + self.put_var_xdecref_clear(entry) + + def put_init_to_py_none(self, cname, type, nanny=True): + from .PyrexTypes import py_object_type, typecast + py_none = typecast(type, py_object_type, "Py_None") + if nanny: + self.putln("%s = %s; __Pyx_INCREF(Py_None);" % (cname, py_none)) + else: + self.putln("%s = %s; Py_INCREF(Py_None);" % (cname, py_none)) + + def put_init_var_to_py_none(self, entry, template = "%s", nanny=True): + code = template % entry.cname + #if entry.type.is_extension_type: + # code = "((PyObject*)%s)" % code + self.put_init_to_py_none(code, entry.type, nanny) + if entry.in_closure: + self.put_giveref('Py_None') + + def put_pymethoddef(self, entry, term, allow_skip=True, wrapper_code_writer=None): + is_reverse_number_slot = False + if entry.is_special or entry.name == '__getattribute__': + from . import TypeSlots + is_reverse_number_slot = True + if entry.name not in special_py_methods and not TypeSlots.is_reverse_number_slot(entry.name): + if entry.name == '__getattr__' and not self.globalstate.directives['fast_getattr']: + pass + # Python's typeobject.c will automatically fill in our slot + # in add_operators() (called by PyType_Ready) with a value + # that's better than ours. + elif allow_skip: + return + + method_flags = entry.signature.method_flags() + if not method_flags: + return + if entry.is_special: + method_flags += [TypeSlots.method_coexist] + func_ptr = wrapper_code_writer.put_pymethoddef_wrapper(entry) if wrapper_code_writer else entry.func_cname + # Add required casts, but try not to shadow real warnings. + cast = entry.signature.method_function_type() + if cast != 'PyCFunction': + func_ptr = '(void*)(%s)%s' % (cast, func_ptr) + entry_name = entry.name.as_c_string_literal() + if is_reverse_number_slot: + # Unlike most special functions, reverse number operator slots are actually generated here + # (to ensure that they can be looked up). However, they're sometimes guarded by the preprocessor + # so a bit of extra logic is needed + slot = TypeSlots.get_slot_table(self.globalstate.directives).get_slot_by_method_name(entry.name) + preproc_guard = slot.preprocessor_guard_code() + if preproc_guard: + self.putln(preproc_guard) + self.putln( + '{%s, (PyCFunction)%s, %s, %s}%s' % ( + entry_name, + func_ptr, + "|".join(method_flags), + entry.doc_cname if entry.doc else '0', + term)) + if is_reverse_number_slot and preproc_guard: + self.putln("#endif") + + def put_pymethoddef_wrapper(self, entry): + func_cname = entry.func_cname + if entry.is_special: + method_flags = entry.signature.method_flags() or [] + from .TypeSlots import method_noargs + if method_noargs in method_flags: + # Special NOARGS methods really take no arguments besides 'self', but PyCFunction expects one. + func_cname = Naming.method_wrapper_prefix + func_cname + self.putln("static PyObject *%s(PyObject *self, CYTHON_UNUSED PyObject *arg) {" % func_cname) + func_call = "%s(self)" % entry.func_cname + if entry.name == "__next__": + self.putln("PyObject *res = %s;" % func_call) + # tp_iternext can return NULL without an exception + self.putln("if (!res && !PyErr_Occurred()) { PyErr_SetNone(PyExc_StopIteration); }") + self.putln("return res;") + else: + self.putln("return %s;" % func_call) + self.putln("}") + return func_cname + + # GIL methods + + def use_fast_gil_utility_code(self): + if self.globalstate.directives['fast_gil']: + self.globalstate.use_utility_code(UtilityCode.load_cached("FastGil", "ModuleSetupCode.c")) + else: + self.globalstate.use_utility_code(UtilityCode.load_cached("NoFastGil", "ModuleSetupCode.c")) + + def put_ensure_gil(self, declare_gilstate=True, variable=None): + """ + Acquire the GIL. The generated code is safe even when no PyThreadState + has been allocated for this thread (for threads not initialized by + using the Python API). Additionally, the code generated by this method + may be called recursively. + """ + self.globalstate.use_utility_code( + UtilityCode.load_cached("ForceInitThreads", "ModuleSetupCode.c")) + self.use_fast_gil_utility_code() + self.putln("#ifdef WITH_THREAD") + if not variable: + variable = '__pyx_gilstate_save' + if declare_gilstate: + self.put("PyGILState_STATE ") + self.putln("%s = __Pyx_PyGILState_Ensure();" % variable) + self.putln("#endif") + + def put_release_ensured_gil(self, variable=None): + """ + Releases the GIL, corresponds to `put_ensure_gil`. + """ + self.use_fast_gil_utility_code() + if not variable: + variable = '__pyx_gilstate_save' + self.putln("#ifdef WITH_THREAD") + self.putln("__Pyx_PyGILState_Release(%s);" % variable) + self.putln("#endif") + + def put_acquire_gil(self, variable=None, unknown_gil_state=True): + """ + Acquire the GIL. The thread's thread state must have been initialized + by a previous `put_release_gil` + """ + self.use_fast_gil_utility_code() + self.putln("#ifdef WITH_THREAD") + self.putln("__Pyx_FastGIL_Forget();") + if variable: + self.putln('_save = %s;' % variable) + if unknown_gil_state: + self.putln("if (_save) {") + self.putln("Py_BLOCK_THREADS") + if unknown_gil_state: + self.putln("}") + self.putln("#endif") + + def put_release_gil(self, variable=None, unknown_gil_state=True): + "Release the GIL, corresponds to `put_acquire_gil`." + self.use_fast_gil_utility_code() + self.putln("#ifdef WITH_THREAD") + self.putln("PyThreadState *_save;") + self.putln("_save = NULL;") + if unknown_gil_state: + # we don't *know* that we don't have the GIL (since we may be inside a nogil function, + # and Py_UNBLOCK_THREADS is unsafe without the GIL) + self.putln("if (PyGILState_Check()) {") + self.putln("Py_UNBLOCK_THREADS") + if unknown_gil_state: + self.putln("}") + if variable: + self.putln('%s = _save;' % variable) + self.putln("__Pyx_FastGIL_Remember();") + self.putln("#endif") + + def declare_gilstate(self): + self.putln("#ifdef WITH_THREAD") + self.putln("PyGILState_STATE __pyx_gilstate_save;") + self.putln("#endif") + + # error handling + + def put_error_if_neg(self, pos, value): + # TODO this path is almost _never_ taken, yet this macro makes is slower! + # return self.putln("if (unlikely(%s < 0)) %s" % (value, self.error_goto(pos))) + return self.putln("if (%s < 0) %s" % (value, self.error_goto(pos))) + + def put_error_if_unbound(self, pos, entry, in_nogil_context=False, unbound_check_code=None): + if entry.from_closure: + func = '__Pyx_RaiseClosureNameError' + self.globalstate.use_utility_code( + UtilityCode.load_cached("RaiseClosureNameError", "ObjectHandling.c")) + elif entry.type.is_memoryviewslice and in_nogil_context: + func = '__Pyx_RaiseUnboundMemoryviewSliceNogil' + self.globalstate.use_utility_code( + UtilityCode.load_cached("RaiseUnboundMemoryviewSliceNogil", "ObjectHandling.c")) + elif entry.type.is_cpp_class and entry.is_cglobal: + func = '__Pyx_RaiseCppGlobalNameError' + self.globalstate.use_utility_code( + UtilityCode.load_cached("RaiseCppGlobalNameError", "ObjectHandling.c")) + elif entry.type.is_cpp_class and entry.is_variable and not entry.is_member and entry.scope.is_c_class_scope: + # there doesn't seem to be a good way to detecting an instance-attribute of a C class + # (is_member is only set for class attributes) + func = '__Pyx_RaiseCppAttributeError' + self.globalstate.use_utility_code( + UtilityCode.load_cached("RaiseCppAttributeError", "ObjectHandling.c")) + else: + func = '__Pyx_RaiseUnboundLocalError' + self.globalstate.use_utility_code( + UtilityCode.load_cached("RaiseUnboundLocalError", "ObjectHandling.c")) + + if not unbound_check_code: + unbound_check_code = entry.type.check_for_null_code(entry.cname) + self.putln('if (unlikely(!%s)) { %s("%s"); %s }' % ( + unbound_check_code, + func, + entry.name, + self.error_goto(pos))) + + def set_error_info(self, pos, used=False): + self.funcstate.should_declare_error_indicator = True + if used: + self.funcstate.uses_error_indicator = True + return "__PYX_MARK_ERR_POS(%s, %s)" % ( + self.lookup_filename(pos[0]), + pos[1]) + + def error_goto(self, pos, used=True): + lbl = self.funcstate.error_label + self.funcstate.use_label(lbl) + if pos is None: + return 'goto %s;' % lbl + self.funcstate.should_declare_error_indicator = True + if used: + self.funcstate.uses_error_indicator = True + return "__PYX_ERR(%s, %s, %s)" % ( + self.lookup_filename(pos[0]), + pos[1], + lbl) + + def error_goto_if(self, cond, pos): + return "if (%s) %s" % (self.unlikely(cond), self.error_goto(pos)) + + def error_goto_if_null(self, cname, pos): + return self.error_goto_if("!%s" % cname, pos) + + def error_goto_if_neg(self, cname, pos): + # Add extra parentheses to silence clang warnings about constant conditions. + return self.error_goto_if("(%s < 0)" % cname, pos) + + def error_goto_if_PyErr(self, pos): + return self.error_goto_if("PyErr_Occurred()", pos) + + def lookup_filename(self, filename): + return self.globalstate.lookup_filename(filename) + + def put_declare_refcount_context(self): + self.putln('__Pyx_RefNannyDeclarations') + + def put_setup_refcount_context(self, name, acquire_gil=False): + name = name.as_c_string_literal() # handle unicode names + if acquire_gil: + self.globalstate.use_utility_code( + UtilityCode.load_cached("ForceInitThreads", "ModuleSetupCode.c")) + self.putln('__Pyx_RefNannySetupContext(%s, %d);' % (name, acquire_gil and 1 or 0)) + + def put_finish_refcount_context(self, nogil=False): + self.putln("__Pyx_RefNannyFinishContextNogil()" if nogil else "__Pyx_RefNannyFinishContext();") + + def put_add_traceback(self, qualified_name, include_cline=True): + """ + Build a Python traceback for propagating exceptions. + + qualified_name should be the qualified name of the function. + """ + qualified_name = qualified_name.as_c_string_literal() # handle unicode names + format_tuple = ( + qualified_name, + Naming.clineno_cname if include_cline else 0, + Naming.lineno_cname, + Naming.filename_cname, + ) + + self.funcstate.uses_error_indicator = True + self.putln('__Pyx_AddTraceback(%s, %s, %s, %s);' % format_tuple) + + def put_unraisable(self, qualified_name, nogil=False): + """ + Generate code to print a Python warning for an unraisable exception. + + qualified_name should be the qualified name of the function. + """ + format_tuple = ( + qualified_name, + Naming.clineno_cname, + Naming.lineno_cname, + Naming.filename_cname, + self.globalstate.directives['unraisable_tracebacks'], + nogil, + ) + self.funcstate.uses_error_indicator = True + self.putln('__Pyx_WriteUnraisable("%s", %s, %s, %s, %d, %d);' % format_tuple) + self.globalstate.use_utility_code( + UtilityCode.load_cached("WriteUnraisableException", "Exceptions.c")) + + def put_trace_declarations(self): + self.putln('__Pyx_TraceDeclarations') + + def put_trace_frame_init(self, codeobj=None): + if codeobj: + self.putln('__Pyx_TraceFrameInit(%s)' % codeobj) + + def put_trace_call(self, name, pos, nogil=False): + self.putln('__Pyx_TraceCall("%s", %s[%s], %s, %d, %s);' % ( + name, Naming.filetable_cname, self.lookup_filename(pos[0]), pos[1], nogil, self.error_goto(pos))) + + def put_trace_exception(self): + self.putln("__Pyx_TraceException();") + + def put_trace_return(self, retvalue_cname, nogil=False): + self.putln("__Pyx_TraceReturn(%s, %d);" % (retvalue_cname, nogil)) + + def putln_openmp(self, string): + self.putln("#ifdef _OPENMP") + self.putln(string) + self.putln("#endif /* _OPENMP */") + + def undef_builtin_expect(self, cond): + """ + Redefine the macros likely() and unlikely to no-ops, depending on + condition 'cond' + """ + self.putln("#if %s" % cond) + self.putln(" #undef likely") + self.putln(" #undef unlikely") + self.putln(" #define likely(x) (x)") + self.putln(" #define unlikely(x) (x)") + self.putln("#endif") + + def redef_builtin_expect(self, cond): + self.putln("#if %s" % cond) + self.putln(" #undef likely") + self.putln(" #undef unlikely") + self.putln(" #define likely(x) __builtin_expect(!!(x), 1)") + self.putln(" #define unlikely(x) __builtin_expect(!!(x), 0)") + self.putln("#endif") + + +class PyrexCodeWriter(object): + # f file output file + # level int indentation level + + def __init__(self, outfile_name): + self.f = Utils.open_new_file(outfile_name) + self.level = 0 + + def putln(self, code): + self.f.write("%s%s\n" % (" " * self.level, code)) + + def indent(self): + self.level += 1 + + def dedent(self): + self.level -= 1 + + +class PyxCodeWriter(object): + """ + Can be used for writing out some Cython code. + """ + + def __init__(self, buffer=None, indent_level=0, context=None, encoding='ascii'): + self.buffer = buffer or StringIOTree() + self.level = indent_level + self.original_level = indent_level + self.context = context + self.encoding = encoding + + def indent(self, levels=1): + self.level += levels + return True + + def dedent(self, levels=1): + self.level -= levels + + @contextmanager + def indenter(self, line): + """ + with pyx_code.indenter("for i in range(10):"): + pyx_code.putln("print i") + """ + self.putln(line) + self.indent() + yield + self.dedent() + + def empty(self): + return self.buffer.empty() + + def getvalue(self): + result = self.buffer.getvalue() + if isinstance(result, bytes): + result = result.decode(self.encoding) + return result + + def putln(self, line, context=None): + context = context or self.context + if context: + line = sub_tempita(line, context) + self._putln(line) + + def _putln(self, line): + self.buffer.write(u"%s%s\n" % (self.level * u" ", line)) + + def put_chunk(self, chunk, context=None): + context = context or self.context + if context: + chunk = sub_tempita(chunk, context) + + chunk = textwrap.dedent(chunk) + for line in chunk.splitlines(): + self._putln(line) + + def insertion_point(self): + return type(self)(self.buffer.insertion_point(), self.level, self.context) + + def reset(self): + # resets the buffer so that nothing gets written. Most useful + # for abandoning all work in a specific insertion point + self.buffer.reset() + self.level = self.original_level + + def named_insertion_point(self, name): + setattr(self, name, self.insertion_point()) + + +class ClosureTempAllocator(object): + def __init__(self, klass): + self.klass = klass + self.temps_allocated = {} + self.temps_free = {} + self.temps_count = 0 + + def reset(self): + for type, cnames in self.temps_allocated.items(): + self.temps_free[type] = list(cnames) + + def allocate_temp(self, type): + if type not in self.temps_allocated: + self.temps_allocated[type] = [] + self.temps_free[type] = [] + elif self.temps_free[type]: + return self.temps_free[type].pop(0) + cname = '%s%d' % (Naming.codewriter_temp_prefix, self.temps_count) + self.klass.declare_var(pos=None, name=cname, cname=cname, type=type, is_cdef=True) + self.temps_allocated[type].append(cname) + self.temps_count += 1 + return cname diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/FusedNode.cpython-311-x86_64-linux-gnu.so b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/FusedNode.cpython-311-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..b214764c9717e29f090654d5ad14412f9a7e8a6f --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/FusedNode.cpython-311-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:02498cb7e330a1a7ccae2b142938c3c3c01d80751d06f9ade63bac39f2ab681a +size 517064 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/Parsing.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/Parsing.py new file mode 100644 index 0000000000000000000000000000000000000000..25c0de92cf153896ae59af6cb79d920be98eba58 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/Parsing.py @@ -0,0 +1,4080 @@ +# cython: auto_cpdef=True, infer_types=True, language_level=3, py2_import=True +# +# Parser +# + +from __future__ import absolute_import + +# This should be done automatically +import cython +cython.declare(Nodes=object, ExprNodes=object, EncodedString=object, + bytes_literal=object, StringEncoding=object, + FileSourceDescriptor=object, lookup_unicodechar=object, unicode_category=object, + Future=object, Options=object, error=object, warning=object, + Builtin=object, ModuleNode=object, Utils=object, _unicode=object, _bytes=object, + re=object, sys=object, _parse_escape_sequences=object, _parse_escape_sequences_raw=object, + partial=object, reduce=object, _IS_PY3=cython.bint, _IS_2BYTE_UNICODE=cython.bint, + _CDEF_MODIFIERS=tuple, COMMON_BINOP_MISTAKES=dict) + +from io import StringIO +import re +import sys +from unicodedata import lookup as lookup_unicodechar, category as unicode_category +from functools import partial, reduce + +from .Scanning import PyrexScanner, FileSourceDescriptor, tentatively_scan +from . import Nodes +from . import ExprNodes +from . import Builtin +from . import StringEncoding +from .StringEncoding import EncodedString, bytes_literal, _unicode, _bytes +from .ModuleNode import ModuleNode +from .Errors import error, warning +from .. import Utils +from . import Future +from . import Options + +_IS_PY3 = sys.version_info[0] >= 3 +_IS_2BYTE_UNICODE = sys.maxunicode == 0xffff +_CDEF_MODIFIERS = ('inline', 'nogil', 'api') + + +class Ctx(object): + # Parsing context + level = 'other' + visibility = 'private' + cdef_flag = 0 + typedef_flag = 0 + api = 0 + overridable = 0 + nogil = 0 + namespace = None + templates = None + allow_struct_enum_decorator = False + + def __init__(self, **kwds): + self.__dict__.update(kwds) + + def __call__(self, **kwds): + ctx = Ctx() + d = ctx.__dict__ + d.update(self.__dict__) + d.update(kwds) + return ctx + + +def p_ident(s, message="Expected an identifier"): + if s.sy == 'IDENT': + name = s.context.intern_ustring(s.systring) + s.next() + return name + else: + s.error(message) + +def p_ident_list(s): + names = [] + while s.sy == 'IDENT': + names.append(s.context.intern_ustring(s.systring)) + s.next() + if s.sy != ',': + break + s.next() + return names + +#------------------------------------------ +# +# Expressions +# +#------------------------------------------ + +def p_binop_operator(s): + pos = s.position() + op = s.sy + s.next() + return op, pos + +def p_binop_expr(s, ops, p_sub_expr): + n1 = p_sub_expr(s) + while s.sy in ops: + op, pos = p_binop_operator(s) + n2 = p_sub_expr(s) + n1 = ExprNodes.binop_node(pos, op, n1, n2) + if op == '/': + if Future.division in s.context.future_directives: + n1.truedivision = True + else: + n1.truedivision = None # unknown + return n1 + +#lambdef: 'lambda' [varargslist] ':' test + +def p_lambdef(s): + # s.sy == 'lambda' + pos = s.position() + s.next() + if s.sy == ':': + args = [] + star_arg = starstar_arg = None + else: + args, star_arg, starstar_arg = p_varargslist( + s, terminator=':', annotated=False) + s.expect(':') + expr = p_test(s) + return ExprNodes.LambdaNode( + pos, args = args, + star_arg = star_arg, starstar_arg = starstar_arg, + result_expr = expr) + +#test: or_test ['if' or_test 'else' test] | lambdef + +def p_test(s): + # The check for a following ':=' is only for error reporting purposes. + # It simply changes a + # expected ')', found ':=' + # message into something a bit more descriptive. + # It is close to what the PEG parser does in CPython, where an expression has + # a lookahead assertion that it isn't followed by ':=' + expr = p_test_allow_walrus_after(s) + if s.sy == ':=': + s.error("invalid syntax: assignment expression not allowed in this context") + return expr + +def p_test_allow_walrus_after(s): + if s.sy == 'lambda': + return p_lambdef(s) + pos = s.position() + expr = p_or_test(s) + if s.sy == 'if': + s.next() + test = p_or_test(s) + s.expect('else') + other = p_test(s) + return ExprNodes.CondExprNode(pos, test=test, true_val=expr, false_val=other) + else: + return expr + +def p_namedexpr_test(s): + # defined in the LL parser as + # namedexpr_test: test [':=' test] + # The requirement that the LHS is a name is not enforced in the grammar. + # For comparison the PEG parser does: + # 1. look for "name :=", if found it's definitely a named expression + # so look for expression + # 2. Otherwise, look for expression + lhs = p_test_allow_walrus_after(s) + if s.sy == ':=': + position = s.position() + if not lhs.is_name: + s.error("Left-hand side of assignment expression must be an identifier", fatal=False) + s.next() + rhs = p_test(s) + return ExprNodes.AssignmentExpressionNode(position, lhs=lhs, rhs=rhs) + return lhs + + +#or_test: and_test ('or' and_test)* + +COMMON_BINOP_MISTAKES = {'||': 'or', '&&': 'and'} + +def p_or_test(s): + return p_rassoc_binop_expr(s, u'or', p_and_test) + +def p_rassoc_binop_expr(s, op, p_subexpr): + n1 = p_subexpr(s) + if s.sy == op: + pos = s.position() + op = s.sy + s.next() + n2 = p_rassoc_binop_expr(s, op, p_subexpr) + n1 = ExprNodes.binop_node(pos, op, n1, n2) + elif s.sy in COMMON_BINOP_MISTAKES and COMMON_BINOP_MISTAKES[s.sy] == op: + # Only report this for the current operator since we pass through here twice for 'and' and 'or'. + warning(s.position(), + "Found the C operator '%s', did you mean the Python operator '%s'?" % (s.sy, op), + level=1) + return n1 + +#and_test: not_test ('and' not_test)* + +def p_and_test(s): + #return p_binop_expr(s, ('and',), p_not_test) + return p_rassoc_binop_expr(s, u'and', p_not_test) + +#not_test: 'not' not_test | comparison + +def p_not_test(s): + if s.sy == 'not': + pos = s.position() + s.next() + return ExprNodes.NotNode(pos, operand = p_not_test(s)) + else: + return p_comparison(s) + +#comparison: expr (comp_op expr)* +#comp_op: '<'|'>'|'=='|'>='|'<='|'<>'|'!='|'in'|'not' 'in'|'is'|'is' 'not' + +def p_comparison(s): + n1 = p_starred_expr(s) + if s.sy in comparison_ops: + pos = s.position() + op = p_cmp_op(s) + n2 = p_starred_expr(s) + n1 = ExprNodes.PrimaryCmpNode(pos, + operator = op, operand1 = n1, operand2 = n2) + if s.sy in comparison_ops: + n1.cascade = p_cascaded_cmp(s) + return n1 + +def p_test_or_starred_expr(s): + if s.sy == '*': + return p_starred_expr(s) + else: + return p_test(s) + +def p_namedexpr_test_or_starred_expr(s): + if s.sy == '*': + return p_starred_expr(s) + else: + return p_namedexpr_test(s) + +def p_starred_expr(s): + pos = s.position() + if s.sy == '*': + starred = True + s.next() + else: + starred = False + expr = p_bit_expr(s) + if starred: + expr = ExprNodes.StarredUnpackingNode(pos, expr) + return expr + +def p_cascaded_cmp(s): + pos = s.position() + op = p_cmp_op(s) + n2 = p_starred_expr(s) + result = ExprNodes.CascadedCmpNode(pos, + operator = op, operand2 = n2) + if s.sy in comparison_ops: + result.cascade = p_cascaded_cmp(s) + return result + +def p_cmp_op(s): + if s.sy == 'not': + s.next() + s.expect('in') + op = 'not_in' + elif s.sy == 'is': + s.next() + if s.sy == 'not': + s.next() + op = 'is_not' + else: + op = 'is' + else: + op = s.sy + s.next() + if op == '<>': + op = '!=' + return op + +comparison_ops = cython.declare(frozenset, frozenset(( + '<', '>', '==', '>=', '<=', '<>', '!=', + 'in', 'is', 'not' +))) + +#expr: xor_expr ('|' xor_expr)* + +def p_bit_expr(s): + return p_binop_expr(s, ('|',), p_xor_expr) + +#xor_expr: and_expr ('^' and_expr)* + +def p_xor_expr(s): + return p_binop_expr(s, ('^',), p_and_expr) + +#and_expr: shift_expr ('&' shift_expr)* + +def p_and_expr(s): + return p_binop_expr(s, ('&',), p_shift_expr) + +#shift_expr: arith_expr (('<<'|'>>') arith_expr)* + +def p_shift_expr(s): + return p_binop_expr(s, ('<<', '>>'), p_arith_expr) + +#arith_expr: term (('+'|'-') term)* + +def p_arith_expr(s): + return p_binop_expr(s, ('+', '-'), p_term) + +#term: factor (('*'|'@'|'/'|'%'|'//') factor)* + +def p_term(s): + return p_binop_expr(s, ('*', '@', '/', '%', '//'), p_factor) + +#factor: ('+'|'-'|'~'|'&'|typecast|sizeof) factor | power + +def p_factor(s): + # little indirection for C-ification purposes + return _p_factor(s) + +def _p_factor(s): + sy = s.sy + if sy in ('+', '-', '~'): + op = s.sy + pos = s.position() + s.next() + return ExprNodes.unop_node(pos, op, p_factor(s)) + elif not s.in_python_file: + if sy == '&': + pos = s.position() + s.next() + arg = p_factor(s) + return ExprNodes.AmpersandNode(pos, operand = arg) + elif sy == "<": + return p_typecast(s) + elif sy == 'IDENT' and s.systring == "sizeof": + return p_sizeof(s) + return p_power(s) + +def p_typecast(s): + # s.sy == "<" + pos = s.position() + s.next() + base_type = p_c_base_type(s) + is_memslice = isinstance(base_type, Nodes.MemoryViewSliceTypeNode) + is_other_unnamed_type = isinstance(base_type, ( + Nodes.TemplatedTypeNode, + Nodes.CConstOrVolatileTypeNode, + Nodes.CTupleBaseTypeNode, + )) + if not (is_memslice or is_other_unnamed_type) and base_type.name is None: + s.error("Unknown type") + declarator = p_c_declarator(s, empty = 1) + if s.sy == '?': + s.next() + typecheck = 1 + else: + typecheck = 0 + s.expect(">") + operand = p_factor(s) + if is_memslice: + return ExprNodes.CythonArrayNode(pos, base_type_node=base_type, operand=operand) + + return ExprNodes.TypecastNode(pos, + base_type = base_type, + declarator = declarator, + operand = operand, + typecheck = typecheck) + +def p_sizeof(s): + # s.sy == ident "sizeof" + pos = s.position() + s.next() + s.expect('(') + # Here we decide if we are looking at an expression or type + # If it is actually a type, but parsable as an expression, + # we treat it as an expression here. + if looking_at_expr(s): + operand = p_test(s) + node = ExprNodes.SizeofVarNode(pos, operand = operand) + else: + base_type = p_c_base_type(s) + declarator = p_c_declarator(s, empty = 1) + node = ExprNodes.SizeofTypeNode(pos, + base_type = base_type, declarator = declarator) + s.expect(')') + return node + + +def p_yield_expression(s): + # s.sy == "yield" + pos = s.position() + s.next() + is_yield_from = False + if s.sy == 'from': + is_yield_from = True + s.next() + if s.sy != ')' and s.sy not in statement_terminators: + # "yield from" does not support implicit tuples, but "yield" does ("yield 1,2") + arg = p_test(s) if is_yield_from else p_testlist(s) + else: + if is_yield_from: + s.error("'yield from' requires a source argument", + pos=pos, fatal=False) + arg = None + if is_yield_from: + return ExprNodes.YieldFromExprNode(pos, arg=arg) + else: + return ExprNodes.YieldExprNode(pos, arg=arg) + + +def p_yield_statement(s): + # s.sy == "yield" + yield_expr = p_yield_expression(s) + return Nodes.ExprStatNode(yield_expr.pos, expr=yield_expr) + + +def p_async_statement(s, ctx, decorators): + # s.sy >> 'async' ... + if s.sy == 'def': + # 'async def' statements aren't allowed in pxd files + if 'pxd' in ctx.level: + s.error('def statement not allowed here') + s.level = ctx.level + return p_def_statement(s, decorators, is_async_def=True) + elif decorators: + s.error("Decorators can only be followed by functions or classes") + elif s.sy == 'for': + return p_for_statement(s, is_async=True) + elif s.sy == 'with': + s.next() + return p_with_items(s, is_async=True) + else: + s.error("expected one of 'def', 'for', 'with' after 'async'") + + +#power: atom_expr ('**' factor)* +#atom_expr: ['await'] atom trailer* + +def p_power(s): + if s.systring == 'new' and s.peek()[0] == 'IDENT': + return p_new_expr(s) + await_pos = None + if s.sy == 'await': + await_pos = s.position() + s.next() + n1 = p_atom(s) + while s.sy in ('(', '[', '.'): + n1 = p_trailer(s, n1) + if await_pos: + n1 = ExprNodes.AwaitExprNode(await_pos, arg=n1) + if s.sy == '**': + pos = s.position() + s.next() + n2 = p_factor(s) + n1 = ExprNodes.binop_node(pos, '**', n1, n2) + return n1 + + +def p_new_expr(s): + # s.systring == 'new'. + pos = s.position() + s.next() + cppclass = p_c_base_type(s) + return p_call(s, ExprNodes.NewExprNode(pos, cppclass = cppclass)) + +#trailer: '(' [arglist] ')' | '[' subscriptlist ']' | '.' NAME + +def p_trailer(s, node1): + pos = s.position() + if s.sy == '(': + return p_call(s, node1) + elif s.sy == '[': + return p_index(s, node1) + else: # s.sy == '.' + s.next() + name = p_ident(s) + return ExprNodes.AttributeNode(pos, + obj=node1, attribute=name) + + +# arglist: argument (',' argument)* [','] +# argument: [test '='] test # Really [keyword '='] test + +# since PEP 448: +# argument: ( test [comp_for] | +# test '=' test | +# '**' expr | +# star_expr ) + +def p_call_parse_args(s, allow_genexp=True): + # s.sy == '(' + pos = s.position() + s.next() + positional_args = [] + keyword_args = [] + starstar_seen = False + last_was_tuple_unpack = False + while s.sy != ')': + if s.sy == '*': + if starstar_seen: + s.error("Non-keyword arg following keyword arg", pos=s.position()) + s.next() + positional_args.append(p_test(s)) + last_was_tuple_unpack = True + elif s.sy == '**': + s.next() + keyword_args.append(p_test(s)) + starstar_seen = True + else: + arg = p_namedexpr_test(s) + if s.sy == '=': + s.next() + if not arg.is_name: + s.error("Expected an identifier before '='", + pos=arg.pos) + encoded_name = s.context.intern_ustring(arg.name) + keyword = ExprNodes.IdentifierStringNode( + arg.pos, value=encoded_name) + arg = p_test(s) + keyword_args.append((keyword, arg)) + else: + if keyword_args: + s.error("Non-keyword arg following keyword arg", pos=arg.pos) + if positional_args and not last_was_tuple_unpack: + positional_args[-1].append(arg) + else: + positional_args.append([arg]) + last_was_tuple_unpack = False + if s.sy != ',': + break + s.next() + + if s.sy in ('for', 'async'): + if not keyword_args and not last_was_tuple_unpack: + if len(positional_args) == 1 and len(positional_args[0]) == 1: + positional_args = [[p_genexp(s, positional_args[0][0])]] + s.expect(')') + return positional_args or [[]], keyword_args + + +def p_call_build_packed_args(pos, positional_args, keyword_args): + keyword_dict = None + + subtuples = [ + ExprNodes.TupleNode(pos, args=arg) if isinstance(arg, list) else ExprNodes.AsTupleNode(pos, arg=arg) + for arg in positional_args + ] + # TODO: implement a faster way to join tuples than creating each one and adding them + arg_tuple = reduce(partial(ExprNodes.binop_node, pos, '+'), subtuples) + + if keyword_args: + kwargs = [] + dict_items = [] + for item in keyword_args: + if isinstance(item, tuple): + key, value = item + dict_items.append(ExprNodes.DictItemNode(pos=key.pos, key=key, value=value)) + elif item.is_dict_literal: + # unpack "**{a:b}" directly + dict_items.extend(item.key_value_pairs) + else: + if dict_items: + kwargs.append(ExprNodes.DictNode( + dict_items[0].pos, key_value_pairs=dict_items, reject_duplicates=True)) + dict_items = [] + kwargs.append(item) + + if dict_items: + kwargs.append(ExprNodes.DictNode( + dict_items[0].pos, key_value_pairs=dict_items, reject_duplicates=True)) + + if kwargs: + if len(kwargs) == 1 and kwargs[0].is_dict_literal: + # only simple keyword arguments found -> one dict + keyword_dict = kwargs[0] + else: + # at least one **kwargs + keyword_dict = ExprNodes.MergedDictNode(pos, keyword_args=kwargs) + + return arg_tuple, keyword_dict + + +def p_call(s, function): + # s.sy == '(' + pos = s.position() + positional_args, keyword_args = p_call_parse_args(s) + + if not keyword_args and len(positional_args) == 1 and isinstance(positional_args[0], list): + return ExprNodes.SimpleCallNode(pos, function=function, args=positional_args[0]) + else: + arg_tuple, keyword_dict = p_call_build_packed_args(pos, positional_args, keyword_args) + return ExprNodes.GeneralCallNode( + pos, function=function, positional_args=arg_tuple, keyword_args=keyword_dict) + + +#lambdef: 'lambda' [varargslist] ':' test + +#subscriptlist: subscript (',' subscript)* [','] + +def p_index(s, base): + # s.sy == '[' + pos = s.position() + s.next() + subscripts, is_single_value = p_subscript_list(s) + if is_single_value and len(subscripts[0]) == 2: + start, stop = subscripts[0] + result = ExprNodes.SliceIndexNode(pos, + base = base, start = start, stop = stop) + else: + indexes = make_slice_nodes(pos, subscripts) + if is_single_value: + index = indexes[0] + else: + index = ExprNodes.TupleNode(pos, args = indexes) + result = ExprNodes.IndexNode(pos, + base = base, index = index) + s.expect(']') + return result + +def p_subscript_list(s): + is_single_value = True + items = [p_subscript(s)] + while s.sy == ',': + is_single_value = False + s.next() + if s.sy == ']': + break + items.append(p_subscript(s)) + return items, is_single_value + +#subscript: '.' '.' '.' | test | [test] ':' [test] [':' [test]] + +def p_subscript(s): + # Parse a subscript and return a list of + # 1, 2 or 3 ExprNodes, depending on how + # many slice elements were encountered. + pos = s.position() + start = p_slice_element(s, (':',)) + if s.sy != ':': + return [start] + s.next() + stop = p_slice_element(s, (':', ',', ']')) + if s.sy != ':': + return [start, stop] + s.next() + step = p_slice_element(s, (':', ',', ']')) + return [start, stop, step] + +def p_slice_element(s, follow_set): + # Simple expression which may be missing iff + # it is followed by something in follow_set. + if s.sy not in follow_set: + return p_test(s) + else: + return None + +def expect_ellipsis(s): + s.expect('...') + +def make_slice_nodes(pos, subscripts): + # Convert a list of subscripts as returned + # by p_subscript_list into a list of ExprNodes, + # creating SliceNodes for elements with 2 or + # more components. + result = [] + for subscript in subscripts: + if len(subscript) == 1: + result.append(subscript[0]) + else: + result.append(make_slice_node(pos, *subscript)) + return result + +def make_slice_node(pos, start, stop = None, step = None): + if not start: + start = ExprNodes.NoneNode(pos) + if not stop: + stop = ExprNodes.NoneNode(pos) + if not step: + step = ExprNodes.NoneNode(pos) + return ExprNodes.SliceNode(pos, + start = start, stop = stop, step = step) + +#atom: '(' [yield_expr|testlist_comp] ')' | '[' [listmaker] ']' | '{' [dict_or_set_maker] '}' | '`' testlist '`' | NAME | NUMBER | STRING+ + +def p_atom(s): + pos = s.position() + sy = s.sy + if sy == '(': + s.next() + if s.sy == ')': + result = ExprNodes.TupleNode(pos, args = []) + elif s.sy == 'yield': + result = p_yield_expression(s) + else: + result = p_testlist_comp(s) + s.expect(')') + return result + elif sy == '[': + return p_list_maker(s) + elif sy == '{': + return p_dict_or_set_maker(s) + elif sy == '`': + return p_backquote_expr(s) + elif sy == '...': + expect_ellipsis(s) + return ExprNodes.EllipsisNode(pos) + elif sy == 'INT': + return p_int_literal(s) + elif sy == 'FLOAT': + value = s.systring + s.next() + return ExprNodes.FloatNode(pos, value = value) + elif sy == 'IMAG': + value = s.systring[:-1] + s.next() + return ExprNodes.ImagNode(pos, value = value) + elif sy == 'BEGIN_STRING': + kind, bytes_value, unicode_value = p_cat_string_literal(s) + if kind == 'c': + return ExprNodes.CharNode(pos, value = bytes_value) + elif kind == 'u': + return ExprNodes.UnicodeNode(pos, value = unicode_value, bytes_value = bytes_value) + elif kind == 'b': + return ExprNodes.BytesNode(pos, value = bytes_value) + elif kind == 'f': + return ExprNodes.JoinedStrNode(pos, values = unicode_value) + elif kind == '': + return ExprNodes.StringNode(pos, value = bytes_value, unicode_value = unicode_value) + else: + s.error("invalid string kind '%s'" % kind) + elif sy == 'IDENT': + name = s.systring + if name == "None": + result = ExprNodes.NoneNode(pos) + elif name == "True": + result = ExprNodes.BoolNode(pos, value=True) + elif name == "False": + result = ExprNodes.BoolNode(pos, value=False) + elif name == "NULL" and not s.in_python_file: + result = ExprNodes.NullNode(pos) + else: + result = p_name(s, name) + s.next() + return result + else: + s.error("Expected an identifier or literal") + +def p_int_literal(s): + pos = s.position() + value = s.systring + s.next() + unsigned = "" + longness = "" + while value[-1] in u"UuLl": + if value[-1] in u"Ll": + longness += "L" + else: + unsigned += "U" + value = value[:-1] + # '3L' is ambiguous in Py2 but not in Py3. '3U' and '3LL' are + # illegal in Py2 Python files. All suffixes are illegal in Py3 + # Python files. + is_c_literal = None + if unsigned: + is_c_literal = True + elif longness: + if longness == 'LL' or s.context.language_level >= 3: + is_c_literal = True + if s.in_python_file: + if is_c_literal: + error(pos, "illegal integer literal syntax in Python source file") + is_c_literal = False + return ExprNodes.IntNode(pos, + is_c_literal = is_c_literal, + value = value, + unsigned = unsigned, + longness = longness) + + +def p_name(s, name): + pos = s.position() + if not s.compile_time_expr and name in s.compile_time_env: + value = s.compile_time_env.lookup_here(name) + node = wrap_compile_time_constant(pos, value) + if node is not None: + return node + return ExprNodes.NameNode(pos, name=name) + + +def wrap_compile_time_constant(pos, value): + rep = repr(value) + if value is None: + return ExprNodes.NoneNode(pos) + elif value is Ellipsis: + return ExprNodes.EllipsisNode(pos) + elif isinstance(value, bool): + return ExprNodes.BoolNode(pos, value=value) + elif isinstance(value, int): + return ExprNodes.IntNode(pos, value=rep, constant_result=value) + elif isinstance(value, float): + return ExprNodes.FloatNode(pos, value=rep, constant_result=value) + elif isinstance(value, complex): + node = ExprNodes.ImagNode(pos, value=repr(value.imag), constant_result=complex(0.0, value.imag)) + if value.real: + # FIXME: should we care about -0.0 ? + # probably not worth using the '-' operator for negative imag values + node = ExprNodes.binop_node( + pos, '+', ExprNodes.FloatNode(pos, value=repr(value.real), constant_result=value.real), node, + constant_result=value) + return node + elif isinstance(value, _unicode): + return ExprNodes.UnicodeNode(pos, value=EncodedString(value)) + elif isinstance(value, _bytes): + bvalue = bytes_literal(value, 'ascii') # actually: unknown encoding, but BytesLiteral requires one + return ExprNodes.BytesNode(pos, value=bvalue, constant_result=value) + elif isinstance(value, tuple): + args = [wrap_compile_time_constant(pos, arg) + for arg in value] + if None not in args: + return ExprNodes.TupleNode(pos, args=args) + else: + # error already reported + return None + elif not _IS_PY3 and isinstance(value, long): + return ExprNodes.IntNode(pos, value=rep.rstrip('L'), constant_result=value) + error(pos, "Invalid type for compile-time constant: %r (type %s)" + % (value, value.__class__.__name__)) + return None + + +def p_cat_string_literal(s): + # A sequence of one or more adjacent string literals. + # Returns (kind, bytes_value, unicode_value) + # where kind in ('b', 'c', 'u', 'f', '') + pos = s.position() + kind, bytes_value, unicode_value = p_string_literal(s) + if kind == 'c' or s.sy != 'BEGIN_STRING': + return kind, bytes_value, unicode_value + bstrings, ustrings, positions = [bytes_value], [unicode_value], [pos] + bytes_value = unicode_value = None + while s.sy == 'BEGIN_STRING': + pos = s.position() + next_kind, next_bytes_value, next_unicode_value = p_string_literal(s) + if next_kind == 'c': + error(pos, "Cannot concatenate char literal with another string or char literal") + continue + elif next_kind != kind: + # concatenating f strings and normal strings is allowed and leads to an f string + if {kind, next_kind} in ({'f', 'u'}, {'f', ''}): + kind = 'f' + else: + error(pos, "Cannot mix string literals of different types, expected %s'', got %s''" % ( + kind, next_kind)) + continue + bstrings.append(next_bytes_value) + ustrings.append(next_unicode_value) + positions.append(pos) + # join and rewrap the partial literals + if kind in ('b', 'c', '') or kind == 'u' and None not in bstrings: + # Py3 enforced unicode literals are parsed as bytes/unicode combination + bytes_value = bytes_literal(StringEncoding.join_bytes(bstrings), s.source_encoding) + if kind in ('u', ''): + unicode_value = EncodedString(u''.join([u for u in ustrings if u is not None])) + if kind == 'f': + unicode_value = [] + for u, pos in zip(ustrings, positions): + if isinstance(u, list): + unicode_value += u + else: + # non-f-string concatenated into the f-string + unicode_value.append(ExprNodes.UnicodeNode(pos, value=EncodedString(u))) + return kind, bytes_value, unicode_value + + +def p_opt_string_literal(s, required_type='u'): + if s.sy != 'BEGIN_STRING': + return None + pos = s.position() + kind, bytes_value, unicode_value = p_string_literal(s, required_type) + if required_type == 'u': + if kind == 'f': + s.error("f-string not allowed here", pos) + return unicode_value + elif required_type == 'b': + return bytes_value + else: + s.error("internal parser configuration error") + + +def check_for_non_ascii_characters(string): + for c in string: + if c >= u'\x80': + return True + return False + + +def p_string_literal(s, kind_override=None): + # A single string or char literal. Returns (kind, bvalue, uvalue) + # where kind in ('b', 'c', 'u', 'f', ''). The 'bvalue' is the source + # code byte sequence of the string literal, 'uvalue' is the + # decoded Unicode string. Either of the two may be None depending + # on the 'kind' of string, only unprefixed strings have both + # representations. In f-strings, the uvalue is a list of the Unicode + # strings and f-string expressions that make up the f-string. + + # s.sy == 'BEGIN_STRING' + pos = s.position() + is_python3_source = s.context.language_level >= 3 + has_non_ascii_literal_characters = False + string_start_pos = (pos[0], pos[1], pos[2] + len(s.systring)) + kind_string = s.systring.rstrip('"\'').lower() + if len(kind_string) > 1: + if len(set(kind_string)) != len(kind_string): + error(pos, 'Duplicate string prefix character') + if 'b' in kind_string and 'u' in kind_string: + error(pos, 'String prefixes b and u cannot be combined') + if 'b' in kind_string and 'f' in kind_string: + error(pos, 'String prefixes b and f cannot be combined') + if 'u' in kind_string and 'f' in kind_string: + error(pos, 'String prefixes u and f cannot be combined') + + is_raw = 'r' in kind_string + + if 'c' in kind_string: + # this should never happen, since the lexer does not allow combining c + # with other prefix characters + if len(kind_string) != 1: + error(pos, 'Invalid string prefix for character literal') + kind = 'c' + elif 'f' in kind_string: + kind = 'f' # u is ignored + is_raw = True # postpone the escape resolution + elif 'b' in kind_string: + kind = 'b' + elif 'u' in kind_string: + kind = 'u' + else: + kind = '' + + if kind == '' and kind_override is None and Future.unicode_literals in s.context.future_directives: + chars = StringEncoding.StrLiteralBuilder(s.source_encoding) + kind = 'u' + else: + if kind_override is not None and kind_override in 'ub': + kind = kind_override + if kind in ('u', 'f'): # f-strings are scanned exactly like Unicode literals, but are parsed further later + chars = StringEncoding.UnicodeLiteralBuilder() + elif kind == '': + chars = StringEncoding.StrLiteralBuilder(s.source_encoding) + else: + chars = StringEncoding.BytesLiteralBuilder(s.source_encoding) + + while 1: + s.next() + sy = s.sy + systr = s.systring + # print "p_string_literal: sy =", sy, repr(s.systring) ### + if sy == 'CHARS': + chars.append(systr) + if is_python3_source and not has_non_ascii_literal_characters and check_for_non_ascii_characters(systr): + has_non_ascii_literal_characters = True + elif sy == 'ESCAPE': + # in Py2, 'ur' raw unicode strings resolve unicode escapes but nothing else + if is_raw and (is_python3_source or kind != 'u' or systr[1] not in u'Uu'): + chars.append(systr) + if is_python3_source and not has_non_ascii_literal_characters and check_for_non_ascii_characters(systr): + has_non_ascii_literal_characters = True + else: + _append_escape_sequence(kind, chars, systr, s) + elif sy == 'NEWLINE': + chars.append(u'\n') + elif sy == 'END_STRING': + break + elif sy == 'EOF': + s.error("Unclosed string literal", pos=pos) + else: + s.error("Unexpected token %r:%r in string literal" % ( + sy, s.systring)) + + if kind == 'c': + unicode_value = None + bytes_value = chars.getchar() + if len(bytes_value) != 1: + error(pos, u"invalid character literal: %r" % bytes_value) + else: + bytes_value, unicode_value = chars.getstrings() + if (has_non_ascii_literal_characters + and is_python3_source and Future.unicode_literals in s.context.future_directives): + # Python 3 forbids literal non-ASCII characters in byte strings + if kind == 'b': + s.error("bytes can only contain ASCII literal characters.", pos=pos) + bytes_value = None + if kind == 'f': + unicode_value = p_f_string(s, unicode_value, string_start_pos, is_raw='r' in kind_string) + s.next() + return (kind, bytes_value, unicode_value) + + +def _append_escape_sequence(kind, builder, escape_sequence, s): + c = escape_sequence[1] + if c in u"01234567": + builder.append_charval(int(escape_sequence[1:], 8)) + elif c in u"'\"\\": + builder.append(c) + elif c in u"abfnrtv": + builder.append(StringEncoding.char_from_escape_sequence(escape_sequence)) + elif c == u'\n': + pass # line continuation + elif c == u'x': # \xXX + if len(escape_sequence) == 4: + builder.append_charval(int(escape_sequence[2:], 16)) + else: + s.error("Invalid hex escape '%s'" % escape_sequence, fatal=False) + elif c in u'NUu' and kind in ('u', 'f', ''): # \uxxxx, \Uxxxxxxxx, \N{...} + chrval = -1 + if c == u'N': + uchar = None + try: + uchar = lookup_unicodechar(escape_sequence[3:-1]) + chrval = ord(uchar) + except KeyError: + s.error("Unknown Unicode character name %s" % + repr(escape_sequence[3:-1]).lstrip('u'), fatal=False) + except TypeError: + # 2-byte unicode build of CPython? + if (uchar is not None and _IS_2BYTE_UNICODE and len(uchar) == 2 and + unicode_category(uchar[0]) == 'Cs' and unicode_category(uchar[1]) == 'Cs'): + # surrogate pair instead of single character + chrval = 0x10000 + (ord(uchar[0]) - 0xd800) >> 10 + (ord(uchar[1]) - 0xdc00) + else: + raise + elif len(escape_sequence) in (6, 10): + chrval = int(escape_sequence[2:], 16) + if chrval > 1114111: # sys.maxunicode: + s.error("Invalid unicode escape '%s'" % escape_sequence) + chrval = -1 + else: + s.error("Invalid unicode escape '%s'" % escape_sequence, fatal=False) + if chrval >= 0: + builder.append_uescape(chrval, escape_sequence) + else: + builder.append(escape_sequence) + + +_parse_escape_sequences_raw, _parse_escape_sequences = [re.compile(( + # escape sequences: + br'(\\(?:' + + (br'\\?' if is_raw else ( + br'[\\abfnrtv"\'{]|' + br'[0-7]{2,3}|' + br'N\{[^}]*\}|' + br'x[0-9a-fA-F]{2}|' + br'u[0-9a-fA-F]{4}|' + br'U[0-9a-fA-F]{8}|' + br'[NxuU]|' # detect invalid escape sequences that do not match above + )) + + br')?|' + # non-escape sequences: + br'\{\{?|' + br'\}\}?|' + br'[^\\{}]+)' + ).decode('us-ascii')).match + for is_raw in (True, False)] + + +def _f_string_error_pos(pos, string, i): + return (pos[0], pos[1], pos[2] + i + 1) # FIXME: handle newlines in string + + +def p_f_string(s, unicode_value, pos, is_raw): + # Parses a PEP 498 f-string literal into a list of nodes. Nodes are either UnicodeNodes + # or FormattedValueNodes. + values = [] + next_start = 0 + size = len(unicode_value) + builder = StringEncoding.UnicodeLiteralBuilder() + _parse_seq = _parse_escape_sequences_raw if is_raw else _parse_escape_sequences + + while next_start < size: + end = next_start + match = _parse_seq(unicode_value, next_start) + if match is None: + error(_f_string_error_pos(pos, unicode_value, next_start), "Invalid escape sequence") + + next_start = match.end() + part = match.group() + c = part[0] + if c == '\\': + if not is_raw and len(part) > 1: + _append_escape_sequence('f', builder, part, s) + else: + builder.append(part) + elif c == '{': + if part == '{{': + builder.append('{') + else: + # start of an expression + if builder.chars: + values.append(ExprNodes.UnicodeNode(pos, value=builder.getstring())) + builder = StringEncoding.UnicodeLiteralBuilder() + next_start, expr_nodes = p_f_string_expr(s, unicode_value, pos, next_start, is_raw) + values.extend(expr_nodes) + elif c == '}': + if part == '}}': + builder.append('}') + else: + error(_f_string_error_pos(pos, unicode_value, end), + "f-string: single '}' is not allowed") + else: + builder.append(part) + + if builder.chars: + values.append(ExprNodes.UnicodeNode(pos, value=builder.getstring())) + return values + + +def p_f_string_expr(s, unicode_value, pos, starting_index, is_raw): + # Parses a {}-delimited expression inside an f-string. Returns a list of nodes + # [UnicodeNode?, FormattedValueNode] and the index in the string that follows + # the expression. + # + # ? = Optional + i = starting_index + size = len(unicode_value) + conversion_char = terminal_char = format_spec = None + format_spec_str = None + expr_text = None + NO_CHAR = 2**30 + + nested_depth = 0 + quote_char = NO_CHAR + in_triple_quotes = False + backslash_reported = False + + while True: + if i >= size: + break # error will be reported below + c = unicode_value[i] + + if quote_char != NO_CHAR: + if c == '\\': + # avoid redundant error reports along '\' sequences + if not backslash_reported: + error(_f_string_error_pos(pos, unicode_value, i), + "backslashes not allowed in f-strings") + backslash_reported = True + elif c == quote_char: + if in_triple_quotes: + if i + 2 < size and unicode_value[i + 1] == c and unicode_value[i + 2] == c: + in_triple_quotes = False + quote_char = NO_CHAR + i += 2 + else: + quote_char = NO_CHAR + elif c in '\'"': + quote_char = c + if i + 2 < size and unicode_value[i + 1] == c and unicode_value[i + 2] == c: + in_triple_quotes = True + i += 2 + elif c in '{[(': + nested_depth += 1 + elif nested_depth != 0 and c in '}])': + nested_depth -= 1 + elif c == '#': + error(_f_string_error_pos(pos, unicode_value, i), + "format string cannot include #") + elif nested_depth == 0 and c in '><=!:}': + # allow special cases with '!' and '=' + if i + 1 < size and c in '!=><': + if unicode_value[i + 1] == '=': + i += 2 # we checked 2, so we can skip 2: '!=', '==', '>=', '<=' + continue + elif c in '><': # allow single '<' and '>' + i += 1 + continue + terminal_char = c + break + i += 1 + + # normalise line endings as the parser expects that + expr_str = unicode_value[starting_index:i].replace('\r\n', '\n').replace('\r', '\n') + expr_pos = (pos[0], pos[1], pos[2] + starting_index + 2) # TODO: find exact code position (concat, multi-line, ...) + + if not expr_str.strip(): + error(_f_string_error_pos(pos, unicode_value, starting_index), + "empty expression not allowed in f-string") + + if terminal_char == '=': + i += 1 + while i < size and unicode_value[i].isspace(): + i += 1 + + if i < size: + terminal_char = unicode_value[i] + expr_text = unicode_value[starting_index:i] + # otherwise: error will be reported below + + if terminal_char == '!': + i += 1 + if i + 2 > size: + pass # error will be reported below + else: + conversion_char = unicode_value[i] + i += 1 + terminal_char = unicode_value[i] + + if terminal_char == ':': + in_triple_quotes = False + in_string = False + nested_depth = 0 + start_format_spec = i + 1 + while True: + if i >= size: + break # error will be reported below + c = unicode_value[i] + if not in_triple_quotes and not in_string: + if c == '{': + nested_depth += 1 + elif c == '}': + if nested_depth > 0: + nested_depth -= 1 + else: + terminal_char = c + break + if c in '\'"': + if not in_string and i + 2 < size and unicode_value[i + 1] == c and unicode_value[i + 2] == c: + in_triple_quotes = not in_triple_quotes + i += 2 + elif not in_triple_quotes: + in_string = not in_string + i += 1 + + format_spec_str = unicode_value[start_format_spec:i] + + if expr_text and conversion_char is None and format_spec_str is None: + conversion_char = 'r' + + if terminal_char != '}': + error(_f_string_error_pos(pos, unicode_value, i), + "missing '}' in format string expression" + ( + ", found '%s'" % terminal_char if terminal_char else "")) + + # parse the expression as if it was surrounded by parentheses + buf = StringIO('(%s)' % expr_str) + scanner = PyrexScanner(buf, expr_pos[0], parent_scanner=s, source_encoding=s.source_encoding, initial_pos=expr_pos) + expr = p_testlist(scanner) # TODO is testlist right here? + + # validate the conversion char + if conversion_char is not None and not ExprNodes.FormattedValueNode.find_conversion_func(conversion_char): + error(expr_pos, "invalid conversion character '%s'" % conversion_char) + + # the format spec is itself treated like an f-string + if format_spec_str: + format_spec = ExprNodes.JoinedStrNode(pos, values=p_f_string(s, format_spec_str, pos, is_raw)) + + nodes = [] + if expr_text: + nodes.append(ExprNodes.UnicodeNode(pos, value=StringEncoding.EncodedString(expr_text))) + nodes.append(ExprNodes.FormattedValueNode(pos, value=expr, conversion_char=conversion_char, format_spec=format_spec)) + + return i + 1, nodes + + +# since PEP 448: +# list_display ::= "[" [listmaker] "]" +# listmaker ::= (named_test|star_expr) ( comp_for | (',' (named_test|star_expr))* [','] ) +# comp_iter ::= comp_for | comp_if +# comp_for ::= ["async"] "for" expression_list "in" testlist [comp_iter] +# comp_if ::= "if" test [comp_iter] + +def p_list_maker(s): + # s.sy == '[' + pos = s.position() + s.next() + if s.sy == ']': + s.expect(']') + return ExprNodes.ListNode(pos, args=[]) + + expr = p_namedexpr_test_or_starred_expr(s) + if s.sy in ('for', 'async'): + if expr.is_starred: + s.error("iterable unpacking cannot be used in comprehension") + append = ExprNodes.ComprehensionAppendNode(pos, expr=expr) + loop = p_comp_for(s, append) + s.expect(']') + return ExprNodes.ComprehensionNode( + pos, loop=loop, append=append, type=Builtin.list_type, + # list comprehensions leak their loop variable in Py2 + has_local_scope=s.context.language_level >= 3) + + # (merged) list literal + if s.sy == ',': + s.next() + exprs = p_namedexpr_test_or_starred_expr_list(s, expr) + else: + exprs = [expr] + s.expect(']') + return ExprNodes.ListNode(pos, args=exprs) + + +def p_comp_iter(s, body): + if s.sy in ('for', 'async'): + return p_comp_for(s, body) + elif s.sy == 'if': + return p_comp_if(s, body) + else: + # insert the 'append' operation into the loop + return body + +def p_comp_for(s, body): + pos = s.position() + # [async] for ... + is_async = False + if s.sy == 'async': + is_async = True + s.next() + + # s.sy == 'for' + s.expect('for') + kw = p_for_bounds(s, allow_testlist=False, is_async=is_async) + kw.update(else_clause=None, body=p_comp_iter(s, body), is_async=is_async) + return Nodes.ForStatNode(pos, **kw) + +def p_comp_if(s, body): + # s.sy == 'if' + pos = s.position() + s.next() + # Note that Python 3.9+ is actually more restrictive here and Cython now follows + # the Python 3.9+ behaviour: https://github.com/python/cpython/issues/86014 + # On Python <3.9 `[i for i in range(10) if lambda: i if True else 1]` was disallowed + # but `[i for i in range(10) if lambda: i]` was allowed. + # On Python >=3.9 they're both disallowed. + test = p_or_test(s) + return Nodes.IfStatNode(pos, + if_clauses = [Nodes.IfClauseNode(pos, condition = test, + body = p_comp_iter(s, body))], + else_clause = None ) + + +# since PEP 448: +#dictorsetmaker: ( ((test ':' test | '**' expr) +# (comp_for | (',' (test ':' test | '**' expr))* [','])) | +# ((test | star_expr) +# (comp_for | (',' (test | star_expr))* [','])) ) + +def p_dict_or_set_maker(s): + # s.sy == '{' + pos = s.position() + s.next() + if s.sy == '}': + s.next() + return ExprNodes.DictNode(pos, key_value_pairs=[]) + + parts = [] + target_type = 0 + last_was_simple_item = False + while True: + if s.sy in ('*', '**'): + # merged set/dict literal + if target_type == 0: + target_type = 1 if s.sy == '*' else 2 # 'stars' + elif target_type != len(s.sy): + s.error("unexpected %sitem found in %s literal" % ( + s.sy, 'set' if target_type == 1 else 'dict')) + s.next() + if s.sy == '*': + s.error("expected expression, found '*'") + item = p_starred_expr(s) + parts.append(item) + last_was_simple_item = False + else: + item = p_test(s) + if target_type == 0: + target_type = 2 if s.sy == ':' else 1 # dict vs. set + if target_type == 2: + # dict literal + s.expect(':') + key = item + value = p_test(s) + item = ExprNodes.DictItemNode(key.pos, key=key, value=value) + if last_was_simple_item: + parts[-1].append(item) + else: + parts.append([item]) + last_was_simple_item = True + + if s.sy == ',': + s.next() + if s.sy == '}': + break + else: + break + + if s.sy in ('for', 'async'): + # dict/set comprehension + if len(parts) == 1 and isinstance(parts[0], list) and len(parts[0]) == 1: + item = parts[0][0] + if target_type == 2: + assert isinstance(item, ExprNodes.DictItemNode), type(item) + comprehension_type = Builtin.dict_type + append = ExprNodes.DictComprehensionAppendNode( + item.pos, key_expr=item.key, value_expr=item.value) + else: + comprehension_type = Builtin.set_type + append = ExprNodes.ComprehensionAppendNode(item.pos, expr=item) + loop = p_comp_for(s, append) + s.expect('}') + return ExprNodes.ComprehensionNode(pos, loop=loop, append=append, type=comprehension_type) + else: + # syntax error, try to find a good error message + if len(parts) == 1 and not isinstance(parts[0], list): + s.error("iterable unpacking cannot be used in comprehension") + else: + # e.g. "{1,2,3 for ..." + s.expect('}') + return ExprNodes.DictNode(pos, key_value_pairs=[]) + + s.expect('}') + if target_type == 1: + # (merged) set literal + items = [] + set_items = [] + for part in parts: + if isinstance(part, list): + set_items.extend(part) + else: + if set_items: + items.append(ExprNodes.SetNode(set_items[0].pos, args=set_items)) + set_items = [] + items.append(part) + if set_items: + items.append(ExprNodes.SetNode(set_items[0].pos, args=set_items)) + if len(items) == 1 and items[0].is_set_literal: + return items[0] + return ExprNodes.MergedSequenceNode(pos, args=items, type=Builtin.set_type) + else: + # (merged) dict literal + items = [] + dict_items = [] + for part in parts: + if isinstance(part, list): + dict_items.extend(part) + else: + if dict_items: + items.append(ExprNodes.DictNode(dict_items[0].pos, key_value_pairs=dict_items)) + dict_items = [] + items.append(part) + if dict_items: + items.append(ExprNodes.DictNode(dict_items[0].pos, key_value_pairs=dict_items)) + if len(items) == 1 and items[0].is_dict_literal: + return items[0] + return ExprNodes.MergedDictNode(pos, keyword_args=items, reject_duplicates=False) + + +# NOTE: no longer in Py3 :) +def p_backquote_expr(s): + # s.sy == '`' + pos = s.position() + s.next() + args = [p_test(s)] + while s.sy == ',': + s.next() + args.append(p_test(s)) + s.expect('`') + if len(args) == 1: + arg = args[0] + else: + arg = ExprNodes.TupleNode(pos, args = args) + return ExprNodes.BackquoteNode(pos, arg = arg) + +def p_simple_expr_list(s, expr=None): + exprs = expr is not None and [expr] or [] + while s.sy not in expr_terminators: + exprs.append( p_test(s) ) + if s.sy != ',': + break + s.next() + return exprs + + +def p_test_or_starred_expr_list(s, expr=None): + exprs = expr is not None and [expr] or [] + while s.sy not in expr_terminators: + exprs.append(p_test_or_starred_expr(s)) + if s.sy != ',': + break + s.next() + return exprs + +def p_namedexpr_test_or_starred_expr_list(s, expr=None): + exprs = expr is not None and [expr] or [] + while s.sy not in expr_terminators: + exprs.append(p_namedexpr_test_or_starred_expr(s)) + if s.sy != ',': + break + s.next() + return exprs + + +#testlist: test (',' test)* [','] + +def p_testlist(s): + pos = s.position() + expr = p_test(s) + if s.sy == ',': + s.next() + exprs = p_simple_expr_list(s, expr) + return ExprNodes.TupleNode(pos, args = exprs) + else: + return expr + +# testlist_star_expr: (test|star_expr) ( comp_for | (',' (test|star_expr))* [','] ) + +def p_testlist_star_expr(s): + pos = s.position() + expr = p_test_or_starred_expr(s) + if s.sy == ',': + s.next() + exprs = p_test_or_starred_expr_list(s, expr) + return ExprNodes.TupleNode(pos, args = exprs) + else: + return expr + +# testlist_comp: (test|star_expr) ( comp_for | (',' (test|star_expr))* [','] ) + +def p_testlist_comp(s): + pos = s.position() + expr = p_namedexpr_test_or_starred_expr(s) + if s.sy == ',': + s.next() + exprs = p_namedexpr_test_or_starred_expr_list(s, expr) + return ExprNodes.TupleNode(pos, args = exprs) + elif s.sy in ('for', 'async'): + return p_genexp(s, expr) + else: + return expr + +def p_genexp(s, expr): + # s.sy == 'async' | 'for' + loop = p_comp_for(s, Nodes.ExprStatNode( + expr.pos, expr = ExprNodes.YieldExprNode(expr.pos, arg=expr))) + return ExprNodes.GeneratorExpressionNode(expr.pos, loop=loop) + +expr_terminators = cython.declare(frozenset, frozenset(( + ')', ']', '}', ':', '=', 'NEWLINE'))) + + +#------------------------------------------------------- +# +# Statements +# +#------------------------------------------------------- + +def p_global_statement(s): + # assume s.sy == 'global' + pos = s.position() + s.next() + names = p_ident_list(s) + return Nodes.GlobalNode(pos, names = names) + + +def p_nonlocal_statement(s): + pos = s.position() + s.next() + names = p_ident_list(s) + return Nodes.NonlocalNode(pos, names = names) + + +def p_expression_or_assignment(s): + expr = p_testlist_star_expr(s) + has_annotation = False + if s.sy == ':' and (expr.is_name or expr.is_subscript or expr.is_attribute): + has_annotation = True + s.next() + expr.annotation = p_annotation(s) + + if s.sy == '=' and expr.is_starred: + # This is a common enough error to make when learning Cython to let + # it fail as early as possible and give a very clear error message. + s.error("a starred assignment target must be in a list or tuple" + " - maybe you meant to use an index assignment: var[0] = ...", + pos=expr.pos) + + expr_list = [expr] + while s.sy == '=': + s.next() + if s.sy == 'yield': + expr = p_yield_expression(s) + else: + expr = p_testlist_star_expr(s) + expr_list.append(expr) + if len(expr_list) == 1: + if re.match(r"([-+*/%^&|]|<<|>>|\*\*|//|@)=", s.sy): + lhs = expr_list[0] + if isinstance(lhs, ExprNodes.SliceIndexNode): + # implementation requires IndexNode + lhs = ExprNodes.IndexNode( + lhs.pos, + base=lhs.base, + index=make_slice_node(lhs.pos, lhs.start, lhs.stop)) + elif not isinstance(lhs, (ExprNodes.AttributeNode, ExprNodes.IndexNode, ExprNodes.NameNode)): + error(lhs.pos, "Illegal operand for inplace operation.") + operator = s.sy[:-1] + s.next() + if s.sy == 'yield': + rhs = p_yield_expression(s) + else: + rhs = p_testlist(s) + return Nodes.InPlaceAssignmentNode(lhs.pos, operator=operator, lhs=lhs, rhs=rhs) + expr = expr_list[0] + return Nodes.ExprStatNode(expr.pos, expr=expr) + + rhs = expr_list[-1] + if len(expr_list) == 2: + return Nodes.SingleAssignmentNode(rhs.pos, lhs=expr_list[0], rhs=rhs, first=has_annotation) + else: + return Nodes.CascadedAssignmentNode(rhs.pos, lhs_list=expr_list[:-1], rhs=rhs) + + +def p_print_statement(s): + # s.sy == 'print' + pos = s.position() + ends_with_comma = 0 + s.next() + if s.sy == '>>': + s.next() + stream = p_test(s) + if s.sy == ',': + s.next() + ends_with_comma = s.sy in ('NEWLINE', 'EOF') + else: + stream = None + args = [] + if s.sy not in ('NEWLINE', 'EOF'): + args.append(p_test(s)) + while s.sy == ',': + s.next() + if s.sy in ('NEWLINE', 'EOF'): + ends_with_comma = 1 + break + args.append(p_test(s)) + arg_tuple = ExprNodes.TupleNode(pos, args=args) + return Nodes.PrintStatNode(pos, + arg_tuple=arg_tuple, stream=stream, + append_newline=not ends_with_comma) + + +def p_exec_statement(s): + # s.sy == 'exec' + pos = s.position() + s.next() + code = p_bit_expr(s) + if isinstance(code, ExprNodes.TupleNode): + # Py3 compatibility syntax + tuple_variant = True + args = code.args + if len(args) not in (2, 3): + s.error("expected tuple of length 2 or 3, got length %d" % len(args), + pos=pos, fatal=False) + args = [code] + else: + tuple_variant = False + args = [code] + if s.sy == 'in': + if tuple_variant: + s.error("tuple variant of exec does not support additional 'in' arguments", + fatal=False) + s.next() + args.append(p_test(s)) + if s.sy == ',': + s.next() + args.append(p_test(s)) + return Nodes.ExecStatNode(pos, args=args) + +def p_del_statement(s): + # s.sy == 'del' + pos = s.position() + s.next() + # FIXME: 'exprlist' in Python + args = p_simple_expr_list(s) + return Nodes.DelStatNode(pos, args = args) + +def p_pass_statement(s, with_newline = 0): + pos = s.position() + s.expect('pass') + if with_newline: + s.expect_newline("Expected a newline", ignore_semicolon=True) + return Nodes.PassStatNode(pos) + +def p_break_statement(s): + # s.sy == 'break' + pos = s.position() + s.next() + return Nodes.BreakStatNode(pos) + +def p_continue_statement(s): + # s.sy == 'continue' + pos = s.position() + s.next() + return Nodes.ContinueStatNode(pos) + +def p_return_statement(s): + # s.sy == 'return' + pos = s.position() + s.next() + if s.sy not in statement_terminators: + value = p_testlist(s) + else: + value = None + return Nodes.ReturnStatNode(pos, value = value) + +def p_raise_statement(s): + # s.sy == 'raise' + pos = s.position() + s.next() + exc_type = None + exc_value = None + exc_tb = None + cause = None + if s.sy not in statement_terminators: + exc_type = p_test(s) + if s.sy == ',': + s.next() + exc_value = p_test(s) + if s.sy == ',': + s.next() + exc_tb = p_test(s) + elif s.sy == 'from': + s.next() + cause = p_test(s) + if exc_type or exc_value or exc_tb: + return Nodes.RaiseStatNode(pos, + exc_type = exc_type, + exc_value = exc_value, + exc_tb = exc_tb, + cause = cause) + else: + return Nodes.ReraiseStatNode(pos) + + +def p_import_statement(s): + # s.sy in ('import', 'cimport') + pos = s.position() + kind = s.sy + s.next() + items = [p_dotted_name(s, as_allowed=1)] + while s.sy == ',': + s.next() + items.append(p_dotted_name(s, as_allowed=1)) + stats = [] + is_absolute = Future.absolute_import in s.context.future_directives + for pos, target_name, dotted_name, as_name in items: + if kind == 'cimport': + stat = Nodes.CImportStatNode( + pos, + module_name=dotted_name, + as_name=as_name, + is_absolute=is_absolute) + else: + stat = Nodes.SingleAssignmentNode( + pos, + lhs=ExprNodes.NameNode(pos, name=as_name or target_name), + rhs=ExprNodes.ImportNode( + pos, + module_name=ExprNodes.IdentifierStringNode(pos, value=dotted_name), + level=0 if is_absolute else None, + get_top_level_module='.' in dotted_name and as_name is None, + name_list=None)) + stats.append(stat) + return Nodes.StatListNode(pos, stats=stats) + + +def p_from_import_statement(s, first_statement = 0): + # s.sy == 'from' + pos = s.position() + s.next() + if s.sy in ('.', '...'): + # count relative import level + level = 0 + while s.sy in ('.', '...'): + level += len(s.sy) + s.next() + else: + level = None + if level is not None and s.sy in ('import', 'cimport'): + # we are dealing with "from .. import foo, bar" + dotted_name_pos, dotted_name = s.position(), s.context.intern_ustring('') + else: + if level is None and Future.absolute_import in s.context.future_directives: + level = 0 + (dotted_name_pos, _, dotted_name, _) = p_dotted_name(s, as_allowed=False) + if s.sy not in ('import', 'cimport'): + s.error("Expected 'import' or 'cimport'") + kind = s.sy + s.next() + + is_cimport = kind == 'cimport' + is_parenthesized = False + if s.sy == '*': + imported_names = [(s.position(), s.context.intern_ustring("*"), None)] + s.next() + else: + if s.sy == '(': + is_parenthesized = True + s.next() + imported_names = [p_imported_name(s)] + while s.sy == ',': + s.next() + if is_parenthesized and s.sy == ')': + break + imported_names.append(p_imported_name(s)) + if is_parenthesized: + s.expect(')') + if dotted_name == '__future__': + if not first_statement: + s.error("from __future__ imports must occur at the beginning of the file") + elif level: + s.error("invalid syntax") + else: + for (name_pos, name, as_name) in imported_names: + if name == "braces": + s.error("not a chance", name_pos) + break + try: + directive = getattr(Future, name) + except AttributeError: + s.error("future feature %s is not defined" % name, name_pos) + break + s.context.future_directives.add(directive) + return Nodes.PassStatNode(pos) + elif is_cimport: + return Nodes.FromCImportStatNode( + pos, module_name=dotted_name, + relative_level=level, + imported_names=imported_names) + else: + imported_name_strings = [] + items = [] + for (name_pos, name, as_name) in imported_names: + imported_name_strings.append( + ExprNodes.IdentifierStringNode(name_pos, value=name)) + items.append( + (name, ExprNodes.NameNode(name_pos, name=as_name or name))) + import_list = ExprNodes.ListNode( + imported_names[0][0], args=imported_name_strings) + return Nodes.FromImportStatNode(pos, + module = ExprNodes.ImportNode(dotted_name_pos, + module_name = ExprNodes.IdentifierStringNode(pos, value = dotted_name), + level = level, + name_list = import_list), + items = items) + + +def p_imported_name(s): + pos = s.position() + name = p_ident(s) + as_name = p_as_name(s) + return (pos, name, as_name) + + +def p_dotted_name(s, as_allowed): + pos = s.position() + target_name = p_ident(s) + as_name = None + names = [target_name] + while s.sy == '.': + s.next() + names.append(p_ident(s)) + if as_allowed: + as_name = p_as_name(s) + return (pos, target_name, s.context.intern_ustring(u'.'.join(names)), as_name) + + +def p_as_name(s): + if s.sy == 'IDENT' and s.systring == 'as': + s.next() + return p_ident(s) + else: + return None + + +def p_assert_statement(s): + # s.sy == 'assert' + pos = s.position() + s.next() + cond = p_test(s) + if s.sy == ',': + s.next() + value = p_test(s) + else: + value = None + return Nodes.AssertStatNode(pos, condition=cond, value=value) + + +statement_terminators = cython.declare(frozenset, frozenset(( + ';', 'NEWLINE', 'EOF'))) + +def p_if_statement(s): + # s.sy == 'if' + pos = s.position() + s.next() + if_clauses = [p_if_clause(s)] + while s.sy == 'elif': + s.next() + if_clauses.append(p_if_clause(s)) + else_clause = p_else_clause(s) + return Nodes.IfStatNode(pos, + if_clauses = if_clauses, else_clause = else_clause) + +def p_if_clause(s): + pos = s.position() + test = p_namedexpr_test(s) + body = p_suite(s) + return Nodes.IfClauseNode(pos, + condition = test, body = body) + +def p_else_clause(s): + if s.sy == 'else': + s.next() + return p_suite(s) + else: + return None + +def p_while_statement(s): + # s.sy == 'while' + pos = s.position() + s.next() + test = p_namedexpr_test(s) + body = p_suite(s) + else_clause = p_else_clause(s) + return Nodes.WhileStatNode(pos, + condition = test, body = body, + else_clause = else_clause) + + +def p_for_statement(s, is_async=False): + # s.sy == 'for' + pos = s.position() + s.next() + kw = p_for_bounds(s, allow_testlist=True, is_async=is_async) + body = p_suite(s) + else_clause = p_else_clause(s) + kw.update(body=body, else_clause=else_clause, is_async=is_async) + return Nodes.ForStatNode(pos, **kw) + + +def p_for_bounds(s, allow_testlist=True, is_async=False): + target = p_for_target(s) + if s.sy == 'in': + s.next() + iterator = p_for_iterator(s, allow_testlist, is_async=is_async) + return dict(target=target, iterator=iterator) + elif not s.in_python_file and not is_async: + if s.sy == 'from': + s.next() + bound1 = p_bit_expr(s) + else: + # Support shorter "for a <= x < b" syntax + bound1, target = target, None + rel1 = p_for_from_relation(s) + name2_pos = s.position() + name2 = p_ident(s) + rel2_pos = s.position() + rel2 = p_for_from_relation(s) + bound2 = p_bit_expr(s) + step = p_for_from_step(s) + if target is None: + target = ExprNodes.NameNode(name2_pos, name = name2) + else: + if not target.is_name: + error(target.pos, + "Target of for-from statement must be a variable name") + elif name2 != target.name: + error(name2_pos, + "Variable name in for-from range does not match target") + if rel1[0] != rel2[0]: + error(rel2_pos, + "Relation directions in for-from do not match") + return dict(target = target, + bound1 = bound1, + relation1 = rel1, + relation2 = rel2, + bound2 = bound2, + step = step, + ) + else: + s.expect('in') + return {} + +def p_for_from_relation(s): + if s.sy in inequality_relations: + op = s.sy + s.next() + return op + else: + s.error("Expected one of '<', '<=', '>' '>='") + +def p_for_from_step(s): + if s.sy == 'IDENT' and s.systring == 'by': + s.next() + step = p_bit_expr(s) + return step + else: + return None + +inequality_relations = cython.declare(frozenset, frozenset(( + '<', '<=', '>', '>='))) + +def p_target(s, terminator): + pos = s.position() + expr = p_starred_expr(s) + if s.sy == ',': + s.next() + exprs = [expr] + while s.sy != terminator: + exprs.append(p_starred_expr(s)) + if s.sy != ',': + break + s.next() + return ExprNodes.TupleNode(pos, args = exprs) + else: + return expr + + +def p_for_target(s): + return p_target(s, 'in') + + +def p_for_iterator(s, allow_testlist=True, is_async=False): + pos = s.position() + if allow_testlist: + expr = p_testlist(s) + else: + expr = p_or_test(s) + return (ExprNodes.AsyncIteratorNode if is_async else ExprNodes.IteratorNode)(pos, sequence=expr) + + +def p_try_statement(s): + # s.sy == 'try' + pos = s.position() + s.next() + body = p_suite(s) + except_clauses = [] + else_clause = None + if s.sy in ('except', 'else'): + while s.sy == 'except': + except_clauses.append(p_except_clause(s)) + if s.sy == 'else': + s.next() + else_clause = p_suite(s) + body = Nodes.TryExceptStatNode(pos, + body = body, except_clauses = except_clauses, + else_clause = else_clause) + if s.sy != 'finally': + return body + # try-except-finally is equivalent to nested try-except/try-finally + if s.sy == 'finally': + s.next() + finally_clause = p_suite(s) + return Nodes.TryFinallyStatNode(pos, + body = body, finally_clause = finally_clause) + else: + s.error("Expected 'except' or 'finally'") + +def p_except_clause(s): + # s.sy == 'except' + pos = s.position() + s.next() + exc_type = None + exc_value = None + is_except_as = False + if s.sy != ':': + exc_type = p_test(s) + # normalise into list of single exception tests + if isinstance(exc_type, ExprNodes.TupleNode): + exc_type = exc_type.args + else: + exc_type = [exc_type] + if s.sy == ',' or (s.sy == 'IDENT' and s.systring == 'as' + and s.context.language_level == 2): + s.next() + exc_value = p_test(s) + elif s.sy == 'IDENT' and s.systring == 'as': + # Py3 syntax requires a name here + s.next() + pos2 = s.position() + name = p_ident(s) + exc_value = ExprNodes.NameNode(pos2, name = name) + is_except_as = True + body = p_suite(s) + return Nodes.ExceptClauseNode(pos, + pattern = exc_type, target = exc_value, + body = body, is_except_as=is_except_as) + +def p_include_statement(s, ctx): + pos = s.position() + s.next() # 'include' + unicode_include_file_name = p_string_literal(s, 'u')[2] + s.expect_newline("Syntax error in include statement") + if s.compile_time_eval: + include_file_name = unicode_include_file_name + include_file_path = s.context.find_include_file(include_file_name, pos) + if include_file_path: + s.included_files.append(include_file_name) + with Utils.open_source_file(include_file_path) as f: + source_desc = FileSourceDescriptor(include_file_path) + s2 = PyrexScanner(f, source_desc, s, source_encoding=f.encoding, parse_comments=s.parse_comments) + tree = p_statement_list(s2, ctx) + return tree + else: + return None + else: + return Nodes.PassStatNode(pos) + + +def p_with_statement(s): + s.next() # 'with' + if s.systring == 'template' and not s.in_python_file: + node = p_with_template(s) + else: + node = p_with_items(s) + return node + + +def p_with_items(s, is_async=False): + """ + Copied from CPython: + | 'with' '(' a[asdl_withitem_seq*]=','.with_item+ ','? ')' ':' b=block { + _PyAST_With(a, b, NULL, EXTRA) } + | 'with' a[asdl_withitem_seq*]=','.with_item+ ':' tc=[TYPE_COMMENT] b=block { + _PyAST_With(a, b, NEW_TYPE_COMMENT(p, tc), EXTRA) } + Therefore the first thing to try is the bracket-enclosed + version and if that fails try the regular version + """ + brackets_succeeded = False + items = () # unused, but static analysis fails to track that below + if s.sy == '(': + with tentatively_scan(s) as errors: + s.next() + items = p_with_items_list(s, is_async) + s.expect(")") + if s.sy != ":": + # Fail - the message doesn't matter because we'll try the + # non-bracket version so it'll never be shown + s.error("") + brackets_succeeded = not errors + if not brackets_succeeded: + # try the non-bracket version + items = p_with_items_list(s, is_async) + body = p_suite(s) + for cls, pos, kwds in reversed(items): + # construct the actual nodes now that we know what the body is + body = cls(pos, body=body, **kwds) + return body + + +def p_with_items_list(s, is_async): + items = [] + while True: + items.append(p_with_item(s, is_async)) + if s.sy != ",": + break + s.next() + if s.sy == ")": + # trailing commas allowed + break + return items + + +def p_with_item(s, is_async): + # In contrast to most parsing functions, this returns a tuple of + # class, pos, kwd_dict + # This is because GILStatNode does a reasonable amount of initialization in its + # constructor, and requires "body" to be set, which we don't currently have + pos = s.position() + if not s.in_python_file and s.sy == 'IDENT' and s.systring in ('nogil', 'gil'): + if is_async: + s.error("with gil/nogil cannot be async") + state = s.systring + s.next() + + # support conditional gil/nogil + condition = None + if s.sy == '(': + s.next() + condition = p_test(s) + s.expect(')') + + return Nodes.GILStatNode, pos, {"state": state, "condition": condition} + else: + manager = p_test(s) + target = None + if s.sy == 'IDENT' and s.systring == 'as': + s.next() + target = p_starred_expr(s) + return Nodes.WithStatNode, pos, {"manager": manager, "target": target, "is_async": is_async} + + +def p_with_template(s): + pos = s.position() + templates = [] + s.next() + s.expect('[') + templates.append(s.systring) + s.next() + while s.systring == ',': + s.next() + templates.append(s.systring) + s.next() + s.expect(']') + if s.sy == ':': + s.next() + s.expect_newline("Syntax error in template function declaration") + s.expect_indent() + body_ctx = Ctx() + body_ctx.templates = templates + func_or_var = p_c_func_or_var_declaration(s, pos, body_ctx) + s.expect_dedent() + return func_or_var + else: + error(pos, "Syntax error in template function declaration") + +def p_simple_statement(s, first_statement = 0): + #print "p_simple_statement:", s.sy, s.systring ### + if s.sy == 'global': + node = p_global_statement(s) + elif s.sy == 'nonlocal': + node = p_nonlocal_statement(s) + elif s.sy == 'print': + node = p_print_statement(s) + elif s.sy == 'exec': + node = p_exec_statement(s) + elif s.sy == 'del': + node = p_del_statement(s) + elif s.sy == 'break': + node = p_break_statement(s) + elif s.sy == 'continue': + node = p_continue_statement(s) + elif s.sy == 'return': + node = p_return_statement(s) + elif s.sy == 'raise': + node = p_raise_statement(s) + elif s.sy in ('import', 'cimport'): + node = p_import_statement(s) + elif s.sy == 'from': + node = p_from_import_statement(s, first_statement = first_statement) + elif s.sy == 'yield': + node = p_yield_statement(s) + elif s.sy == 'assert': + node = p_assert_statement(s) + elif s.sy == 'pass': + node = p_pass_statement(s) + else: + node = p_expression_or_assignment(s) + return node + +def p_simple_statement_list(s, ctx, first_statement = 0): + # Parse a series of simple statements on one line + # separated by semicolons. + stat = p_simple_statement(s, first_statement = first_statement) + pos = stat.pos + stats = [] + if not isinstance(stat, Nodes.PassStatNode): + stats.append(stat) + while s.sy == ';': + #print "p_simple_statement_list: maybe more to follow" ### + s.next() + if s.sy in ('NEWLINE', 'EOF'): + break + stat = p_simple_statement(s, first_statement = first_statement) + if isinstance(stat, Nodes.PassStatNode): + continue + stats.append(stat) + first_statement = False + + if not stats: + stat = Nodes.PassStatNode(pos) + elif len(stats) == 1: + stat = stats[0] + else: + stat = Nodes.StatListNode(pos, stats = stats) + + if s.sy not in ('NEWLINE', 'EOF'): + # provide a better error message for users who accidentally write Cython code in .py files + if isinstance(stat, Nodes.ExprStatNode): + if stat.expr.is_name and stat.expr.name == 'cdef': + s.error("The 'cdef' keyword is only allowed in Cython files (pyx/pxi/pxd)", pos) + s.expect_newline("Syntax error in simple statement list") + + return stat + +def p_compile_time_expr(s): + old = s.compile_time_expr + s.compile_time_expr = 1 + expr = p_testlist(s) + s.compile_time_expr = old + return expr + +def p_DEF_statement(s): + pos = s.position() + denv = s.compile_time_env + s.next() # 'DEF' + name = p_ident(s) + s.expect('=') + expr = p_compile_time_expr(s) + if s.compile_time_eval: + value = expr.compile_time_value(denv) + #print "p_DEF_statement: %s = %r" % (name, value) ### + denv.declare(name, value) + s.expect_newline("Expected a newline", ignore_semicolon=True) + return Nodes.PassStatNode(pos) + +def p_IF_statement(s, ctx): + pos = s.position() + saved_eval = s.compile_time_eval + current_eval = saved_eval + denv = s.compile_time_env + result = None + while 1: + s.next() # 'IF' or 'ELIF' + expr = p_compile_time_expr(s) + s.compile_time_eval = current_eval and bool(expr.compile_time_value(denv)) + body = p_suite(s, ctx) + if s.compile_time_eval: + result = body + current_eval = 0 + if s.sy != 'ELIF': + break + if s.sy == 'ELSE': + s.next() + s.compile_time_eval = current_eval + body = p_suite(s, ctx) + if current_eval: + result = body + if not result: + result = Nodes.PassStatNode(pos) + s.compile_time_eval = saved_eval + return result + +def p_statement(s, ctx, first_statement = 0): + cdef_flag = ctx.cdef_flag + decorators = None + if s.sy == 'ctypedef': + if ctx.level not in ('module', 'module_pxd'): + s.error("ctypedef statement not allowed here") + #if ctx.api: + # error(s.position(), "'api' not allowed with 'ctypedef'") + return p_ctypedef_statement(s, ctx) + elif s.sy == 'DEF': + # We used to dep-warn about this but removed the warning again since + # we don't have a good answer yet for all use cases. + # warning(s.position(), + # "The 'DEF' statement is deprecated and will be removed in a future Cython version. " + # "Consider using global variables, constants, and in-place literals instead. " + # "See https://github.com/cython/cython/issues/4310", level=1) + return p_DEF_statement(s) + elif s.sy == 'IF': + warning(s.position(), + "The 'IF' statement is deprecated and will be removed in a future Cython version. " + "Consider using runtime conditions or C macros instead. " + "See https://github.com/cython/cython/issues/4310", level=1) + return p_IF_statement(s, ctx) + elif s.sy == '@': + if ctx.level not in ('module', 'class', 'c_class', 'function', 'property', 'module_pxd', 'c_class_pxd', 'other'): + s.error('decorator not allowed here') + s.level = ctx.level + decorators = p_decorators(s) + if not ctx.allow_struct_enum_decorator and s.sy not in ('def', 'cdef', 'cpdef', 'class', 'async'): + if s.sy == 'IDENT' and s.systring == 'async': + pass # handled below + else: + s.error("Decorators can only be followed by functions or classes") + elif s.sy == 'pass' and cdef_flag: + # empty cdef block + return p_pass_statement(s, with_newline=1) + + overridable = 0 + if s.sy == 'cdef': + cdef_flag = 1 + s.next() + elif s.sy == 'cpdef': + cdef_flag = 1 + overridable = 1 + s.next() + if cdef_flag: + if ctx.level not in ('module', 'module_pxd', 'function', 'c_class', 'c_class_pxd'): + s.error('cdef statement not allowed here') + s.level = ctx.level + node = p_cdef_statement(s, ctx(overridable=overridable)) + if decorators is not None: + tup = (Nodes.CFuncDefNode, Nodes.CVarDefNode, Nodes.CClassDefNode) + if ctx.allow_struct_enum_decorator: + tup += (Nodes.CStructOrUnionDefNode, Nodes.CEnumDefNode) + if not isinstance(node, tup): + s.error("Decorators can only be followed by functions or classes") + node.decorators = decorators + return node + else: + if ctx.api: + s.error("'api' not allowed with this statement", fatal=False) + elif s.sy == 'def': + # def statements aren't allowed in pxd files, except + # as part of a cdef class + if ('pxd' in ctx.level) and (ctx.level != 'c_class_pxd'): + s.error('def statement not allowed here') + s.level = ctx.level + return p_def_statement(s, decorators) + elif s.sy == 'class': + if ctx.level not in ('module', 'function', 'class', 'other'): + s.error("class definition not allowed here") + return p_class_statement(s, decorators) + elif s.sy == 'include': + if ctx.level not in ('module', 'module_pxd'): + s.error("include statement not allowed here") + return p_include_statement(s, ctx) + elif ctx.level == 'c_class' and s.sy == 'IDENT' and s.systring == 'property': + return p_property_decl(s) + elif s.sy == 'pass' and ctx.level != 'property': + return p_pass_statement(s, with_newline=True) + else: + if ctx.level in ('c_class_pxd', 'property'): + node = p_ignorable_statement(s) + if node is not None: + return node + s.error("Executable statement not allowed here") + if s.sy == 'if': + return p_if_statement(s) + elif s.sy == 'while': + return p_while_statement(s) + elif s.sy == 'for': + return p_for_statement(s) + elif s.sy == 'try': + return p_try_statement(s) + elif s.sy == 'with': + return p_with_statement(s) + elif s.sy == 'async': + s.next() + return p_async_statement(s, ctx, decorators) + else: + if s.sy == 'IDENT' and s.systring == 'async': + ident_name = s.systring + ident_pos = s.position() + # PEP 492 enables the async/await keywords when it spots "async def ..." + s.next() + if s.sy == 'def': + return p_async_statement(s, ctx, decorators) + elif decorators: + s.error("Decorators can only be followed by functions or classes") + s.put_back(u'IDENT', ident_name, ident_pos) # re-insert original token + return p_simple_statement_list(s, ctx, first_statement=first_statement) + + +def p_statement_list(s, ctx, first_statement = 0): + # Parse a series of statements separated by newlines. + pos = s.position() + stats = [] + while s.sy not in ('DEDENT', 'EOF'): + stat = p_statement(s, ctx, first_statement = first_statement) + if isinstance(stat, Nodes.PassStatNode): + continue + stats.append(stat) + first_statement = False + if not stats: + return Nodes.PassStatNode(pos) + elif len(stats) == 1: + return stats[0] + else: + return Nodes.StatListNode(pos, stats = stats) + + +def p_suite(s, ctx=Ctx()): + return p_suite_with_docstring(s, ctx, with_doc_only=False)[1] + + +def p_suite_with_docstring(s, ctx, with_doc_only=False): + s.expect(':') + doc = None + if s.sy == 'NEWLINE': + s.next() + s.expect_indent() + if with_doc_only: + doc = p_doc_string(s) + body = p_statement_list(s, ctx) + s.expect_dedent() + else: + if ctx.api: + s.error("'api' not allowed with this statement", fatal=False) + if ctx.level in ('module', 'class', 'function', 'other'): + body = p_simple_statement_list(s, ctx) + else: + body = p_pass_statement(s) + s.expect_newline("Syntax error in declarations", ignore_semicolon=True) + if not with_doc_only: + doc, body = _extract_docstring(body) + return doc, body + + +def p_positional_and_keyword_args(s, end_sy_set, templates = None): + """ + Parses positional and keyword arguments. end_sy_set + should contain any s.sy that terminate the argument list. + Argument expansion (* and **) are not allowed. + + Returns: (positional_args, keyword_args) + """ + positional_args = [] + keyword_args = [] + pos_idx = 0 + + while s.sy not in end_sy_set: + if s.sy == '*' or s.sy == '**': + s.error('Argument expansion not allowed here.', fatal=False) + + parsed_type = False + if s.sy == 'IDENT' and s.peek()[0] == '=': + ident = s.systring + s.next() # s.sy is '=' + s.next() + if looking_at_expr(s): + arg = p_test(s) + else: + base_type = p_c_base_type(s, templates = templates) + declarator = p_c_declarator(s, empty = 1) + arg = Nodes.CComplexBaseTypeNode(base_type.pos, + base_type = base_type, declarator = declarator) + parsed_type = True + keyword_node = ExprNodes.IdentifierStringNode(arg.pos, value=ident) + keyword_args.append((keyword_node, arg)) + was_keyword = True + + else: + if looking_at_expr(s): + arg = p_test(s) + else: + base_type = p_c_base_type(s, templates = templates) + declarator = p_c_declarator(s, empty = 1) + arg = Nodes.CComplexBaseTypeNode(base_type.pos, + base_type = base_type, declarator = declarator) + parsed_type = True + positional_args.append(arg) + pos_idx += 1 + if len(keyword_args) > 0: + s.error("Non-keyword arg following keyword arg", + pos=arg.pos) + + if s.sy != ',': + if s.sy not in end_sy_set: + if parsed_type: + s.error("Unmatched %s" % " or ".join(end_sy_set)) + break + s.next() + return positional_args, keyword_args + +def p_c_base_type(s, nonempty=False, templates=None): + if s.sy == '(': + return p_c_complex_base_type(s, templates = templates) + else: + return p_c_simple_base_type(s, nonempty=nonempty, templates=templates) + +def p_calling_convention(s): + if s.sy == 'IDENT' and s.systring in calling_convention_words: + result = s.systring + s.next() + return result + else: + return "" + + +calling_convention_words = cython.declare(frozenset, frozenset(( + "__stdcall", "__cdecl", "__fastcall"))) + + +def p_c_complex_base_type(s, templates = None): + # s.sy == '(' + pos = s.position() + s.next() + base_type = p_c_base_type(s, templates=templates) + declarator = p_c_declarator(s, empty=True) + type_node = Nodes.CComplexBaseTypeNode( + pos, base_type=base_type, declarator=declarator) + if s.sy == ',': + components = [type_node] + while s.sy == ',': + s.next() + if s.sy == ')': + break + base_type = p_c_base_type(s, templates=templates) + declarator = p_c_declarator(s, empty=True) + components.append(Nodes.CComplexBaseTypeNode( + pos, base_type=base_type, declarator=declarator)) + type_node = Nodes.CTupleBaseTypeNode(pos, components = components) + + s.expect(')') + if s.sy == '[': + if is_memoryviewslice_access(s): + type_node = p_memoryviewslice_access(s, type_node) + else: + type_node = p_buffer_or_template(s, type_node, templates) + return type_node + + +def p_c_simple_base_type(s, nonempty, templates=None): + is_basic = 0 + signed = 1 + longness = 0 + complex = 0 + module_path = [] + pos = s.position() + + # Handle const/volatile + is_const = is_volatile = 0 + while s.sy == 'IDENT': + if s.systring == 'const': + if is_const: error(pos, "Duplicate 'const'") + is_const = 1 + elif s.systring == 'volatile': + if is_volatile: error(pos, "Duplicate 'volatile'") + is_volatile = 1 + else: + break + s.next() + if is_const or is_volatile: + base_type = p_c_base_type(s, nonempty=nonempty, templates=templates) + if isinstance(base_type, Nodes.MemoryViewSliceTypeNode): + # reverse order to avoid having to write "(const int)[:]" + base_type.base_type_node = Nodes.CConstOrVolatileTypeNode(pos, + base_type=base_type.base_type_node, is_const=is_const, is_volatile=is_volatile) + return base_type + return Nodes.CConstOrVolatileTypeNode(pos, + base_type=base_type, is_const=is_const, is_volatile=is_volatile) + + if s.sy != 'IDENT': + error(pos, "Expected an identifier, found '%s'" % s.sy) + if looking_at_base_type(s): + #print "p_c_simple_base_type: looking_at_base_type at", s.position() + is_basic = 1 + if s.sy == 'IDENT' and s.systring in special_basic_c_types: + signed, longness = special_basic_c_types[s.systring] + name = s.systring + s.next() + else: + signed, longness = p_sign_and_longness(s) + if s.sy == 'IDENT' and s.systring in basic_c_type_names: + name = s.systring + s.next() + else: + name = 'int' # long [int], short [int], long [int] complex, etc. + if s.sy == 'IDENT' and s.systring == 'complex': + complex = 1 + s.next() + elif looking_at_dotted_name(s): + #print "p_c_simple_base_type: looking_at_type_name at", s.position() + name = s.systring + s.next() + while s.sy == '.': + module_path.append(name) + s.next() + name = p_ident(s) + else: + name = s.systring + name_pos = s.position() + s.next() + if nonempty and s.sy != 'IDENT': + # Make sure this is not a declaration of a variable or function. + if s.sy == '(': + old_pos = s.position() + s.next() + if (s.sy == '*' or s.sy == '**' or s.sy == '&' + or (s.sy == 'IDENT' and s.systring in calling_convention_words)): + s.put_back(u'(', u'(', old_pos) + else: + s.put_back(u'(', u'(', old_pos) + s.put_back(u'IDENT', name, name_pos) + name = None + elif s.sy not in ('*', '**', '[', '&'): + s.put_back(u'IDENT', name, name_pos) + name = None + + type_node = Nodes.CSimpleBaseTypeNode(pos, + name = name, module_path = module_path, + is_basic_c_type = is_basic, signed = signed, + complex = complex, longness = longness, + templates = templates) + + # declarations here. + if s.sy == '[': + if is_memoryviewslice_access(s): + type_node = p_memoryviewslice_access(s, type_node) + else: + type_node = p_buffer_or_template(s, type_node, templates) + + if s.sy == '.': + s.next() + name = p_ident(s) + type_node = Nodes.CNestedBaseTypeNode(pos, base_type = type_node, name = name) + + return type_node + +def p_buffer_or_template(s, base_type_node, templates): + # s.sy == '[' + pos = s.position() + s.next() + # Note that buffer_positional_options_count=1, so the only positional argument is dtype. + # For templated types, all parameters are types. + positional_args, keyword_args = ( + p_positional_and_keyword_args(s, (']',), templates) + ) + s.expect(']') + + if s.sy == '[': + base_type_node = p_buffer_or_template(s, base_type_node, templates) + + keyword_dict = ExprNodes.DictNode(pos, + key_value_pairs = [ + ExprNodes.DictItemNode(pos=key.pos, key=key, value=value) + for key, value in keyword_args + ]) + result = Nodes.TemplatedTypeNode(pos, + positional_args = positional_args, + keyword_args = keyword_dict, + base_type_node = base_type_node) + return result + +def p_bracketed_base_type(s, base_type_node, nonempty, empty): + # s.sy == '[' + if empty and not nonempty: + # sizeof-like thing. Only anonymous C arrays allowed (int[SIZE]). + return base_type_node + elif not empty and nonempty: + # declaration of either memoryview slice or buffer. + if is_memoryviewslice_access(s): + return p_memoryviewslice_access(s, base_type_node) + else: + return p_buffer_or_template(s, base_type_node, None) + # return p_buffer_access(s, base_type_node) + elif not empty and not nonempty: + # only anonymous C arrays and memoryview slice arrays here. We + # disallow buffer declarations for now, due to ambiguity with anonymous + # C arrays. + if is_memoryviewslice_access(s): + return p_memoryviewslice_access(s, base_type_node) + else: + return base_type_node + +def is_memoryviewslice_access(s): + # s.sy == '[' + # a memoryview slice declaration is distinguishable from a buffer access + # declaration by the first entry in the bracketed list. The buffer will + # not have an unnested colon in the first entry; the memoryview slice will. + saved = [(s.sy, s.systring, s.position())] + s.next() + retval = False + if s.systring == ':': + retval = True + elif s.sy == 'INT': + saved.append((s.sy, s.systring, s.position())) + s.next() + if s.sy == ':': + retval = True + + for sv in saved[::-1]: + s.put_back(*sv) + + return retval + +def p_memoryviewslice_access(s, base_type_node): + # s.sy == '[' + pos = s.position() + s.next() + subscripts, _ = p_subscript_list(s) + # make sure each entry in subscripts is a slice + for subscript in subscripts: + if len(subscript) < 2: + s.error("An axis specification in memoryview declaration does not have a ':'.") + s.expect(']') + indexes = make_slice_nodes(pos, subscripts) + result = Nodes.MemoryViewSliceTypeNode(pos, + base_type_node = base_type_node, + axes = indexes) + return result + +def looking_at_name(s): + return s.sy == 'IDENT' and s.systring not in calling_convention_words + +def looking_at_expr(s): + if s.systring in base_type_start_words: + return False + elif s.sy == 'IDENT': + is_type = False + name = s.systring + name_pos = s.position() + dotted_path = [] + s.next() + + while s.sy == '.': + s.next() + dotted_path.append((s.systring, s.position())) + s.expect('IDENT') + + saved = s.sy, s.systring, s.position() + if s.sy == 'IDENT': + is_type = True + elif s.sy == '*' or s.sy == '**': + s.next() + is_type = s.sy in (')', ']') + s.put_back(*saved) + elif s.sy == '(': + s.next() + is_type = s.sy == '*' + s.put_back(*saved) + elif s.sy == '[': + s.next() + is_type = s.sy == ']' or not looking_at_expr(s) # could be a nested template type + s.put_back(*saved) + + dotted_path.reverse() + for p in dotted_path: + s.put_back(u'IDENT', *p) + s.put_back(u'.', u'.', p[1]) # gets the position slightly wrong + + s.put_back(u'IDENT', name, name_pos) + return not is_type and saved[0] + else: + return True + +def looking_at_base_type(s): + #print "looking_at_base_type?", s.sy, s.systring, s.position() + return s.sy == 'IDENT' and s.systring in base_type_start_words + +def looking_at_dotted_name(s): + if s.sy == 'IDENT': + name = s.systring + name_pos = s.position() + s.next() + result = s.sy == '.' + s.put_back(u'IDENT', name, name_pos) + return result + else: + return 0 + + +basic_c_type_names = cython.declare(frozenset, frozenset(( + "void", "char", "int", "float", "double", "bint"))) + +special_basic_c_types = cython.declare(dict, { + # name : (signed, longness) + "Py_UNICODE" : (0, 0), + "Py_UCS4" : (0, 0), + "Py_hash_t" : (2, 0), + "Py_ssize_t" : (2, 0), + "ssize_t" : (2, 0), + "size_t" : (0, 0), + "ptrdiff_t" : (2, 0), + "Py_tss_t" : (1, 0), +}) + +sign_and_longness_words = cython.declare(frozenset, frozenset(( + "short", "long", "signed", "unsigned"))) + +base_type_start_words = cython.declare( + frozenset, + basic_c_type_names + | sign_and_longness_words + | frozenset(special_basic_c_types)) + +struct_enum_union = cython.declare(frozenset, frozenset(( + "struct", "union", "enum", "packed"))) + +def p_sign_and_longness(s): + signed = 1 + longness = 0 + while s.sy == 'IDENT' and s.systring in sign_and_longness_words: + if s.systring == 'unsigned': + signed = 0 + elif s.systring == 'signed': + signed = 2 + elif s.systring == 'short': + longness = -1 + elif s.systring == 'long': + longness += 1 + s.next() + return signed, longness + +def p_opt_cname(s): + literal = p_opt_string_literal(s, 'u') + if literal is not None: + cname = EncodedString(literal) + cname.encoding = s.source_encoding + else: + cname = None + return cname + +def p_c_declarator(s, ctx = Ctx(), empty = 0, is_type = 0, cmethod_flag = 0, + assignable = 0, nonempty = 0, + calling_convention_allowed = 0): + # If empty is true, the declarator must be empty. If nonempty is true, + # the declarator must be nonempty. Otherwise we don't care. + # If cmethod_flag is true, then if this declarator declares + # a function, it's a C method of an extension type. + pos = s.position() + if s.sy == '(': + s.next() + if s.sy == ')' or looking_at_name(s): + base = Nodes.CNameDeclaratorNode(pos, name=s.context.intern_ustring(u""), cname=None) + result = p_c_func_declarator(s, pos, ctx, base, cmethod_flag) + else: + result = p_c_declarator(s, ctx, empty = empty, is_type = is_type, + cmethod_flag = cmethod_flag, + nonempty = nonempty, + calling_convention_allowed = 1) + s.expect(')') + else: + result = p_c_simple_declarator(s, ctx, empty, is_type, cmethod_flag, + assignable, nonempty) + if not calling_convention_allowed and result.calling_convention and s.sy != '(': + error(s.position(), "%s on something that is not a function" + % result.calling_convention) + while s.sy in ('[', '('): + pos = s.position() + if s.sy == '[': + result = p_c_array_declarator(s, result) + else: # sy == '(' + s.next() + result = p_c_func_declarator(s, pos, ctx, result, cmethod_flag) + cmethod_flag = 0 + return result + +def p_c_array_declarator(s, base): + pos = s.position() + s.next() # '[' + if s.sy != ']': + dim = p_testlist(s) + else: + dim = None + s.expect(']') + return Nodes.CArrayDeclaratorNode(pos, base = base, dimension = dim) + +def p_c_func_declarator(s, pos, ctx, base, cmethod_flag): + # Opening paren has already been skipped + args = p_c_arg_list(s, ctx, cmethod_flag = cmethod_flag, + nonempty_declarators = 0) + ellipsis = p_optional_ellipsis(s) + s.expect(')') + nogil = p_nogil(s) + exc_val, exc_check, exc_clause = p_exception_value_clause(s, ctx.visibility == 'extern') + if nogil and exc_clause: + warning( + s.position(), + "The keyword 'nogil' should appear at the end of the " + "function signature line. Placing it before 'except' " + "or 'noexcept' will be disallowed in a future version " + "of Cython.", + level=2 + ) + nogil = nogil or p_nogil(s) + with_gil = p_with_gil(s) + return Nodes.CFuncDeclaratorNode(pos, + base = base, args = args, has_varargs = ellipsis, + exception_value = exc_val, exception_check = exc_check, + nogil = nogil or ctx.nogil or with_gil, with_gil = with_gil, has_explicit_exc_clause=exc_clause) + +supported_overloaded_operators = cython.declare(frozenset, frozenset(( + '+', '-', '*', '/', '%', + '++', '--', '~', '|', '&', '^', '<<', '>>', ',', + '==', '!=', '>=', '>', '<=', '<', + '[]', '()', '!', '=', + 'bool', +))) + +def p_c_simple_declarator(s, ctx, empty, is_type, cmethod_flag, + assignable, nonempty): + pos = s.position() + calling_convention = p_calling_convention(s) + if s.sy in ('*', '**'): + # scanner returns '**' as a single token + is_ptrptr = s.sy == '**' + s.next() + + const_pos = s.position() + is_const = s.systring == 'const' and s.sy == 'IDENT' + if is_const: + s.next() + + base = p_c_declarator(s, ctx, empty=empty, is_type=is_type, + cmethod_flag=cmethod_flag, + assignable=assignable, nonempty=nonempty) + if is_const: + base = Nodes.CConstDeclaratorNode(const_pos, base=base) + if is_ptrptr: + base = Nodes.CPtrDeclaratorNode(pos, base=base) + result = Nodes.CPtrDeclaratorNode(pos, base=base) + elif s.sy == '&' or (s.sy == '&&' and s.context.cpp): + node_class = Nodes.CppRvalueReferenceDeclaratorNode if s.sy == '&&' else Nodes.CReferenceDeclaratorNode + s.next() + base = p_c_declarator(s, ctx, empty=empty, is_type=is_type, + cmethod_flag=cmethod_flag, + assignable=assignable, nonempty=nonempty) + result = node_class(pos, base=base) + else: + rhs = None + if s.sy == 'IDENT': + name = s.systring + if empty: + error(s.position(), "Declarator should be empty") + s.next() + cname = p_opt_cname(s) + if name != 'operator' and s.sy == '=' and assignable: + s.next() + rhs = p_test(s) + else: + if nonempty: + error(s.position(), "Empty declarator") + name = "" + cname = None + if cname is None and ctx.namespace is not None and nonempty: + cname = ctx.namespace + "::" + name + if name == 'operator' and ctx.visibility == 'extern' and nonempty: + op = s.sy + if [1 for c in op if c in '+-*/<=>!%&|([^~,']: + s.next() + # Handle diphthong operators. + if op == '(': + s.expect(')') + op = '()' + elif op == '[': + s.expect(']') + op = '[]' + elif op in ('-', '+', '|', '&') and s.sy == op: + op *= 2 # ++, --, ... + s.next() + elif s.sy == '=': + op += s.sy # +=, -=, ... + s.next() + if op not in supported_overloaded_operators: + s.error("Overloading operator '%s' not yet supported." % op, + fatal=False) + name += op + elif op == 'IDENT': + op = s.systring + if op not in supported_overloaded_operators: + s.error("Overloading operator '%s' not yet supported." % op, + fatal=False) + name = name + ' ' + op + s.next() + result = Nodes.CNameDeclaratorNode(pos, + name = name, cname = cname, default = rhs) + result.calling_convention = calling_convention + return result + +def p_nogil(s): + if s.sy == 'IDENT' and s.systring == 'nogil': + s.next() + return 1 + else: + return 0 + +def p_with_gil(s): + if s.sy == 'with': + s.next() + s.expect_keyword('gil') + return 1 + else: + return 0 + +def p_exception_value_clause(s, is_extern): + """ + Parse exception value clause. + + Maps clauses to exc_check / exc_value / exc_clause as follows: + ______________________________________________________________________ + | | | | | + | Clause | exc_check | exc_value | exc_clause | + | ___________________________ | ___________ | ___________ | __________ | + | | | | | + | (default func.) | True | None | False | + | (cdef extern) | False | None | False | + | noexcept | False | None | True | + | except | False | | True | + | except? | True | | True | + | except * | True | None | True | + | except + | '+' | None | True | + | except +* | '+' | '*' | True | + | except + | '+' | | True | + | ___________________________ | ___________ | ___________ | __________ | + + Note that the only reason we need `exc_clause` is to raise a + warning when `'except'` or `'noexcept'` is placed after the + `'nogil'` keyword. + """ + exc_clause = False + exc_val = None + exc_check = False if is_extern else True + + if s.sy == 'IDENT' and s.systring == 'noexcept': + exc_clause = True + s.next() + exc_check = False + elif s.sy == 'except': + exc_clause = True + s.next() + if s.sy == '*': + exc_check = True + s.next() + elif s.sy == '+': + exc_check = '+' + plus_char_pos = s.position()[2] + s.next() + if s.sy == 'IDENT': + name = s.systring + if name == 'nogil': + if s.position()[2] == plus_char_pos + 1: + error(s.position(), + "'except +nogil' defines an exception handling function. Use 'except + nogil' for the 'nogil' modifier.") + # 'except + nogil' is parsed outside + else: + exc_val = p_name(s, name) + s.next() + elif s.sy == '*': + exc_val = ExprNodes.CharNode(s.position(), value=u'*') + s.next() + else: + if s.sy == '?': + exc_check = True + s.next() + else: + exc_check = False + # exc_val can be non-None even if exc_check is False, c.f. "except -1" + exc_val = p_test(s) + + return exc_val, exc_check, exc_clause + +c_arg_list_terminators = cython.declare(frozenset, frozenset(( + '*', '**', '...', ')', ':', '/'))) + +def p_c_arg_list(s, ctx = Ctx(), in_pyfunc = 0, cmethod_flag = 0, + nonempty_declarators = 0, kw_only = 0, annotated = 1): + # Comma-separated list of C argument declarations, possibly empty. + # May have a trailing comma. + args = [] + is_self_arg = cmethod_flag + while s.sy not in c_arg_list_terminators: + args.append(p_c_arg_decl(s, ctx, in_pyfunc, is_self_arg, + nonempty = nonempty_declarators, kw_only = kw_only, + annotated = annotated)) + if s.sy != ',': + break + s.next() + is_self_arg = 0 + return args + +def p_optional_ellipsis(s): + if s.sy == '...': + expect_ellipsis(s) + return 1 + else: + return 0 + +def p_c_arg_decl(s, ctx, in_pyfunc, cmethod_flag = 0, nonempty = 0, + kw_only = 0, annotated = 1): + pos = s.position() + not_none = or_none = 0 + default = None + annotation = None + if s.in_python_file: + # empty type declaration + base_type = Nodes.CSimpleBaseTypeNode(pos, + name = None, module_path = [], + is_basic_c_type = 0, signed = 0, + complex = 0, longness = 0, + is_self_arg = cmethod_flag, templates = None) + else: + base_type = p_c_base_type(s, nonempty=nonempty) + declarator = p_c_declarator(s, ctx, nonempty = nonempty) + if s.sy in ('not', 'or') and not s.in_python_file: + kind = s.sy + s.next() + if s.sy == 'IDENT' and s.systring == 'None': + s.next() + else: + s.error("Expected 'None'") + if not in_pyfunc: + error(pos, "'%s None' only allowed in Python functions" % kind) + or_none = kind == 'or' + not_none = kind == 'not' + if annotated and s.sy == ':': + s.next() + annotation = p_annotation(s) + if s.sy == '=': + s.next() + if 'pxd' in ctx.level: + if s.sy in ['*', '?']: + # TODO(github/1736): Make this an error for inline declarations. + default = ExprNodes.NoneNode(pos) + s.next() + elif 'inline' in ctx.modifiers: + default = p_test(s) + else: + error(pos, "default values cannot be specified in pxd files, use ? or *") + else: + default = p_test(s) + return Nodes.CArgDeclNode(pos, + base_type = base_type, + declarator = declarator, + not_none = not_none, + or_none = or_none, + default = default, + annotation = annotation, + kw_only = kw_only) + +def p_api(s): + if s.sy == 'IDENT' and s.systring == 'api': + s.next() + return 1 + else: + return 0 + +def p_cdef_statement(s, ctx): + pos = s.position() + ctx.visibility = p_visibility(s, ctx.visibility) + ctx.api = ctx.api or p_api(s) + if ctx.api: + if ctx.visibility not in ('private', 'public'): + error(pos, "Cannot combine 'api' with '%s'" % ctx.visibility) + if (ctx.visibility == 'extern') and s.sy == 'from': + return p_cdef_extern_block(s, pos, ctx) + elif s.sy == 'import': + s.next() + return p_cdef_extern_block(s, pos, ctx) + elif p_nogil(s): + ctx.nogil = 1 + if ctx.overridable: + error(pos, "cdef blocks cannot be declared cpdef") + return p_cdef_block(s, ctx) + elif s.sy == ':': + if ctx.overridable: + error(pos, "cdef blocks cannot be declared cpdef") + return p_cdef_block(s, ctx) + elif s.sy == 'class': + if ctx.level not in ('module', 'module_pxd'): + error(pos, "Extension type definition not allowed here") + if ctx.overridable: + error(pos, "Extension types cannot be declared cpdef") + return p_c_class_definition(s, pos, ctx) + elif s.sy == 'IDENT' and s.systring == 'cppclass': + return p_cpp_class_definition(s, pos, ctx) + elif s.sy == 'IDENT' and s.systring in struct_enum_union: + if ctx.level not in ('module', 'module_pxd'): + error(pos, "C struct/union/enum definition not allowed here") + if ctx.overridable: + if s.systring != 'enum': + error(pos, "C struct/union cannot be declared cpdef") + return p_struct_enum(s, pos, ctx) + elif s.sy == 'IDENT' and s.systring == 'fused': + return p_fused_definition(s, pos, ctx) + else: + return p_c_func_or_var_declaration(s, pos, ctx) + +def p_cdef_block(s, ctx): + return p_suite(s, ctx(cdef_flag = 1)) + +def p_cdef_extern_block(s, pos, ctx): + if ctx.overridable: + error(pos, "cdef extern blocks cannot be declared cpdef") + include_file = None + s.expect('from') + if s.sy == '*': + s.next() + else: + include_file = p_string_literal(s, 'u')[2] + ctx = ctx(cdef_flag = 1, visibility = 'extern') + if s.systring == "namespace": + s.next() + ctx.namespace = p_string_literal(s, 'u')[2] + if p_nogil(s): + ctx.nogil = 1 + + # Use "docstring" as verbatim string to include + verbatim_include, body = p_suite_with_docstring(s, ctx, True) + + return Nodes.CDefExternNode(pos, + include_file = include_file, + verbatim_include = verbatim_include, + body = body, + namespace = ctx.namespace) + +def p_c_enum_definition(s, pos, ctx): + # s.sy == ident 'enum' + s.next() + + scoped = False + if s.context.cpp and (s.sy == 'class' or (s.sy == 'IDENT' and s.systring == 'struct')): + scoped = True + s.next() + + if s.sy == 'IDENT': + name = s.systring + s.next() + cname = p_opt_cname(s) + if cname is None and ctx.namespace is not None: + cname = ctx.namespace + "::" + name + else: + name = cname = None + if scoped: + s.error("Unnamed scoped enum not allowed") + + if scoped and s.sy == '(': + s.next() + underlying_type = p_c_base_type(s) + s.expect(')') + else: + underlying_type = Nodes.CSimpleBaseTypeNode( + pos, + name="int", + module_path = [], + is_basic_c_type = True, + signed = 1, + complex = 0, + longness = 0 + ) + + s.expect(':') + items = [] + + doc = None + if s.sy != 'NEWLINE': + p_c_enum_line(s, ctx, items) + else: + s.next() # 'NEWLINE' + s.expect_indent() + doc = p_doc_string(s) + + while s.sy not in ('DEDENT', 'EOF'): + p_c_enum_line(s, ctx, items) + + s.expect_dedent() + + if not items and ctx.visibility != "extern": + error(pos, "Empty enum definition not allowed outside a 'cdef extern from' block") + + return Nodes.CEnumDefNode( + pos, name=name, cname=cname, + scoped=scoped, items=items, + underlying_type=underlying_type, + typedef_flag=ctx.typedef_flag, visibility=ctx.visibility, + create_wrapper=ctx.overridable, + api=ctx.api, in_pxd=ctx.level == 'module_pxd', doc=doc) + +def p_c_enum_line(s, ctx, items): + if s.sy != 'pass': + p_c_enum_item(s, ctx, items) + while s.sy == ',': + s.next() + if s.sy in ('NEWLINE', 'EOF'): + break + p_c_enum_item(s, ctx, items) + else: + s.next() + s.expect_newline("Syntax error in enum item list") + +def p_c_enum_item(s, ctx, items): + pos = s.position() + name = p_ident(s) + cname = p_opt_cname(s) + if cname is None and ctx.namespace is not None: + cname = ctx.namespace + "::" + name + value = None + if s.sy == '=': + s.next() + value = p_test(s) + items.append(Nodes.CEnumDefItemNode(pos, + name = name, cname = cname, value = value)) + +def p_c_struct_or_union_definition(s, pos, ctx): + packed = False + if s.systring == 'packed': + packed = True + s.next() + if s.sy != 'IDENT' or s.systring != 'struct': + s.expected('struct') + # s.sy == ident 'struct' or 'union' + kind = s.systring + s.next() + name = p_ident(s) + cname = p_opt_cname(s) + if cname is None and ctx.namespace is not None: + cname = ctx.namespace + "::" + name + attributes = None + if s.sy == ':': + s.next() + attributes = [] + if s.sy == 'pass': + s.next() + s.expect_newline("Expected a newline", ignore_semicolon=True) + else: + s.expect('NEWLINE') + s.expect_indent() + body_ctx = Ctx(visibility=ctx.visibility) + while s.sy != 'DEDENT': + if s.sy != 'pass': + attributes.append( + p_c_func_or_var_declaration(s, s.position(), body_ctx)) + else: + s.next() + s.expect_newline("Expected a newline") + s.expect_dedent() + + if not attributes and ctx.visibility != "extern": + error(pos, "Empty struct or union definition not allowed outside a 'cdef extern from' block") + else: + s.expect_newline("Syntax error in struct or union definition") + + return Nodes.CStructOrUnionDefNode(pos, + name = name, cname = cname, kind = kind, attributes = attributes, + typedef_flag = ctx.typedef_flag, visibility = ctx.visibility, + api = ctx.api, in_pxd = ctx.level == 'module_pxd', packed = packed) + +def p_fused_definition(s, pos, ctx): + """ + c(type)def fused my_fused_type: + ... + """ + # s.systring == 'fused' + + if ctx.level not in ('module', 'module_pxd'): + error(pos, "Fused type definition not allowed here") + + s.next() + name = p_ident(s) + + s.expect(":") + s.expect_newline() + s.expect_indent() + + types = [] + while s.sy != 'DEDENT': + if s.sy != 'pass': + #types.append(p_c_declarator(s)) + types.append(p_c_base_type(s)) #, nonempty=1)) + else: + s.next() + + s.expect_newline() + + s.expect_dedent() + + if not types: + error(pos, "Need at least one type") + + return Nodes.FusedTypeNode(pos, name=name, types=types) + +def p_struct_enum(s, pos, ctx): + if s.systring == 'enum': + return p_c_enum_definition(s, pos, ctx) + else: + return p_c_struct_or_union_definition(s, pos, ctx) + +def p_visibility(s, prev_visibility): + pos = s.position() + visibility = prev_visibility + if s.sy == 'IDENT' and s.systring in ('extern', 'public', 'readonly'): + visibility = s.systring + if prev_visibility != 'private' and visibility != prev_visibility: + s.error("Conflicting visibility options '%s' and '%s'" + % (prev_visibility, visibility), fatal=False) + s.next() + return visibility + +def p_c_modifiers(s): + if s.sy == 'IDENT' and s.systring in ('inline',): + modifier = s.systring + s.next() + return [modifier] + p_c_modifiers(s) + return [] + +def p_c_func_or_var_declaration(s, pos, ctx): + cmethod_flag = ctx.level in ('c_class', 'c_class_pxd') + modifiers = p_c_modifiers(s) + base_type = p_c_base_type(s, nonempty = 1, templates = ctx.templates) + declarator = p_c_declarator(s, ctx(modifiers=modifiers), cmethod_flag = cmethod_flag, + assignable = 1, nonempty = 1) + declarator.overridable = ctx.overridable + if s.sy == 'IDENT' and s.systring == 'const' and ctx.level == 'cpp_class': + s.next() + is_const_method = 1 + else: + is_const_method = 0 + if s.sy == '->': + # Special enough to give a better error message and keep going. + s.error( + "Return type annotation is not allowed in cdef/cpdef signatures. " + "Please define it before the function name, as in C signatures.", + fatal=False) + s.next() + p_test(s) # Keep going, but ignore result. + if s.sy == ':': + if ctx.level not in ('module', 'c_class', 'module_pxd', 'c_class_pxd', 'cpp_class') and not ctx.templates: + s.error("C function definition not allowed here") + doc, suite = p_suite_with_docstring(s, Ctx(level='function')) + result = Nodes.CFuncDefNode(pos, + visibility = ctx.visibility, + base_type = base_type, + declarator = declarator, + body = suite, + doc = doc, + modifiers = modifiers, + api = ctx.api, + overridable = ctx.overridable, + is_const_method = is_const_method) + else: + #if api: + # s.error("'api' not allowed with variable declaration") + if is_const_method: + declarator.is_const_method = is_const_method + declarators = [declarator] + while s.sy == ',': + s.next() + if s.sy == 'NEWLINE': + break + declarator = p_c_declarator(s, ctx, cmethod_flag = cmethod_flag, + assignable = 1, nonempty = 1) + declarators.append(declarator) + doc_line = s.start_line + 1 + s.expect_newline("Syntax error in C variable declaration", ignore_semicolon=True) + if ctx.level in ('c_class', 'c_class_pxd') and s.start_line == doc_line: + doc = p_doc_string(s) + else: + doc = None + result = Nodes.CVarDefNode(pos, + visibility = ctx.visibility, + base_type = base_type, + declarators = declarators, + in_pxd = ctx.level in ('module_pxd', 'c_class_pxd'), + doc = doc, + api = ctx.api, + modifiers = modifiers, + overridable = ctx.overridable) + return result + +def p_ctypedef_statement(s, ctx): + # s.sy == 'ctypedef' + pos = s.position() + s.next() + visibility = p_visibility(s, ctx.visibility) + api = p_api(s) + ctx = ctx(typedef_flag = 1, visibility = visibility) + if api: + ctx.api = 1 + if s.sy == 'class': + return p_c_class_definition(s, pos, ctx) + elif s.sy == 'IDENT' and s.systring in struct_enum_union: + return p_struct_enum(s, pos, ctx) + elif s.sy == 'IDENT' and s.systring == 'fused': + return p_fused_definition(s, pos, ctx) + else: + base_type = p_c_base_type(s, nonempty = 1) + declarator = p_c_declarator(s, ctx, is_type = 1, nonempty = 1) + s.expect_newline("Syntax error in ctypedef statement", ignore_semicolon=True) + return Nodes.CTypeDefNode( + pos, base_type = base_type, + declarator = declarator, + visibility = visibility, api = api, + in_pxd = ctx.level == 'module_pxd') + +def p_decorators(s): + decorators = [] + while s.sy == '@': + pos = s.position() + s.next() + decorator = p_namedexpr_test(s) + decorators.append(Nodes.DecoratorNode(pos, decorator=decorator)) + s.expect_newline("Expected a newline after decorator") + return decorators + + +def _reject_cdef_modifier_in_py(s, name): + """Step over incorrectly placed cdef modifiers (@see _CDEF_MODIFIERS) to provide a good error message for them. + """ + if s.sy == 'IDENT' and name in _CDEF_MODIFIERS: + # Special enough to provide a good error message. + s.error("Cannot use cdef modifier '%s' in Python function signature. Use a decorator instead." % name, fatal=False) + return p_ident(s) # Keep going, in case there are other errors. + return name + + +def p_def_statement(s, decorators=None, is_async_def=False): + # s.sy == 'def' + pos = decorators[0].pos if decorators else s.position() + # PEP 492 switches the async/await keywords on in "async def" functions + if is_async_def: + s.enter_async() + s.next() + name = _reject_cdef_modifier_in_py(s, p_ident(s)) + s.expect( + '(', + "Expected '(', found '%s'. Did you use cdef syntax in a Python declaration? " + "Use decorators and Python type annotations instead." % ( + s.systring if s.sy == 'IDENT' else s.sy)) + args, star_arg, starstar_arg = p_varargslist(s, terminator=')') + s.expect(')') + _reject_cdef_modifier_in_py(s, s.systring) + return_type_annotation = None + if s.sy == '->': + s.next() + return_type_annotation = p_annotation(s) + _reject_cdef_modifier_in_py(s, s.systring) + + doc, body = p_suite_with_docstring(s, Ctx(level='function')) + if is_async_def: + s.exit_async() + + return Nodes.DefNode( + pos, name=name, args=args, star_arg=star_arg, starstar_arg=starstar_arg, + doc=doc, body=body, decorators=decorators, is_async_def=is_async_def, + return_type_annotation=return_type_annotation) + + +def p_varargslist(s, terminator=')', annotated=1): + args = p_c_arg_list(s, in_pyfunc = 1, nonempty_declarators = 1, + annotated = annotated) + star_arg = None + starstar_arg = None + if s.sy == '/': + if len(args) == 0: + s.error("Got zero positional-only arguments despite presence of " + "positional-only specifier '/'") + s.next() + # Mark all args to the left as pos only + for arg in args: + arg.pos_only = 1 + if s.sy == ',': + s.next() + args.extend(p_c_arg_list(s, in_pyfunc = 1, + nonempty_declarators = 1, annotated = annotated)) + elif s.sy != terminator: + s.error("Syntax error in Python function argument list") + if s.sy == '*': + s.next() + if s.sy == 'IDENT': + star_arg = p_py_arg_decl(s, annotated=annotated) + if s.sy == ',': + s.next() + args.extend(p_c_arg_list(s, in_pyfunc = 1, + nonempty_declarators = 1, kw_only = 1, annotated = annotated)) + elif s.sy != terminator: + s.error("Syntax error in Python function argument list") + if s.sy == '**': + s.next() + starstar_arg = p_py_arg_decl(s, annotated=annotated) + if s.sy == ',': + s.next() + return (args, star_arg, starstar_arg) + +def p_py_arg_decl(s, annotated = 1): + pos = s.position() + name = p_ident(s) + annotation = None + if annotated and s.sy == ':': + s.next() + annotation = p_annotation(s) + return Nodes.PyArgDeclNode(pos, name = name, annotation = annotation) + + +def p_class_statement(s, decorators): + # s.sy == 'class' + pos = s.position() + s.next() + class_name = EncodedString(p_ident(s)) + class_name.encoding = s.source_encoding # FIXME: why is this needed? + arg_tuple = None + keyword_dict = None + if s.sy == '(': + positional_args, keyword_args = p_call_parse_args(s, allow_genexp=False) + arg_tuple, keyword_dict = p_call_build_packed_args(pos, positional_args, keyword_args) + if arg_tuple is None: + # XXX: empty arg_tuple + arg_tuple = ExprNodes.TupleNode(pos, args=[]) + doc, body = p_suite_with_docstring(s, Ctx(level='class')) + return Nodes.PyClassDefNode( + pos, name=class_name, + bases=arg_tuple, + keyword_args=keyword_dict, + doc=doc, body=body, decorators=decorators, + force_py3_semantics=s.context.language_level >= 3) + + +def p_c_class_definition(s, pos, ctx): + # s.sy == 'class' + s.next() + module_path = [] + class_name = p_ident(s) + while s.sy == '.': + s.next() + module_path.append(class_name) + class_name = p_ident(s) + if module_path and ctx.visibility != 'extern': + error(pos, "Qualified class name only allowed for 'extern' C class") + if module_path and s.sy == 'IDENT' and s.systring == 'as': + s.next() + as_name = p_ident(s) + else: + as_name = class_name + objstruct_name = None + typeobj_name = None + bases = None + check_size = None + if s.sy == '(': + positional_args, keyword_args = p_call_parse_args(s, allow_genexp=False) + if keyword_args: + s.error("C classes cannot take keyword bases.") + bases, _ = p_call_build_packed_args(pos, positional_args, keyword_args) + if bases is None: + bases = ExprNodes.TupleNode(pos, args=[]) + + if s.sy == '[': + if ctx.visibility not in ('public', 'extern') and not ctx.api: + error(s.position(), "Name options only allowed for 'public', 'api', or 'extern' C class") + objstruct_name, typeobj_name, check_size = p_c_class_options(s) + if s.sy == ':': + if ctx.level == 'module_pxd': + body_level = 'c_class_pxd' + else: + body_level = 'c_class' + doc, body = p_suite_with_docstring(s, Ctx(level=body_level)) + else: + s.expect_newline("Syntax error in C class definition") + doc = None + body = None + if ctx.visibility == 'extern': + if not module_path: + error(pos, "Module name required for 'extern' C class") + if typeobj_name: + error(pos, "Type object name specification not allowed for 'extern' C class") + elif ctx.visibility == 'public': + if not objstruct_name: + error(pos, "Object struct name specification required for 'public' C class") + if not typeobj_name: + error(pos, "Type object name specification required for 'public' C class") + elif ctx.visibility == 'private': + if ctx.api: + if not objstruct_name: + error(pos, "Object struct name specification required for 'api' C class") + if not typeobj_name: + error(pos, "Type object name specification required for 'api' C class") + else: + error(pos, "Invalid class visibility '%s'" % ctx.visibility) + return Nodes.CClassDefNode(pos, + visibility = ctx.visibility, + typedef_flag = ctx.typedef_flag, + api = ctx.api, + module_name = ".".join(module_path), + class_name = class_name, + as_name = as_name, + bases = bases, + objstruct_name = objstruct_name, + typeobj_name = typeobj_name, + check_size = check_size, + in_pxd = ctx.level == 'module_pxd', + doc = doc, + body = body) + + +def p_c_class_options(s): + objstruct_name = None + typeobj_name = None + check_size = None + s.expect('[') + while 1: + if s.sy != 'IDENT': + break + if s.systring == 'object': + s.next() + objstruct_name = p_ident(s) + elif s.systring == 'type': + s.next() + typeobj_name = p_ident(s) + elif s.systring == 'check_size': + s.next() + check_size = p_ident(s) + if check_size not in ('ignore', 'warn', 'error'): + s.error("Expected one of ignore, warn or error, found %r" % check_size) + if s.sy != ',': + break + s.next() + s.expect(']', "Expected 'object', 'type' or 'check_size'") + return objstruct_name, typeobj_name, check_size + + +def p_property_decl(s): + pos = s.position() + s.next() # 'property' + name = p_ident(s) + doc, body = p_suite_with_docstring( + s, Ctx(level='property'), with_doc_only=True) + return Nodes.PropertyNode(pos, name=name, doc=doc, body=body) + + +def p_ignorable_statement(s): + """ + Parses any kind of ignorable statement that is allowed in .pxd files. + """ + if s.sy == 'BEGIN_STRING': + pos = s.position() + string_node = p_atom(s) + s.expect_newline("Syntax error in string", ignore_semicolon=True) + return Nodes.ExprStatNode(pos, expr=string_node) + return None + + +def p_doc_string(s): + if s.sy == 'BEGIN_STRING': + pos = s.position() + kind, bytes_result, unicode_result = p_cat_string_literal(s) + s.expect_newline("Syntax error in doc string", ignore_semicolon=True) + if kind in ('u', ''): + return unicode_result + warning(pos, "Python 3 requires docstrings to be unicode strings") + return bytes_result + else: + return None + + +def _extract_docstring(node): + """ + Extract a docstring from a statement or from the first statement + in a list. Remove the statement if found. Return a tuple + (plain-docstring or None, node). + """ + doc_node = None + if node is None: + pass + elif isinstance(node, Nodes.ExprStatNode): + if node.expr.is_string_literal: + doc_node = node.expr + node = Nodes.StatListNode(node.pos, stats=[]) + elif isinstance(node, Nodes.StatListNode) and node.stats: + stats = node.stats + if isinstance(stats[0], Nodes.ExprStatNode): + if stats[0].expr.is_string_literal: + doc_node = stats[0].expr + del stats[0] + + if doc_node is None: + doc = None + elif isinstance(doc_node, ExprNodes.BytesNode): + warning(node.pos, + "Python 3 requires docstrings to be unicode strings") + doc = doc_node.value + elif isinstance(doc_node, ExprNodes.StringNode): + doc = doc_node.unicode_value + if doc is None: + doc = doc_node.value + else: + doc = doc_node.value + return doc, node + + +def p_code(s, level=None, ctx=Ctx): + body = p_statement_list(s, ctx(level = level), first_statement = 1) + if s.sy != 'EOF': + s.error("Syntax error in statement [%s,%s]" % ( + repr(s.sy), repr(s.systring))) + return body + + +_match_compiler_directive_comment = cython.declare(object, re.compile( + r"^#\s*cython\s*:\s*((\w|[.])+\s*=.*)$").match) + + +def p_compiler_directive_comments(s): + result = {} + while s.sy == 'commentline': + pos = s.position() + m = _match_compiler_directive_comment(s.systring) + if m: + directives_string = m.group(1).strip() + try: + new_directives = Options.parse_directive_list(directives_string, ignore_unknown=True) + except ValueError as e: + s.error(e.args[0], fatal=False) + s.next() + continue + + for name in new_directives: + if name not in result: + pass + elif Options.directive_types.get(name) is list: + result[name] += new_directives[name] + new_directives[name] = result[name] + elif new_directives[name] == result[name]: + warning(pos, "Duplicate directive found: %s" % (name,)) + else: + s.error("Conflicting settings found for top-level directive %s: %r and %r" % ( + name, result[name], new_directives[name]), pos=pos) + + if 'language_level' in new_directives: + # Make sure we apply the language level already to the first token that follows the comments. + s.context.set_language_level(new_directives['language_level']) + if 'legacy_implicit_noexcept' in new_directives: + s.context.legacy_implicit_noexcept = new_directives['legacy_implicit_noexcept'] + + + result.update(new_directives) + + s.next() + return result + + +def p_module(s, pxd, full_module_name, ctx=Ctx): + pos = s.position() + + directive_comments = p_compiler_directive_comments(s) + s.parse_comments = False + + if s.context.language_level is None: + s.context.set_language_level('3str') + if pos[0].filename: + import warnings + warnings.warn( + "Cython directive 'language_level' not set, using '3str' for now (Py3). " + "This has changed from earlier releases! File: %s" % pos[0].filename, + FutureWarning, + stacklevel=1 if cython.compiled else 2, + ) + + level = 'module_pxd' if pxd else 'module' + doc = p_doc_string(s) + body = p_statement_list(s, ctx(level=level), first_statement = 1) + if s.sy != 'EOF': + s.error("Syntax error in statement [%s,%s]" % ( + repr(s.sy), repr(s.systring))) + return ModuleNode(pos, doc = doc, body = body, + full_module_name = full_module_name, + directive_comments = directive_comments) + +def p_template_definition(s): + name = p_ident(s) + if s.sy == '=': + s.expect('=') + s.expect('*') + required = False + else: + required = True + return name, required + +def p_cpp_class_definition(s, pos, ctx): + # s.sy == 'cppclass' + s.next() + class_name = p_ident(s) + cname = p_opt_cname(s) + if cname is None and ctx.namespace is not None: + cname = ctx.namespace + "::" + class_name + if s.sy == '.': + error(pos, "Qualified class name not allowed C++ class") + if s.sy == '[': + s.next() + templates = [p_template_definition(s)] + while s.sy == ',': + s.next() + templates.append(p_template_definition(s)) + s.expect(']') + template_names = [name for name, required in templates] + else: + templates = None + template_names = None + if s.sy == '(': + s.next() + base_classes = [p_c_base_type(s, templates = template_names)] + while s.sy == ',': + s.next() + base_classes.append(p_c_base_type(s, templates = template_names)) + s.expect(')') + else: + base_classes = [] + if s.sy == '[': + error(s.position(), "Name options not allowed for C++ class") + nogil = p_nogil(s) + if s.sy == ':': + s.next() + s.expect('NEWLINE') + s.expect_indent() + # Allow a cppclass to have docstrings. It will be discarded as comment. + # The goal of this is consistency: we can make docstrings inside cppclass methods, + # so why not on the cppclass itself ? + p_doc_string(s) + attributes = [] + body_ctx = Ctx(visibility = ctx.visibility, level='cpp_class', nogil=nogil or ctx.nogil) + body_ctx.templates = template_names + while s.sy != 'DEDENT': + if s.sy != 'pass': + attributes.append(p_cpp_class_attribute(s, body_ctx)) + else: + s.next() + s.expect_newline("Expected a newline") + s.expect_dedent() + else: + attributes = None + s.expect_newline("Syntax error in C++ class definition") + return Nodes.CppClassNode(pos, + name = class_name, + cname = cname, + base_classes = base_classes, + visibility = ctx.visibility, + in_pxd = ctx.level == 'module_pxd', + attributes = attributes, + templates = templates) + +def p_cpp_class_attribute(s, ctx): + decorators = None + if s.sy == '@': + decorators = p_decorators(s) + if s.systring == 'cppclass': + return p_cpp_class_definition(s, s.position(), ctx) + elif s.systring == 'ctypedef': + return p_ctypedef_statement(s, ctx) + elif s.sy == 'IDENT' and s.systring in struct_enum_union: + if s.systring != 'enum': + return p_cpp_class_definition(s, s.position(), ctx) + else: + return p_struct_enum(s, s.position(), ctx) + else: + node = p_c_func_or_var_declaration(s, s.position(), ctx) + if decorators is not None: + tup = Nodes.CFuncDefNode, Nodes.CVarDefNode, Nodes.CClassDefNode + if ctx.allow_struct_enum_decorator: + tup += Nodes.CStructOrUnionDefNode, Nodes.CEnumDefNode + if not isinstance(node, tup): + s.error("Decorators can only be followed by functions or classes") + node.decorators = decorators + return node + + +#---------------------------------------------- +# +# Debugging +# +#---------------------------------------------- + +def print_parse_tree(f, node, level, key = None): + ind = " " * level + if node: + f.write(ind) + if key: + f.write("%s: " % key) + t = type(node) + if t is tuple: + f.write("(%s @ %s\n" % (node[0], node[1])) + for i in range(2, len(node)): + print_parse_tree(f, node[i], level+1) + f.write("%s)\n" % ind) + return + elif isinstance(node, Nodes.Node): + try: + tag = node.tag + except AttributeError: + tag = node.__class__.__name__ + f.write("%s @ %s\n" % (tag, node.pos)) + for name, value in node.__dict__.items(): + if name != 'tag' and name != 'pos': + print_parse_tree(f, value, level+1, name) + return + elif t is list: + f.write("[\n") + for i in range(len(node)): + print_parse_tree(f, node[i], level+1) + f.write("%s]\n" % ind) + return + f.write("%s%s\n" % (ind, node)) + +def p_annotation(s): + """An annotation just has the "test" syntax, but also stores the string it came from + + Note that the string is *allowed* to be changed/processed (although isn't here) + so may not exactly match the string generated by Python, and if it doesn't + then it is not a bug. + """ + pos = s.position() + expr = p_test(s) + return ExprNodes.AnnotationNode(pos, expr=expr) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fa81adaff68e06d8e915a6afa375f62f7e5a8fad --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__init__.py @@ -0,0 +1 @@ +# empty file diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Lexicon.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Lexicon.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f1c6061ad06428b05559670dd8568f60a7e7def Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Lexicon.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Naming.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Naming.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75502300556ac10b07930c479c861d7e6cfbec2e Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Naming.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Options.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Options.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b051ac498688f4614b4ab85dc5ad14fb6a8bbdb Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Options.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Pythran.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Pythran.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83611fe420a21a7c021b48549f14ca4f94508543 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Pythran.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Scanning.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Scanning.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f359751f327c8bc5fed253236a3e3ae6446fc65b Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Scanning.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/UtilityCode.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/UtilityCode.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa0afde4b72f14dfad777febce85eb96860ea5f0 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/UtilityCode.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/networkx/generators/atlas.dat.gz b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/networkx/generators/atlas.dat.gz new file mode 100644 index 0000000000000000000000000000000000000000..2cb5d1308c1168df6218cee0a97552c5787a29b8 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/networkx/generators/atlas.dat.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:73fc416df0164923607751cb759f4ae81deb5f6550bf25be59c86de3b747e41d +size 8887 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pyproject_hooks/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pyproject_hooks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ddfcf7f72f31658d75c8128de0732fbbf0e12b15 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pyproject_hooks/__init__.py @@ -0,0 +1,23 @@ +"""Wrappers to call pyproject.toml-based build backend hooks. +""" + +from ._impl import ( + BackendInvalid, + BackendUnavailable, + BuildBackendHookCaller, + HookMissing, + UnsupportedOperation, + default_subprocess_runner, + quiet_subprocess_runner, +) + +__version__ = '1.0.0' +__all__ = [ + 'BackendUnavailable', + 'BackendInvalid', + 'HookMissing', + 'UnsupportedOperation', + 'default_subprocess_runner', + 'quiet_subprocess_runner', + 'BuildBackendHookCaller', +] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pyproject_hooks/__pycache__/_compat.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pyproject_hooks/__pycache__/_compat.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..edc7327ae20f1ccfe93039245705e0d9899f5f2e Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pyproject_hooks/__pycache__/_compat.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pyproject_hooks/_in_process/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pyproject_hooks/_in_process/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..917fa065b3c7feccdef5bc666a5109c855217260 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pyproject_hooks/_in_process/__init__.py @@ -0,0 +1,18 @@ +"""This is a subpackage because the directory is on sys.path for _in_process.py + +The subpackage should stay as empty as possible to avoid shadowing modules that +the backend might import. +""" + +import importlib.resources as resources + +try: + resources.files +except AttributeError: + # Python 3.8 compatibility + def _in_proc_script_path(): + return resources.path(__package__, '_in_process.py') +else: + def _in_proc_script_path(): + return resources.as_file( + resources.files(__package__).joinpath('_in_process.py')) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pyproject_hooks/_in_process/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pyproject_hooks/_in_process/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..296dcd39ebf538f45820b7e8c3beea7414e8ba6f Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pyproject_hooks/_in_process/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/code_template.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/code_template.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eca039512128c33af66fda9d817a9504a0787129 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/code_template.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/context.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/context.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47aceda4086d7b3f9e5b4b26043a6f7083def958 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/context.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_aoti_c_shim.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_aoti_c_shim.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0897a5819d57b13541ae09e69dc123e6797579f1 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_aoti_c_shim.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_backend_stubs.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_backend_stubs.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8c2803d3b6a5a83f580f6472ed88b405ef4b49a Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_backend_stubs.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_executorch.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_executorch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0b294d8ce6d38dac4b8877cfb03567e38b42c4e Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_executorch.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_functionalization_type.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_functionalization_type.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52abc3c12b37f3d84823e8f87d6f11cea7bbab58 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_functionalization_type.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_lazy_tensor.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_lazy_tensor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5768c8f281fc385d46bfd3b3c65b70b1ed2a3777 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_lazy_tensor.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_vmap_plumbing.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_vmap_plumbing.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18f20caea52b38f223369eae83bf2a74b12596af Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_vmap_plumbing.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/local.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/local.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a664ab5ee4cb9218478ed6cb69eeb0500421955 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/local.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/native_function_generation.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/native_function_generation.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..063067e75baabe0f4f925a91fc3d656a668604ff Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/native_function_generation.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/yaml_utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/yaml_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5aae42232d40df67d38cba73541daf15b078d4ad Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/yaml_utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0c684fc1915cb90e8e5458d1d8d3fad2b5ee13af --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/__init__.py @@ -0,0 +1,19 @@ +from .lazy_ir import ( + generate_non_native_lazy_ir_nodes as generate_non_native_lazy_ir_nodes, + GenLazyIR as GenLazyIR, + GenLazyNativeFuncDefinition as GenLazyNativeFuncDefinition, + GenLazyShapeInferenceDefinition as GenLazyShapeInferenceDefinition, +) +from .native_functions import ( + compute_native_function_declaration as compute_native_function_declaration, +) +from .register_dispatch_key import ( + gen_registration_headers as gen_registration_headers, + gen_registration_helpers as gen_registration_helpers, + RegisterDispatchKey as RegisterDispatchKey, +) +from .ufunc import ( + compute_ufunc_cpu as compute_ufunc_cpu, + compute_ufunc_cpu_kernel as compute_ufunc_cpu_kernel, + compute_ufunc_cuda as compute_ufunc_cuda, +) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/__pycache__/lazy_ts_lowering.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/__pycache__/lazy_ts_lowering.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c68161696fc49ff3bcb5f63708345ddb74a97f9 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/__pycache__/lazy_ts_lowering.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/__pycache__/ufunc.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/__pycache__/ufunc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6059400437235c1390e11d89e2344d509a94c764 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/__pycache__/ufunc.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/lazy_ir.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/lazy_ir.py new file mode 100644 index 0000000000000000000000000000000000000000..43cde1e04043afdd843b21a4113b358afb6db691 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/lazy_ir.py @@ -0,0 +1,707 @@ +import itertools +from abc import ABC +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torchgen.api.dispatcher as dispatcher +from torchgen.api.lazy import ( + getValueT, + isValueType, + LazyArgument, + LazyIrProperties, + LazyIrSchema, + tensorListValueT, +) +from torchgen.api.translate import translate +from torchgen.api.types import ( + BaseCType, + Binding, + deviceT, + DispatcherSignature, + kernel_signature, + NativeSignature, + OptionalCType, + VectorCType, +) +from torchgen.context import method_with_native_function +from torchgen.dest.lazy_ts_lowering import ts_lowering_body +from torchgen.model import ( + Argument, + BackendIndex, + BackendMetadata, + BaseTy, + BaseType, + FunctionSchema, + ListType, + NativeFunction, + NativeFunctionsGroup, +) + + +def node_ctor_arg_rvalue_string(arg: LazyArgument) -> str: + """ + Given a LazyArgument, + generate a c++ string for materializing an rvalue of that arg for passing into + a lazy Node constructor. + """ + + # TODO: Matching on CType seems wrong; should be matching on Type + if isValueType(arg.lazy_type): + if isinstance(arg.lazy_type, BaseCType): + if arg.is_wrapped_scalar: + return f"node_{arg.name}" + elif arg.lazy_type.type is tensorListValueT: + return f"lazy_{arg.name}_tensorlist" + elif arg.is_symint_or_list: + return f"GetSymIntValue({arg.name})" + return f"lazy_{arg.name}->GetIrValue()" + elif isinstance(arg.lazy_type, OptionalCType): + if arg.is_symint_or_list: + # TODO: I don't understand when you should put lazy_ in the name + # or not + return f"{arg.name} ? c10::make_optional(GetSymIntValue(*{arg.name})) : c10::nullopt" + elif arg.is_wrapped_scalar: + return f"node_{arg.name}" + return ( + f"lazy_{arg.name} ? " + f"c10::make_optional(lazy_{arg.name}->GetIrValue()) : " + "c10::nullopt" + ) + else: + raise AssertionError( + f"TODO not sure if there are other valid types to handle here ({arg.lazy_type})" + ) + else: + # NB: this is here because right now we aren't treating SymInt[] as a + # value type; when we do this needs to move above + # NB: we cannot test arg.lazy_type as we've already specified it is an + # int64_t and so we cannot distinguish between SymInt and int64_t + if isinstance(arg.orig_type, ListType) and arg.orig_type.elem == BaseType( + BaseTy.SymInt + ): + if arg.symint: + return f"GetSymIntArrayRefValue({arg.name})" + else: + return f"std::vector({arg.name}.begin(), {arg.name}.end())" + elif isinstance(arg.lazy_type, VectorCType) and isinstance( + arg.lazy_type.elem, BaseCType + ): + return f"std::vector<{arg.lazy_type.elem.type}>({arg.name}.begin(), {arg.name}.end())" + elif ( + isinstance(arg.lazy_type, OptionalCType) + and isinstance(arg.lazy_type.elem, VectorCType) + and isinstance(arg.lazy_type.elem.elem, BaseCType) + ): + return f"torch::lazy::ToOptionalVector<{arg.lazy_type.elem.elem.type}>({arg.name})" + else: + return f"{arg.name}" + + +def node_ctor_inputs(schema: LazyIrSchema) -> str: + """ + Produce a formatted string with the arguments as passed into the constructor of a node class. + """ + node_ctor_values = [ + node_ctor_arg_rvalue_string(arg) for arg in schema.filtered_args() + ] + return ", ".join(node_ctor_values) + + +def gen_fallback_code( + schema: LazyIrSchema, + sig: Union[DispatcherSignature, NativeSignature], + overload_name: str, +) -> str: + """ + Generate code that falls back to eager conditioned on a predicate + """ + dispatcher_sig = DispatcherSignature.from_schema(schema.func) + exprs = translate(sig.arguments(), dispatcher_sig.arguments()) + fallback_args = ",\n ".join([a.expr for a in exprs]) + if len(overload_name): + aten_op_str = f"ATEN_OP2({schema.aten_name}, {overload_name})" + else: + aten_op_str = f"ATEN_OP({schema.aten_name})" + return f""" + if (force_eager_fallback({aten_symbol(schema)})) {{ + return at::native::call_fallback_fn_symint<<c_eager_fallback, {aten_op_str}>::call( + {fallback_args} + ); + }} +""" + + +def aten_symbol(schema: LazyIrSchema) -> str: + missing_interned_strings = { + "sigmoid_backward", + } + if schema.aten_name in missing_interned_strings: + return f'c10::Symbol::fromQualString("aten::{schema.aten_name}")' + + if not schema.aten_name.startswith("at::"): + return f"at::aten::{schema.aten_name}" + else: + return schema.aten_name + + +# converts all tensor-like arguments to meta tensors. Returns: +# (1) a string containing all of the logic that does the conversions. +# (2) a context, to be used by translate(), with all of the relevant bindings. +def convert_to_meta_tensors(sig: DispatcherSignature) -> Tuple[str, List[Binding]]: + context: List[Binding] = [] + unwrapped_tensor_args: List[str] = [] + for arg in sig.arguments(): + if isinstance(arg.argument, Argument) and arg.argument.type.is_tensor_like(): + unwrapped_name = f"{arg.name}_meta" + unwrapped_tensor_args.append( + f"auto {unwrapped_name} = to_meta({arg.name});" + ) + context.append(arg.with_name(unwrapped_name)) + else: + context.append(arg) + unwrap_tensor_args_str = "\n ".join(unwrapped_tensor_args) + return unwrap_tensor_args_str, context + + +@dataclass(frozen=True) +class GenLazyIR(ABC): + backend_index: BackendIndex + backend_name: str + node_base: str + use_lazy_shape: bool + + @method_with_native_function + def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]: + func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func + metadata = self.backend_index.get_kernel( + f.functional if isinstance(f, NativeFunctionsGroup) else f + ) + schema = LazyIrSchema( + func, symint=metadata is not None and metadata.supports_symint() + ) + return self.gen(schema) + + # there is no lowering functionality generated unless this IR base class is subclassed and + # implemented as a backend-specific node + def lowering_function(self, schema: LazyIrSchema) -> str: + return "" + + def create_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str: + return "" + + def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str: + return f"""bool CanBeReused({node_ctor_args}) const {{ + return false; + }}""" + + def node_base_ctor_call(self, schema: LazyIrSchema) -> str: + value_args = schema.filtered_args(values=True, scalars=False) + # backends can customize the way the node base class constructor is called, + # as long as all of its arguments can be generated from information available from the schema + base_ctor_value_args_list = [] + for arg in value_args: + if isinstance(arg.lazy_type, (BaseCType, VectorCType)): + base_ctor_value_args_list.append(f"{arg.name}") + elif isinstance(arg.lazy_type, OptionalCType): + base_ctor_value_args_list.append(f"{arg.name}.value_or(kNullValue)") + else: + raise AssertionError( + f"Unsupported type ({arg.lazy_type}) - add support if necessary" + ) + base_ctor_value_args = ", ".join(base_ctor_value_args_list) + + scalar_args = schema.filtered_args(values=False, scalars=True) + + # Shape construction. + # Conditionally build shape depending on specified shape property + if schema.properties.ShapePrecompute: + shape_ctor_arg = "std::move(shapes)," + elif schema.properties.ShapeCompute: + shape_args = [a.name for a in value_args] + shape_args.extend(a.name for a in scalar_args) + shape_ctor_arg = f"compute_shape_{schema.name}({', '.join(shape_args)})," + elif schema.properties.ShapeCache: + shape_args = [f"operand({i})" for i in range(len(value_args))] + shape_args.extend(a.name for a in scalar_args) + shape_ctor_arg = f"[&](){{ return compute_shape_{schema.name}({', '.join(shape_args)})[0]; }}," + else: + shape_ctor_arg = "" + + scalar_hashes = ", ".join(f"{a.name}" for a in scalar_args) + + return f"""{self.node_base}( + {schema.node_name}::ClassOpKind(), + OpList{{{base_ctor_value_args}}}, + {shape_ctor_arg} + /* num_outputs */ {len(schema.returns)}, + torch::lazy::MHash({scalar_hashes}))""" + + def gen(self, schema: LazyIrSchema) -> List[str]: + opkind = schema.opkind or aten_symbol(schema) + + # for now, we just want one IR class decl and soon after also the method defs + # and we use the functional version not out/inplace. + all_args = schema.filtered_args() + value_args = schema.filtered_args(values=True, scalars=False) + scalar_args = schema.filtered_args(values=False, scalars=True) + + ctor_args = [f"const {i.lazy_type.cpp_type()}& {i.name}" for i in all_args] + reuse_ctor_args = ", ".join(ctor_args) + if self.use_lazy_shape and schema.properties.ShapePrecompute: + ctor_args.append("std::vector&& shapes") + node_ctor_args = ", ".join(ctor_args) + + scalar_initializers = ",\n ".join( + [ + # This code is just special casing the mapping from string_view -> strings + f"{a.name}({a.name}.has_value() ? c10::make_optional(std::string(*{a.name})) : c10::nullopt)" + if a.lazy_type.cpp_type() == "c10::optional" + else f"{a.name}({a.name})" + for a in scalar_args + ] + ) + if len(scalar_initializers): + scalar_initializers = f",\n {scalar_initializers}" + scalar_decls = "\n ".join( + [ + f"std::string {a.name};" + if a.lazy_type.cpp_type() == "c10::string_view" + else f"c10::optional {a.name};" + if a.lazy_type.cpp_type() == "c10::optional" + else f"{a.lazy_type.cpp_type()} {a.name};" + for a in scalar_args + ] + ) + optional_values = [ + arg.name + for arg in schema.filtered_args(values=True, scalars=False) + if isinstance(arg.lazy_type, OptionalCType) + ] + has_optional_decls = "\n ".join( + [f"bool has_{value}: 1;" for value in optional_values] + ) + has_optional_defs = "\n ".join( + [f"has_{value} = !!{value};" for value in optional_values] + ) + members_to_string = [] + for arg in scalar_args: + if isinstance(arg.lazy_type, OptionalCType): + value = f"{arg.name}.value()" + if arg.is_generator: + value = '"torch.Generator()"' + members_to_string.append( + f"""if ({arg.name}.has_value()) {{ + ss << ", {arg.name}=" << {value}; + }} else {{ + ss << ", {arg.name}=null"; + }}""" + ) + else: + members_to_string.append(f'ss << ", {arg.name}=" << {arg.name};') + members_to_string_str = "\n ".join(members_to_string) + + return [ + f"""\ +class {schema.node_name} : public {self.node_base} {{ + public: + static torch::lazy::OpKind ClassOpKind() {{ + return torch::lazy::OpKind({opkind}); + }} + + {schema.node_name}({node_ctor_args}) + : {self.node_base_ctor_call(schema)}{scalar_initializers} + {{ + {has_optional_defs} + }} + + std::string ToString() const override {{ + std::stringstream ss; + ss << {self.node_base}::ToString(); + {members_to_string_str} + return ss.str(); + }} + + {self.create_function(schema, reuse_ctor_args)} + + {self.can_be_reused_function(schema, reuse_ctor_args)} + + {self.lowering_function(schema)} + + {scalar_decls} + {has_optional_decls} + +}}; + +""", + ] + + +@dataclass(frozen=True) +class GenTSLazyIR(GenLazyIR): + def lowering_function(self, schema: LazyIrSchema) -> str: + signature = """ + torch::lazy::TSOpVector Lower( + std::shared_ptr function, + torch::lazy::TSLoweringContext* loctx) const override""" + + if schema.properties.LowerDeclOnly: + return f"{signature};" + elif schema.properties.Lower: + return f"""{signature} {{ + {ts_lowering_body(schema)} + }} + """ + else: + return "" + + def create_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str: + signature = f"static NodePtr Create({node_ctor_args})" + if schema.properties.CreateFnDeclOnly: + return f"{signature};" + elif not schema.properties.CreateFn: + return "" + return f"""{signature} {{ + return ReuseOrMakeNode<{schema.node_name}>(data); + }}""" + + def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str: + signature = f"bool CanBeReused({node_ctor_args}) const" + if schema.properties.CanBeReusedDeclOnly: + return f"{signature};" + elif not schema.properties.CanBeReused: + return "" + value_comparison = [] + for arg in itertools.chain(schema.positional_values, schema.keyword_values): + if isinstance(arg.lazy_type, OptionalCType): + value_comparison.append( + f"nullable_operand(i++) == {arg.name}.value_or(kNullValue)" + ) + else: + value_comparison.append(f"operand(i++) == {arg.name}") + for arg in itertools.chain(schema.positional_scalars, schema.keyword_scalars): + if isinstance(arg.lazy_type, OptionalCType): + value_comparison.append( + f"((!this->{arg.name}&&!{arg.name}) || (this->{arg.name}&&{arg.name} && *(this->{arg.name}) == *{arg.name}))" + ) + else: + value_comparison.append(f"this->{arg.name} == {arg.name}") + value_comparison_str = " &&\n ".join(value_comparison) + + return f"""{signature} {{ + size_t i = 0; + return ({value_comparison_str}); + }}""" + + +@dataclass(frozen=True) +class GenLazyNativeFuncDefinition: + class_method_name: str + backend_index: BackendIndex + tensor_class: str + gen_forced_fallback_code: bool + backend_namespace: str + get_tensorlist: str + get_tensor_or_wrap_number: str + try_get_tensor: str + metrics_counter: str + create_tensor: str + create_from_first_tensor: bool + create_aten_from_ltc_tensor: str + tuple_aten_from_ltc_tensors: str + lazy_tensor_ptr: str + get_device_fn: str + + def lazy_tensor_decls(self, func: NativeFunction, schema: LazyIrSchema) -> str: + value_args = schema.filtered_args(values=True, scalars=False) + # Generates lazy_{name} variables for LazyTensors wrapping input tensors + lazy_tensor_decls: List[str] = [] + for arg in value_args: + if arg.is_wrapped_scalar: + if isinstance(arg.lazy_type, OptionalCType): + lazy_tensor_decls.append( + f"""auto node_{arg.name} = {arg.name} ? + c10::make_optional(torch::lazy::LazyGraphExecutor::Get()-> + GetIrValueForScalarFromCodegen(*{arg.name}, *common_device)): + c10::nullopt;""" + ) + else: + lazy_tensor_decls.append( + f"""auto node_{arg.name} = torch::lazy::LazyGraphExecutor::Get()-> + GetIrValueForScalarFromCodegen({arg.name}, *common_device);""" + ) + elif arg.is_symint_or_list: + continue # values are extracted in isValueType + elif isinstance(arg.lazy_type, BaseCType): + if arg.lazy_type.type is tensorListValueT: + lazy_tensor_decls.append( + f"auto lazy_{arg.name}_tensorlist = " + f"{self.backend_namespace}::{self.get_tensorlist}({arg.name});" + ) + else: + lazy_tensor_decls.append( + f"{self.lazy_tensor_ptr} lazy_{arg.name} = " + f"{self.backend_namespace}::{self.get_tensor_or_wrap_number}({arg.name}, *common_device);" + ) + elif isinstance(arg.lazy_type, OptionalCType): + assert arg.lazy_type.elem == BaseCType(getValueT()), arg.lazy_type.elem + # TODO(alanwaketan): Maybe we want to apply GetLtcTensorOrCreateForWrappedNumber here, but hold it + # until we encounter a real world example. + lazy_tensor_decls.append( + f"{self.lazy_tensor_ptr} lazy_{arg.name} = " + f"{self.backend_namespace}::{self.try_get_tensor}({arg.name}.value_or(at::Tensor()));" + ) + else: + raise AssertionError( + f"TODO not sure if there are other valid types to handle here ({arg.lazy_type})" + ) + return ("\n ").join(lazy_tensor_decls) + + def force_eager_fallback( + self, + func: NativeFunction, + schema: LazyIrSchema, + metadata: BackendMetadata, + sig: Union[DispatcherSignature, NativeSignature], + ) -> str: + if self.gen_forced_fallback_code: + return gen_fallback_code( + schema, sig, overload_name=func.func.name.overload_name + ) + return "" + + def metrics(self, func: NativeFunction, schema: LazyIrSchema) -> str: + return f"{self.metrics_counter};" + + def get_device(self, func: NativeFunction, schema: LazyIrSchema) -> str: + value_args = schema.filtered_args(values=True, scalars=False) + scalar_args = schema.filtered_args(values=False, scalars=True) + value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar] + optional_device = OptionalCType(BaseCType(deviceT)) + optional_devices = [ + a.name for a in scalar_args if a.lazy_type == optional_device + ] + assert ( + len(value_types_names) > 0 or len(optional_devices) > 0 + ), "Expected at least one Value or Device type" + get_device_str = ( + f"{self.get_device_fn}({', '.join(value_types_names + optional_devices)})" + ) + return f"""auto common_device = {get_device_str}; + TORCH_INTERNAL_ASSERT(common_device); + """ + + def shape_inference(self, func: NativeFunction, schema: LazyIrSchema) -> str: + metadata = self.backend_index.get_kernel(func) + assert metadata is not None + all_args = schema.filtered_args() + returns_length = len(schema.returns) + # call the meta kernel if it exists, to compute output shape/dtype for our IR + # Note [Generated LTC Shape Functions] + # LTC uses meta tensors from core to do shape inference when possible, and otherwise + # we generate a shape function declaration that needs to be manually implemented. + # How do we detect which ops are eligible to use meta tensors? + # In general we should be able to use meta tensors not just on structured operators, + # but also on composite operators that are implemented in terms of structured kernels. + # We don't currently have a way of knowing at codegen time which ops are implemented that way. + # This is the case for all view and view_copy operators however, so we're going to + # use them specifically for all of the view_copy ops (instead of manually writing shape rules for all of them). + is_view_copy_op = "view_copy" in func.tags + is_structured = func.structured or func.structured_delegate is not None + if is_structured or is_view_copy_op: + meta_out = """ +std::vector shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};""" + if returns_length > 1: + + def this_shape(i: int) -> str: + return f"torch::lazy::Shape(std::get<{i}>(out_meta).scalar_type(), std::get<{i}>(out_meta).sizes().vec())" + + shapes_str = ",".join([this_shape(i) for i in range(returns_length)]) + meta_out = "std::vector shapes{" + shapes_str + "};" + + # Convert tensor args to the meta device and call it. + # (We can't pass in the input tensors directly, because they are "functional wrappers". + # If any of the meta kernels call a tensor op and redispatch, we don't want to hit the functionalize kernels.) + # Even at::meta:: functions might redispatch, e.g. if they call into view ops. + dispatcher_sig = DispatcherSignature.from_schema(func.func) + meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig) + meta_call_args = [ + e.expr + for e in translate( + meta_call_ctx, dispatcher_sig.arguments(), method=False + ) + ] + if is_view_copy_op: + # view_copy ops always have a CompositeExplicitAutogradNonFunctional kernel + assert func.has_composite_explicit_autograd_non_functional_kernel + dispatch_ns = "compositeexplicitautogradnonfunctional" + else: + dispatch_ns = "meta" + aten_name = schema.aten_name + # TODO: this is trolling + if func.func.has_symint() and metadata.supports_symint(): + aten_name += "_symint" + shape_str = f"""\ + {meta_conversion_str} + auto out_meta = at::{dispatch_ns}::{aten_name}({', '.join(meta_call_args)}); + {meta_out}""" + else: + shape_sig = ComputeShapeSignature( + metadata.kernel, func, symint=metadata.supports_symint() + ) + shape_str = f""" + auto shapes = {shape_sig.shape_call};""" + + shape_str += f""" + TORCH_INTERNAL_ASSERT(shapes.size() == {returns_length});""" + + # Calculating which dimensions are symbolic + func_schema_str = "aten::" + str(func.func) + shape_str += f""" + if(torch::lazy::symbolicShapeEnabled()){{ + std::vector inputs = {{ {', '.join(str(a.name) for a in all_args)} }}; + const char* schema_str = "{func_schema_str}"; + applySymbolicShapesOnLT(schema_str, inputs, shapes); + }} + """ + return shape_str + + def build_ir_node(self, func: NativeFunction, schema: LazyIrSchema) -> str: + node_ctor_input_str = node_ctor_inputs(schema) + return f"""torch::lazy::NodePtr node = torch::lazy::ReuseNode<{schema.node_name}>({node_ctor_input_str}); + if (!node) {{ + {self.shape_inference(func, schema)} + node = torch::lazy::MakeNode<{schema.node_name}>({node_ctor_input_str}, std::move(shapes)); + CacheNode(node); + }} + """ + + def create_lazy_tensor(self, first_tensor_name: Optional[str] = None) -> str: + # xla uses an instance method for tensor creation, for the time being + if self.create_from_first_tensor: + # TODO(whc) remove this if XLA switches to using static method for creation + assert ( + first_tensor_name is not None + ), "Requires first tensor to create lazy tensor" + return f"{first_tensor_name}.{self.create_tensor}" + return f"{self.backend_namespace}::{self.create_tensor}" + + def return_aten_tensor(self, func: NativeFunction, schema: LazyIrSchema) -> str: + returns_length = len(schema.returns) + value_args = schema.filtered_args(values=True, scalars=False) + value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar] + first_tensor_name = value_types_names[0] if len(value_types_names) > 0 else None + bridge_str = f"""auto result = {self.create_aten_from_ltc_tensor}( + {self.create_lazy_tensor(first_tensor_name)}(std::move(node), *common_device));""" + + if returns_length > 1: + assert ( + len(value_types_names) > 0 + ), "Code below assumes there is at least one tensor arg" + bridge_str = f"""std::vector<{self.lazy_tensor_ptr}> lazy_tensors; + for (int i = 0; i < {returns_length}; i++) {{ + lazy_tensors.push_back({self.create_lazy_tensor(first_tensor_name)}({getValueT()}(node, i), *common_device)); + }} + auto result = {self.tuple_aten_from_ltc_tensors}<{returns_length}>(lazy_tensors);""" + + if schema.name.name.inplace or func.func.is_out_fn(): + assert returns_length == 1, ( + "We assumed there was no such case where an op is an in-place variant " + f"and has tuple outputs, but got tuple of len {returns_length}." + ) + bridge_str = f"""lazy_{first_tensor_name}->SetInPlaceIrValue(node); + auto& result = {first_tensor_name};""" + + bridge_str += """ + return result;""" + return bridge_str + + @method_with_native_function + def __call__(self, func: NativeFunction) -> List[str]: + sig = kernel_signature(func, self.backend_index) + metadata = self.backend_index.get_kernel(func) + assert metadata is not None + schema = LazyIrSchema(func.func, symint=metadata.supports_symint()) + return [ + f"""\ + {sig.decl(name=f"{self.class_method_name}::{metadata.kernel}")} {{ + {self.force_eager_fallback(func, schema, metadata, sig)} + {self.metrics(func, schema)} + {self.get_device(func, schema)} + {self.lazy_tensor_decls(func, schema)} + {self.build_ir_node(func, schema)} + {self.return_aten_tensor(func, schema)} + }}\n + """ + ] + + +class ComputeShapeSignature: + """ + Here we use the base name as the suffix of the signature to avoid generating for in-place variants. + """ + + def __init__(self, kernel_name: str, f: NativeFunction, *, symint: bool): + self.__schema = LazyIrSchema(f.func, symint=symint) + self.__dispatch_args = ", ".join( + [a.decl() for a in dispatcher.arguments(f.func, symint=symint)] + ) + self.__call_args = ", ".join( + [f"{arg.name}" for arg in self.__schema.filtered_args(generator=True)] + ) + self.__kernel_name = kernel_name + + def __decl_suffix(self) -> str: + return f"{self.__kernel_name}({self.__dispatch_args})" + + def __call_suffix(self) -> str: + return f"{self.__kernel_name}({self.__call_args})" + + @property + def shape_decl(self) -> str: + return f"TORCH_API std::vector compute_shape_{self.__decl_suffix()}" + + @property + def shape_call(self) -> str: + return f"torch::lazy::compute_shape_{self.__call_suffix()}" + + +@dataclass(frozen=True) +class GenLazyShapeInferenceDefinition: + backend_index: BackendIndex + tensor_class: str + + @method_with_native_function + def __call__(self, f: NativeFunction) -> List[str]: + sig = kernel_signature(f, self.backend_index) + metadata = self.backend_index.get_kernel(f) + assert metadata is not None + + # See Note [Generated LTC Shape Functions] + is_view_copy_op = "view_copy" in f.tags + is_structured = f.structured or f.structured_delegate is not None + if is_structured or is_view_copy_op: + return [] + else: + shape_sig = ComputeShapeSignature( + metadata.kernel, f, symint=metadata.supports_symint() + ) + return ["\n".join([f"{shape_sig.shape_decl};"])] + + +def generate_non_native_lazy_ir_nodes( + non_native: List[Dict[str, Any]], gen_lazy_ir: GenLazyIR +) -> List[str]: + """Generate the non-native lazy IR node classes""" + nodes = [] + for op in non_native: + # Set default properties for Non-Native IRs + properties = LazyIrProperties("ShapeCache", "CanBeReused", "LowerDeclOnly") + for p in op.get("properties", []): + setattr(properties, p, True) + + # non-native is assumed to want symint bindings if you wrote symint + schema = LazyIrSchema(FunctionSchema.parse(op["func"]), properties, symint=True) + schema.opkind = op.get("opkind") + nodes.append(gen_lazy_ir.gen(schema)[0]) + + return nodes diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/lazy_ts_lowering.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/lazy_ts_lowering.py new file mode 100644 index 0000000000000000000000000000000000000000..70161216d8e7c95e194b0d89b345e0da886ef989 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/lazy_ts_lowering.py @@ -0,0 +1,48 @@ +from torchgen.api.lazy import LazyArgument, LazyIrSchema +from torchgen.api.types import OptionalCType + + +def ts_lowering_body(schema: LazyIrSchema) -> str: + # for now, we just want one IR class decl and soon after also the method defs + # and we use the functional version not out/inplace. + emplace_arguments = [] + + def get_value(arg: LazyArgument) -> str: + if isinstance(arg.lazy_type, OptionalCType): + return f"has_{arg.name} ? loctx->GetOutputOp(operand(i++)) : nullptr" + return "loctx->GetOutputOp(operand(i++))" + + for arg in schema.positional_args: + if arg.is_lazy_value: + emplace_arguments.append(get_value(arg)) + continue + emplace_arguments.append(f'"{arg.name}", {arg.name}') + + emplace_arguments_str = "\n ".join( + [f"arguments.emplace_back({a});" for a in emplace_arguments] + ) + emplace_kwarg_values = [ + f'"{arg.name}", {get_value(arg)}' for arg in schema.keyword_values + ] + emplace_kwarg_scalars = [ + f'"{arg.name}", {arg.name}' for arg in schema.keyword_scalars + ] + emplace_kwarguments = "\n ".join( + [ + f"kwarguments.emplace_back({a});" + for a in emplace_kwarg_values + emplace_kwarg_scalars + ] + ) + return f"""\ + std::vector arguments; + std::vector kwarguments; + arguments.reserve({len(emplace_arguments)}); + kwarguments.reserve({len(emplace_kwarg_values + emplace_kwarg_scalars)}); + size_t i = 0; + {emplace_arguments_str} + {emplace_kwarguments} + torch::lazy::TSOpVector {schema.aten_name}_out = torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments); + TORCH_CHECK_EQ({schema.aten_name}_out.size(), {len(schema.returns)}); + + return {schema.aten_name}_out; +""" diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/native_functions.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/native_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..57a9217550d9c9afbbe7f1ab544771381b1359eb --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/native_functions.py @@ -0,0 +1,64 @@ +from typing import List, Optional, Union + +import torchgen.api.meta as meta +import torchgen.api.structured as structured +from torchgen.api.types import kernel_signature + +from torchgen.context import with_native_function_and_index +from torchgen.model import BackendIndex, NativeFunction, NativeFunctionsGroup +from torchgen.utils import mapMaybe + + +@with_native_function_and_index +def gen_unstructured(f: NativeFunction, backend_index: BackendIndex) -> Optional[str]: + sig = kernel_signature(f, backend_index) + metadata = backend_index.get_kernel(f) + if metadata is None: + return None + if "legacy::" in metadata.kernel: + return None + else: + prefix = "static" if backend_index.external else "TORCH_API" + return f"{prefix} {sig.decl(name=metadata.kernel)};" + + +@with_native_function_and_index +def gen_structured(g: NativeFunctionsGroup, backend_index: BackendIndex) -> List[str]: + meta_name = meta.name(g) + out_args = structured.impl_arguments(g) + metadata = backend_index.get_kernel(g) + if metadata is None: + return [] + prefix = "" if backend_index.external else "TORCH_API " + return [ + f"""\ +struct {prefix}structured_{metadata.kernel} : public at::meta::structured_{meta_name} {{ +void impl({', '.join(a.decl() for a in out_args)}); +}}; +""" + ] + + +# Generates NativeFunctions.h, a list of forward declarations of all +# actual kernel definitions we keep in aten/src/ATen/native/ +@with_native_function_and_index +def compute_native_function_declaration( + g: Union[NativeFunctionsGroup, NativeFunction], backend_index: BackendIndex +) -> List[str]: + metadata = backend_index.get_kernel(g) + if isinstance(g, NativeFunctionsGroup): + if metadata is not None and metadata.structured: + if backend_index.external: + # Structured hasn't been tested with external backends yet. + raise AssertionError( + "Structured external backend functions are not implemented yet." + ) + else: + return gen_structured(g, backend_index) + else: + return list( + mapMaybe(lambda f: gen_unstructured(f, backend_index), g.functions()) + ) + else: + x = gen_unstructured(g, backend_index) + return [] if x is None else [x] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/register_dispatch_key.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/register_dispatch_key.py new file mode 100644 index 0000000000000000000000000000000000000000..114b641c5b4dbf52b24b57ffd093d63546a4bbd3 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/register_dispatch_key.py @@ -0,0 +1,989 @@ +import itertools +import textwrap +from dataclasses import dataclass +from typing import List, Literal, Optional, Tuple, Union + +import torchgen.api.cpp as cpp +import torchgen.api.meta as meta +import torchgen.api.structured as structured +from torchgen.api.translate import translate +from torchgen.api.types import ( + BaseCType, + Binding, + ConstRefCType, + CppSignature, + CppSignatureGroup, + DispatcherSignature, + Expr, + kernel_signature, + MutRefCType, + NamedCType, + NativeSignature, + tensorT, +) + +from torchgen.context import method_with_native_function, native_function_manager +from torchgen.model import ( + Argument, + BackendIndex, + DeviceCheckType, + DispatchKey, + gets_generated_out_inplace_wrapper, + is_cuda_dispatch_key, + NativeFunction, + NativeFunctionsGroup, + SchemaKind, + TensorOptionsArguments, +) +from torchgen.selective_build.selector import SelectiveBuilder +from torchgen.utils import assert_never, mapMaybe, Target + + +def gen_registration_headers( + backend_index: BackendIndex, + per_operator_headers: bool, + rocm: bool, +) -> List[str]: + if per_operator_headers: + headers = ["#include "] + else: + headers = ["#include "] + + if backend_index.dispatch_key in (DispatchKey.CPU, DispatchKey.Meta): + headers.append("#include ") + elif backend_index.dispatch_key == DispatchKey.CUDA: + if rocm: + headers.append("#include ") + else: + headers.append("#include ") + elif backend_index.dispatch_key == DispatchKey.MPS: + headers.append("#include ") + elif per_operator_headers: + headers += [ + "#include ", + "#include ", + "#include ", + "#include ", + ] + else: + headers.append("#include ") + + return headers + + +def gen_empty_impl_names( + backend_index: BackendIndex, +) -> Tuple[Optional[str], Optional[str]]: + empty_impl = None + empty_strided_impl = None + + if backend_index.dispatch_key in ( + DispatchKey.Meta, + DispatchKey.CPU, + DispatchKey.CUDA, + DispatchKey.MPS, + ): + dispatch = str(backend_index.dispatch_key).lower() + empty_impl = f"at::detail::empty_{dispatch}" + empty_strided_impl = f"at::detail::empty_strided_{dispatch}" + elif backend_index.dispatch_key in ( + DispatchKey.CompositeExplicitAutogradNonFunctional, + DispatchKey.QuantizedCPU, + DispatchKey.QuantizedCUDA, + ): + empty_impl = "at::empty" + empty_strided_impl = "at::empty_strided" + + return empty_impl, empty_strided_impl + + +def gen_create_out_helper(backend_index: BackendIndex) -> List[str]: + if backend_index.dispatch_key == DispatchKey.Meta: + empty_options = "options.device(at::kMeta)" + else: + empty_options = "options" + + empty_impl, empty_strided_impl = gen_empty_impl_names(backend_index) + if empty_impl is None: + return [] + + return [ + f""" +Tensor create_out(IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {{ + if (strides.empty()) {{ + return {empty_impl}(sizes, {empty_options}); + }} else {{ + return {empty_strided_impl}(sizes, strides, {empty_options}); + }} +}} +""" + ] + + +def gen_maybe_create_proxy_helper(backend_index: BackendIndex) -> List[str]: + _, empty_strided_impl = gen_empty_impl_names(backend_index) + return ( + [] + if empty_strided_impl is None + else [ + f""" +c10::optional maybe_create_proxy(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {{ + if (out.strides() != strides) {{ + return {empty_strided_impl}(sizes, strides, options); + }} + return c10::nullopt; +}} +""" + ] + ) + + +def gen_resize_out_helper(backend_index: BackendIndex) -> List[str]: + if backend_index.dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional: + # The function isn't used by this key (since only functional ops have a kernel for this key), + # so we need to not include it to avoid a defined-but-not-used error. + return [] + return [ + """ +void resize_out(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) { + TORCH_CHECK(options.dtype() == out.dtype(), + "Expected out tensor to have dtype ", options.dtype(), ", but got ", out.dtype(), " instead"); + TORCH_CHECK(options.device() == out.device(), + "Expected out tensor to have device ", options.device(), ", but got ", out.device(), " instead"); + const bool resized = at::native::resize_output(out, sizes); + // Only restride if a resize occurred; otherwise we ignore the (advisory) + // strides from the meta function and directly use the output tensor's + // preexisting strides + if (resized) { + if (!strides.empty()) { + TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value()); + // TODO: avoid the redispatch here + out.as_strided_(sizes, strides); + } else if (options.memory_format_opt().has_value()) { + out.unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt()); + } + } +} +""" + ] + + +def gen_check_inplace_helper(backend_index: BackendIndex) -> List[str]: + return [ + """ +void check_inplace(const Tensor &self, IntArrayRef sizes, const TensorOptions &options) { + // These checks are needed on those operators that: + // 1) don't use 'TensorIterator' (e.g. 'addmm' and 'baddbmm') + // 2) have particular typing rules (e.g. 'cumsum' and 'cumprod') + // For other operators (e.g. 'add'), 'TensorIterator' already checks + // these things separately. + TORCH_CHECK(options.dtype() == self.dtype(), + "Bad in-place call: ", + "input tensor dtype ", self.dtype(), " and output tensor dtype ", options.dtype(), " should match"); + TORCH_CHECK(options.device() == self.device(), + "Bad in-place call: ", + "input tensor device ", self.device(), " and output tensor device ", options.device(), " should match"); + TORCH_CHECK(sizes == self.sizes(), + "Bad in-place call: ", + "input tensor size ", self.sizes(), " and output tensor size ", sizes, " should match"); +} +""" + ] + + +def gen_registration_helpers(backend_index: BackendIndex) -> List[str]: + return [ + *gen_create_out_helper(backend_index), + *gen_resize_out_helper(backend_index), + *gen_check_inplace_helper(backend_index), + *gen_maybe_create_proxy_helper(backend_index), + ] + + +# Generates Register{dispatch}.cpp (e.g., RegisterCPU.cpp). +# +# - The primary function of this file is to register all of the +# implementations for the given dispatch key to the dispatcher, +# so they are available for use in PyTorch. If dispatch is +# None, we generate schema (def) registrations and catchall +# registrations. +# - The secondary function of this file is to generate a wrapper +# around functions. In CPUType these wrappers do nothing +# (and should be removed), but in other cases they handle +# DeviceGuard. A small extra benefit of wrappers is they +# are not overloaded, so they can be used in the registration +# API without having to disambiguate which overload you want +# (as would be the case if you directly registered native:: +# functions). +# - The tertiary function of this file is to generate *static* +# cpp API bindings which can be used to bypass dispatcher +# directly to kernels, but with user-friendly cpp-style API +@dataclass(frozen=True) +class RegisterDispatchKey: + backend_index: BackendIndex + + target: Literal[ + Target.ANONYMOUS_DEFINITION, + Target.NAMESPACED_DEFINITION, + Target.NAMESPACED_DECLARATION, + Target.REGISTRATION, + ] + + # Selector object to determine which operators to generate + # registration code for. + selector: SelectiveBuilder + + # Whether or not we are actually code-genning for ROCm + rocm: bool + + # Whether or not to generate symint registrations or not. External users + # of codegen who don't care about symints can set this to false to get + # non-SymInt codegen + symint: bool + + # The class that all unstructured native functions live under. This is used to improve + # compiler error messages when a kernel writer adds a native function with the wrong signature. + # This is only used in unstructured kernels, since structured kernels already live in a class. + # Finally, this field is currently Optional because it is only used by external backends. + # It would be nice if we can add the same logic to in-tree kernels too, but that requires updating + # all of the existing kernel signatures scattered across aten/src/ATen/native. + class_method_name: Optional[str] + + # Only set to true in lightweight dispatch. If lightweight dispatch is enabled we are registering + # operators into JIT op registry, thus we need to avoid generating code to register into the dispatcher. + skip_dispatcher_op_registration: bool + + @staticmethod + def gen_device_check( + type: DeviceCheckType, args: List[Argument], method_name: str + ) -> str: + if type == DeviceCheckType.NoCheck: + return " // No device check\n" + + device_check = "c10::optional common_device = nullopt;\n" + device_check += "(void)common_device; // Suppress unused variable warning\n" + for arg in args: + # Only tensor like arguments are eligible + if arg.type.is_tensor_like(): + device_check += f""" + c10::impl::check_and_update_common_device(common_device, {arg.name}, "{method_name}", "{arg.name}");""" + return device_check + + @method_with_native_function + def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]: + if isinstance(f, NativeFunctionsGroup): + g: NativeFunctionsGroup = f + # Note: We call gen_structured() if the operator is marked structured, regardless of the backend. + # gen_structured() has special logic to handle auto-generated kernels. + if g.structured: + return self.gen_structured(g) + else: + return list( + mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions()) + ) + elif isinstance(f, NativeFunction): + r = self.gen_unstructured(f) + return [] if r is None else [r] + else: + assert_never(f) + + def wrapper_kernel_sig( + self, f: NativeFunction + ) -> Union[NativeSignature, DispatcherSignature]: + # The prefix is just to ensure uniqueness. The Dispatcher API doesn't guarantee unique kernel names. + return DispatcherSignature.from_schema( + f.func, + prefix=f"wrapper_{self.backend_index.dispatch_key}_{f.func.name.overload_name}_", + symint=self.symint, + ) + + def gen_out_inplace_wrapper( + self, f: NativeFunction, g: Optional[NativeFunctionsGroup] + ) -> Optional[str]: + if g is None: + return None + k = f.func.kind() + if k is SchemaKind.inplace: + copy_op = "at::_copy_from" + elif k is SchemaKind.out: + copy_op = "at::_copy_from_and_resize" + else: + raise AssertionError("gen_out_inplace_wrapper called on a functional op") + + sig = self.wrapper_kernel_sig(f) + name = sig.name() + + func_res = f"{name}_tmp" + return_names = cpp.return_names(f) + if len(return_names) > 1: + updates = "\n ".join( + f"{copy_op}(std::get<{i}>({func_res}), {ret_name});" + for i, ret_name in enumerate(return_names) + ) + returns = f'{sig.returns_type().cpp_type()}({", ".join(return_names)})' + elif len(return_names) == 1: + ret_name = return_names[0] + updates = f"{copy_op}({func_res}, {ret_name});" + returns = ret_name + else: + assert len(f.func.arguments.out) == 1 + returns = "" + out_arg = f.func.arguments.out[0] + if out_arg.type.is_list_like(): + updates = f"""\ + for (int64_t i = 0; i < {func_res}.size(); ++i) {{ + {copy_op}({func_res}[i], {out_arg.name}[i]); + }}""" + else: + updates = f"{copy_op}({func_res}, {out_arg.name});" + + functional_sig = self.wrapper_kernel_sig(g.functional) + wrapper_name = sig.name() + + return f"""\ +{sig.defn(name=wrapper_name)} {{ + auto {func_res} = {functional_sig.name()}({", ".join(e.expr for e in translate(sig.arguments(), functional_sig.arguments()))}); + {updates} + return {returns}; +}} +""" + + def gen_structured(self, g: NativeFunctionsGroup) -> List[str]: + metadata = self.backend_index.get_kernel(g) + if self.backend_index.dispatch_key == DispatchKey.Meta: + assert not self.backend_index.has_kernel(g.out), ( + "Do not explicitly specify Meta dispatch key on structured " + "functions, they will be automatically generated for you" + ) + elif ( + self.backend_index.dispatch_key + == DispatchKey.CompositeExplicitAutogradNonFunctional + ): + assert not self.backend_index.has_kernel(g.out), ( + "Do not explicitly specify CompositeExplicitAutograd dispatch key on structured " + "functions, they will be automatically generated for you" + ) + elif metadata is None or not metadata.structured: + return list(mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions())) + structured_gen = StructuredRegisterDispatchKey( + self.backend_index, + self.target, + self.selector, + self.rocm, + self.symint, + self.class_method_name, + self.skip_dispatcher_op_registration, + g, + ) + return list(mapMaybe(structured_gen.gen_one, g.functions())) + + def gen_unstructured( + self, f: NativeFunction, g: Optional[NativeFunctionsGroup] = None + ) -> Optional[str]: + with native_function_manager(f): + inplace_meta = False + gets_out_inplace_wrapper = False + if not self.backend_index.has_kernel(f): + if ( + self.backend_index.dispatch_key == DispatchKey.Meta + and f.func.kind() is SchemaKind.inplace + and + # Defer to composites for meta implementation + not f.has_composite_kernel + and + # Inplace list operations are not supported + len(f.func.returns) == 1 + ): + inplace_meta = True + elif ( + not self.backend_index.use_out_as_primary + and g is not None + and gets_generated_out_inplace_wrapper(f, g, self.backend_index) + ): + # We want to generate inplace/out wrappers, that don't have a kernel for the backend. + gets_out_inplace_wrapper = True + else: + return None + if f.manual_kernel_registration: + return None + + if ( + self.target is Target.REGISTRATION + and not self.selector.is_native_function_selected(f) + ): + return None + + sig = self.wrapper_kernel_sig(f) + + name = sig.name() + returns_type = sig.returns_type().cpp_type() + args = sig.arguments() + args_str = ", ".join(a.defn() for a in args) + + # See Note [Direct dispatch bindings] + cpp_sig_group = CppSignatureGroup.from_native_function( + f, method=False, fallback_binding=False + ) + + # TODO: dedupe this with the structured codegen + if self.target is Target.NAMESPACED_DECLARATION: + result = "" + for cpp_sig in cpp_sig_group.signatures(symint=self.symint): + result += f"TORCH_API {cpp_sig.decl()};\n" + return result + elif self.target is Target.NAMESPACED_DEFINITION: + + def generate_defn(cpp_sig: CppSignature) -> str: + return f""" +{cpp_sig.defn()} {{ +return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))}); +}} +""" + + result = "" + for cpp_sig in cpp_sig_group.signatures(symint=self.symint): + result += generate_defn(cpp_sig) + return result + + elif self.target is Target.ANONYMOUS_DEFINITION: + # short circuit for inplace_meta + if inplace_meta: + assert f.func.arguments.self_arg is not None + self_arg_name = f.func.arguments.self_arg.argument.name + # TODO: handle in place on tensor list + return f""" +{returns_type} {name}({args_str}) {{ + TORCH_CHECK_NOT_IMPLEMENTED({self_arg_name}.is_meta(), + "Cannot inplace into non-meta tensor with meta tensor argument"); + return {self_arg_name}; +}} +""" + + # short circuit for generated inplace/out wrappers + if gets_out_inplace_wrapper: + return self.gen_out_inplace_wrapper(f, g) + + metadata = self.backend_index.get_kernel(f) + if metadata is None: + return None + if self.class_method_name is None: + impl_name = f"{metadata.cpp_namespace}::{metadata.kernel}" + else: + impl_name = f"{metadata.cpp_namespace}::{self.class_method_name}::{metadata.kernel}" + + kernel_sig = kernel_signature(f, self.backend_index) + + args_exprs_str = ", ".join( + e.expr + for e in translate( + sig.arguments(), kernel_sig.arguments(), method=False + ) + ) + + device_check = " // No device check\n" + # Backends that require device guards presumably also require device checks. + if self.backend_index.device_guard: + device_check_args = itertools.chain( + f.func.arguments.out, f.func.arguments.flat_positional + ) + device_check = RegisterDispatchKey.gen_device_check( + f.device_check, list(device_check_args), name + ) + + device_guard = "// DeviceGuard omitted" # default + if f.device_guard and self.backend_index.device_guard: + has_tensor_options = any( + isinstance(a, TensorOptionsArguments) + for a in f.func.arguments.non_out + ) + if has_tensor_options: + # kernel is creating a tensor + device_guard = """ + const DeviceGuard device_guard(device_or_default(device));""" + + # CUDA requires special handling + if is_cuda_dispatch_key(self.backend_index.dispatch_key): + device_guard = ( + f"globalContext().lazyInitCUDA();\n{device_guard}" + ) + else: + # kernel is operating on existing tensors + + # There is precedence for which argument we use to do + # device guard. This describes the precedence order. + self_arg = ( + [f.func.arguments.self_arg.argument] + if f.func.arguments.self_arg is not None + else [] + ) + candidate_args = itertools.chain( + self_arg, + f.func.arguments.out, + f.func.arguments.flat_positional, + ) + + # Only tensor like arguments are eligible + device_of = next( + ( + f"{a.name}" + for a in candidate_args + if a.type.is_tensor_like() + ), + None, + ) + if device_of is not None: + device_guard = f"const OptionalDeviceGuard device_guard(device_of({device_of}));" + + return f"""\ +namespace {{ + +{returns_type} {name}({args_str}) {{ + {device_check} + + {device_guard} + return {impl_name}({args_exprs_str}); +}} + +}} // anonymous namespace +""" + + elif self.target is Target.REGISTRATION: + if f.manual_kernel_registration or self.skip_dispatcher_op_registration: + return None + else: + payload = f"TORCH_FN({name})" + return f'm.impl("{f.func.name}",\n{payload});\n' + else: + assert_never(self.target) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# STRUCTURED +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +@dataclass(frozen=True) +class StructuredRegisterDispatchKey(RegisterDispatchKey): + g: NativeFunctionsGroup + + def gen_class_set_output_functions( + self, k: SchemaKind, parent_class: str, generate_super: bool + ) -> str: + if generate_super: + set_output_super = f"{parent_class}::set_output_raw_strided(output_idx, sizes, strides, options, names);" + else: + set_output_super = "" + + def gen_set_output_function(name: str, maybe_create_proxy: bool) -> str: + return f""" +void set_output_{name}( + int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, + TensorOptions options, DimnameList names +) override {{ +{textwrap.indent(self.gen_class_set_output_body(k, maybe_create_proxy), " ")} + if (!names.empty()) {{ + namedinference::propagate_names(outputs_[output_idx], names); + }} + // super must happen after, so that downstream can use maybe_get_output + // to retrieve the output +{textwrap.indent(set_output_super, " ")} +}} +""" + + return f""" +{gen_set_output_function("strided", maybe_create_proxy=True)} +{gen_set_output_function("raw_strided", maybe_create_proxy=False)} +""" + + def gen_class_set_output_body(self, k: SchemaKind, maybe_create_proxy: bool) -> str: + if self.backend_index.dispatch_key in [ + DispatchKey.CUDA, + DispatchKey.MPS, + DispatchKey.CompositeExplicitAutogradNonFunctional, + ]: + maybe_set_guard = """ +auto current_device = guard_.current_device(); +if (C10_UNLIKELY(current_device.has_value())) { + TORCH_INTERNAL_ASSERT(*current_device == options.device(), + "structured kernels don't support multi-device outputs"); +} else { + guard_.reset_device(options.device()); +} +""" + maybe_set_guard_line = maybe_set_guard + "\n" + else: + maybe_set_guard_line = maybe_set_guard = "" + + if maybe_create_proxy: + create_proxy = """ +auto maybe_proxy = maybe_create_proxy(out, sizes, strides, options); +if (C10_UNLIKELY(maybe_proxy.has_value())) { + proxy_outputs_[output_idx] = std::move(maybe_proxy).value(); +} +""" + else: + create_proxy = "" + + if k is SchemaKind.functional: + assert self.backend_index.dispatch_key in ( + DispatchKey.Meta, + DispatchKey.CPU, + DispatchKey.CUDA, + DispatchKey.MPS, + DispatchKey.CompositeExplicitAutogradNonFunctional, + ) + return f"""{maybe_set_guard_line} +outputs_[output_idx] = create_out(sizes, strides, options);""" + elif k is SchemaKind.inplace: + return f"""{maybe_set_guard_line} +const auto& out = outputs_[output_idx].get(); +check_inplace(out, sizes, options); +{create_proxy}""" + elif k is SchemaKind.out: + return f"""{maybe_set_guard_line} +const auto& out = outputs_[output_idx].get(); +resize_out(out, sizes, strides, options); +{create_proxy}""" + elif k is SchemaKind.mutable or k is SchemaKind.scratch: + raise AssertionError( + f"{k} structured operators are currently not supported" + ) + else: + assert_never(k) + + # returns the definition of a ctor, as well as how to construct + # this class to a variable named op + def gen_class_ctor(self, k: SchemaKind, class_name: str, returns: int) -> str: + if k is SchemaKind.functional: + return "" + elif k is SchemaKind.inplace: + # TODO: Make sure out argument is guaranteed to be self + return f"{class_name}(Tensor& self) : outputs_{{std::ref(self)}} {{}}" + elif k is SchemaKind.out: + out_args = ", ".join(f"Tensor& out{i}" for i in range(returns)) + out_refs = ", ".join(f"std::ref(out{i})" for i in range(returns)) + return f"{class_name}({out_args}) : outputs_{{ {out_refs} }} {{}}" + elif k is SchemaKind.mutable or k is SchemaKind.scratch: + raise AssertionError( + f"{k} structured operators are currently not supported" + ) + else: + assert_never(k) + + def gen_class( + self, + f: NativeFunction, + k: SchemaKind, + *, + class_name: str, + parent_class: str, + generate_super: bool, + ) -> str: + if k is SchemaKind.functional: + output_type = "Tensor" + output_value = "outputs_[output_idx]" + proxy_field = "" + elif k is SchemaKind.inplace: + output_type = "std::reference_wrapper" + output_value = "proxy_outputs_[output_idx].has_value() ? *proxy_outputs_[output_idx] : outputs_[output_idx].get()" + proxy_field = f"std::array, {len(f.func.returns)}> proxy_outputs_;" + elif k is SchemaKind.out: + output_type = "std::reference_wrapper" + output_value = "proxy_outputs_[output_idx].has_value() ? *proxy_outputs_[output_idx] : outputs_[output_idx].get()" + proxy_field = f"std::array, {len(f.func.returns)}> proxy_outputs_;" + + if self.backend_index.dispatch_key == DispatchKey.CUDA: + if self.rocm: + guard_field = "c10::hip::OptionalHIPGuardMasqueradingAsCUDA guard_;" + else: + guard_field = "c10::cuda::OptionalCUDAGuard guard_;" + elif ( + self.backend_index.dispatch_key + == DispatchKey.CompositeExplicitAutogradNonFunctional + ): + guard_field = "c10::OptionalDeviceGuard guard_;" + elif self.backend_index.dispatch_key == DispatchKey.MPS: + # TODO: Move to OptionalMPSGuard. + guard_field = "c10::OptionalDeviceGuard guard_;" + else: + guard_field = "" + + indent = " " * 4 + class_ctor_str = self.gen_class_ctor(k, class_name, len(f.func.returns)) + lines = ( + f"struct {class_name} final : public {parent_class} {{", + f"{textwrap.indent(class_ctor_str, indent)}", + f"{textwrap.indent(self.gen_class_set_output_functions(k, parent_class, generate_super), indent)}", + " const Tensor& maybe_get_output(int64_t output_idx) override {", + f" return {output_value};\n", # type: ignore[possibly-undefined] # TODO: audit + " }", + f" std::array<{output_type}, {len(f.func.returns)}> outputs_;", # type: ignore[possibly-undefined] # TODO: audit + f"{textwrap.indent(proxy_field, indent)}", # type: ignore[possibly-undefined] # TODO: audit + f"{textwrap.indent(guard_field, indent)}", + "};", + ) + return "\n".join(line for line in lines if line) + + @method_with_native_function + def gen_one(self, f: NativeFunction) -> Optional[str]: + assert not f.manual_kernel_registration + + if ( + self.target is Target.REGISTRATION + and not self.selector.is_native_function_selected(f) + ): + return None + + # TODO: Now, there is something interesting going on here. In the code below, + # we generate CompositeExplicitAutogradNonFunctional implementations of functional and inplace + # based on the out implementation. But in fact, out is definable by + # functional too (just not very efficiently), and this is honestly the + # MORE likely situation for a backend implementor. How do we pick? + # Well, taking a page from Haskell type classes and default methods, + # we could conceivably register a circular definition (out in terms + # of functional, and functional in terms of out) and just require + # someone to implement one or the other. We'd have to do a little bit + # of work to not register one of these "weak" definitions unless there + # is a strong definition somewhere in the DAG! So it's not implemented yet. + if ( + self.backend_index.dispatch_key + == DispatchKey.CompositeExplicitAutogradNonFunctional + and f.func.kind() is SchemaKind.out + ): + # Never generate a default implementation for out, that's what you + # have to define as a backend implementor + return None + + # Note [Direct dispatch bindings] + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Signature of the non-dispatched function we'll expose in a header + # (e.g., at::cpu::add). We don't generate methods (TODO: do this + # when CPUTensor class is a thing); nor do we generate fallback + # bindings for manual_cpp_binding functions. + cpp_sig_group = CppSignatureGroup.from_native_function( + f, method=False, fallback_binding=False + ) + + # Signature of the wrapper function we'll register to the dispatcher + kern = self.backend_index.get_kernel(f) + sig = NativeSignature( + f.func, + prefix=f"wrapper_{self.backend_index.dispatch_key}_", + symint=kern is not None and kern.supports_symint(), + ) + + if self.target is Target.NAMESPACED_DECLARATION: + result = "" + for cpp_sig in cpp_sig_group.signatures(symint=self.symint): + result += f"TORCH_API {cpp_sig.decl()};\n" + return result + + elif self.target is Target.NAMESPACED_DEFINITION: + + def generate_defn(cpp_sig: CppSignature) -> str: + return f""" +{cpp_sig.defn()} {{ +return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))}); +}} +""" + + result = "" + for cpp_sig in cpp_sig_group.signatures(symint=self.symint): + result += generate_defn(cpp_sig) + return result + + elif self.target is Target.ANONYMOUS_DEFINITION: + k = f.func.kind() + + # Construct the body of the wrapper function with signature sig + sig_body = [] + # We'll use context to keep track of any variables we've brought + # into scope while generating code + context: List[Union[Binding, Expr]] = list(sig.arguments()) + + # Initialize the class corresponding to this structured + # operator; feeding it the output argument(s) if it is known + if self.backend_index.dispatch_key is DispatchKey.Meta: + class_name = f"structured_{meta.name(self.g)}_meta_{k.name}" + parent_class = f"at::meta::structured_{meta.name(self.g)}" + elif ( + self.backend_index.dispatch_key + is DispatchKey.CompositeExplicitAutogradNonFunctional + ): + # TODO: dedup this branch + class_name = f"structured_{meta.name(self.g)}_default_backend_{k.name}" + parent_class = f"at::meta::structured_{meta.name(self.g)}" + else: + metadata = self.backend_index.get_kernel(self.g) + assert metadata is not None + class_name = f"structured_{metadata.kernel}_{k.name}" + parent_class = f"{metadata.cpp_namespace}::structured_{metadata.kernel}" + + if self.backend_index.device_guard: + device_check_args = itertools.chain( + f.func.arguments.out, f.func.arguments.flat_positional + ) + sig_body.append( + RegisterDispatchKey.gen_device_check( + f.device_check, list(device_check_args), sig.name() + ) + ) + + if k is SchemaKind.functional: + sig_body.append(f"{class_name} op;") + elif k is SchemaKind.inplace: + sig_body.append(f"{class_name} op(self);") + elif k is SchemaKind.out: + out_args_str = ", ".join(a.name for a in f.func.arguments.out) + sig_body.append(f"{class_name} op({out_args_str});") + + # Translate the input native arguments into structured + # arguments for the meta call + meta_exprs = ", ".join( + e.expr + for e in translate( + context, structured.meta_arguments(self.g), method=False + ) + ) + + if self.g.out.precomputed: + # If this function group has precomputed elements, the meta function + # returns a struct containing them which must be saved so that it + # can be unpacked when generating code to call the impl. + sig_body.append(f"auto precompute = op.meta({meta_exprs});") + + # Put all of the contents of the precompute struct into the context + # so that translate will be able to return the correct args for the + # call to the impl. + precomputed_values = [ + *self.g.out.precomputed.replace.values(), + self.g.out.precomputed.add, + ] + for precomputed_elems in precomputed_values: + for arg in precomputed_elems: + context.append( + Expr( + expr=f"precompute.{arg.name}", + type=structured.argument_type(arg, binds=arg.name), + ) + ) + + # Add a use of the precompute struct so FB internal compilers don't + # complain that there is an unused variable. + sig_body.append("(void)precompute;") + else: + sig_body.append(f"op.meta({meta_exprs});") + + # After running meta, op.outputs_ is guaranteed to be valid; + # add it to the context + out_args = structured.out_arguments(self.g) + for i, out_arg in enumerate(out_args): + assert ConstRefCType(BaseCType(tensorT)) == out_arg.nctype.type + + if k is SchemaKind.out: + expr = f"op.maybe_get_output({i})" + else: + expr = f"op.outputs_[{i}]" + + context.append( + Expr( + expr=expr, + # TODO: Stop hardcoding that the output type is a Tensor. Note + # that for the codegen here this is fine because outputs_ is + # hardcoded to be tensor already + type=NamedCType( + out_arg.nctype.name, MutRefCType(BaseCType(tensorT)) + ), + ) + ) + + # With the expanded context, do the impl call (if not a meta + # function) + if ( + self.backend_index.dispatch_key + == DispatchKey.CompositeExplicitAutogradNonFunctional + ): + # TODO: https://github.com/pytorch/pytorch/issues/53023 + out_sig_group = CppSignatureGroup.from_native_function( + self.g.out, method=False, fallback_binding=f.manual_cpp_binding + ) + out_sig = out_sig_group.most_faithful_signature() + api_name = out_sig.name() + out_exprs = ", ".join( + e.expr + for e in translate(context, out_sig.arguments(), method=False) + ) + # TODO: I think this means structured won't work with method + # only functions (but maybe you're saved by faithful? iunno.) + # NB: Originally I wrote this as an at::redispatch call, but + # I got in trouble because that meant I needed a DispatchKeySet + # in the wrapper function, which meant I needed a DispatchKeySet + # in the DispatchKeyFunctions declarations, but the defined API + # there does NOT permit a dispatch key set. I think you can + # probably unwind this by calling some function to do the TLS + # fetch and get the DispatchKeySet when you don't have it, but + # I didn't do it for this version + sig_body.append(f"at::{api_name}({out_exprs});") + elif self.backend_index.dispatch_key != DispatchKey.Meta: + impl_exprs = ", ".join( + e.expr + for e in translate( + context, structured.impl_arguments(self.g), method=False + ) + ) + sig_body.append(f"op.impl({impl_exprs});") + + # Go over each output, and check if there is a proxy created for it. + # If so, copy it over to the original output. + if k is SchemaKind.out or k is SchemaKind.inplace: + for i in range(len(f.func.returns)): + sig_body.append( + f"if (op.proxy_outputs_[{i}].has_value()) op.outputs_[{i}].get().copy_(*op.proxy_outputs_[{i}]);" + ) + + # Destructively return the final tensors + # TODO: Do this in translate instead + if k is SchemaKind.functional: + if len(f.func.returns) == 1: + ret_expr = "std::move(op.outputs_[0])" # small optimization + else: + moved = ", ".join( + f"std::move(op.outputs_[{i}])" + for i in range(len(f.func.returns)) + ) + ret_expr = f"std::make_tuple({moved})" + elif k is SchemaKind.inplace: + ret_expr = "self" + elif k is SchemaKind.out: + if len(f.func.returns) == 1: + ret_expr = f.func.arguments.out[0].name + else: + refs = ", ".join(a.name for a in f.func.arguments.out) + ret_expr = f"std::forward_as_tuple({refs})" + sig_body.append(f"return {ret_expr};") # type: ignore[possibly-undefined] # TODO: audit + + sig_body_str = "\n".join(sig_body) + + # For an overview of what this template code looks like, see + # https://github.com/pytorch/rfcs/pull/9 + return f"""\ +{self.gen_class( +f, k, +class_name=class_name, +parent_class=parent_class, +generate_super=self.g.out.structured_inherits is not None +)} + +{sig.defn()} {{ +{sig_body_str} +}} +""" + + elif self.target is Target.REGISTRATION: + return f'm.impl("{f.func.name}", TORCH_FN({sig.name()}));' + else: + assert_never(self.target) + # Silence mypy's "Missing return statement" error + return None diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/ufunc.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/ufunc.py new file mode 100644 index 0000000000000000000000000000000000000000..da42149c596b67d7bb13ac673e8d3c6cf141339b --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/ufunc.py @@ -0,0 +1,545 @@ +from dataclasses import dataclass +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import torchgen.api.ufunc as ufunc +from torchgen.api.translate import translate +from torchgen.api.types import ( + BaseCType, + Binding, + CType, + Expr, + NamedCType, + opmath_t, + scalar_t, + StructuredImplSignature, + VectorizedCType, +) +from torchgen.api.ufunc import UfunctorBindings +from torchgen.context import with_native_function +from torchgen.model import ( + Argument, + BaseTy, + BaseType, + DispatchKey, + NativeFunctionsGroup, + ScalarType, + UfuncKey, +) +from torchgen.utils import OrderedSet + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# CUDA STUFF +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + +# NB: not bothering to generate dispatch stub forward declaration in header, +# we can just paste it whereever necessary + +# TODO: use BackendIndex +# dispatch_key: DispatchKey # only CPU/CUDA right now + + +# Represents functors for implementing CUDA ufuncs. +# Functors are templated by scalar_t because when USERS instantiate functors +# they are templated. A functor looks something like this: +# +# template +# struct CUDAFunctorOnSelf_add { +# using opmath_t = at::opmath_type; +# opmath_t other_; +# opmath_t alpha_; +# CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha) +# : other_(other), alpha_(alpha) {} +# __device__ scalar_t operator()(scalar_t self) { +# return ufunc::add(static_cast(self), other_, alpha_); +# } +# }; +# +@dataclass(frozen=True) +class UfunctorSignature: + g: NativeFunctionsGroup + scalar_tensor_idx: Optional[int] + name: str + + def arguments(self) -> UfunctorBindings: + return ufunc.ufunctor_arguments( + self.g, scalar_tensor_idx=self.scalar_tensor_idx, scalar_t=scalar_t + ) + + def fields(self) -> List[Binding]: + # fields are renamed to have a trailing underscore, as is conventional + return [b.rename(f"{b.name}_") for b in self.arguments().ctor] + + def returns_type(self) -> CType: + # TODO: don't hardcode; return type will be inferred based on tags on + # the native function + return BaseCType(scalar_t) + + def decl_fields(self) -> str: + return "\n".join(f"{f.type} {f.name};" for f in self.fields()) + + def inline_defn_ctor(self) -> str: + args_str = ", ".join(a.decl() for a in self.arguments().ctor) + # NB: hypothetically could do this with translate but the + # transition here is very regular + init_str = ", ".join(f"{a.name}_({a.name})" for a in self.arguments().ctor) + return f"{self.name}({args_str}) : {init_str} {{}}" + + def decl_apply(self) -> str: + args_str = ", ".join(a.decl() for a in self.arguments().apply) + return f"{self.returns_type().cpp_type()} operator()({args_str}) const" + + +@dataclass(frozen=True) +class UfuncSignature: + g: NativeFunctionsGroup + name: str + compute_t: CType + + def arguments(self) -> List[Binding]: + return ufunc.ufunc_arguments(self.g, compute_t=self.compute_t) + + def call(self, ctx: Sequence[Union[Binding, Expr]]) -> str: + return f"{self.name}({', '.join(a.expr for a in translate(ctx, self.arguments()))})" + + +# steps: +# 1. take the functional signature +# 2. use api.ufunc to convert it to template signature. this establishes +# the type of the template function +# 3. use api.ufunc (II) to generate a split struct / operator() signature. +# this establish context in which we call the template signature +# +# StructuredImplSignature context +# ~> functor constructor sig +# +# Functor constructor context +# ~> functor fields sig +# +# Functor apply context (functor fields + functor apply sig) +# ~> template sig +# + + +def eligible_for_binary_scalar_specialization(g: NativeFunctionsGroup) -> bool: + num_tensors = sum( + 1 for a in g.functional.func.arguments.flat_non_out if a.type.is_tensor_like() + ) + return num_tensors == 2 + + +def compute_ufunc_cuda_functors( + g: NativeFunctionsGroup, +) -> Tuple[Dict[ScalarType, Dict[UfuncKey, UfunctorSignature]], str]: + # First, build the functors. + ufunctor_sigs: Dict[ScalarType, Dict[UfuncKey, UfunctorSignature]] = {} + ufunctors: List[str] = [] + loops = g.out.ufunc_inner_loop + scalar_tensor_idx_lookup = { + UfuncKey.CUDAFunctorOnSelf: 1, + UfuncKey.CUDAFunctorOnOther: 0, + UfuncKey.CUDAFunctor: None, + } + if eligible_for_binary_scalar_specialization(g): + keys = [ + UfuncKey.CUDAFunctorOnSelf, + UfuncKey.CUDAFunctorOnOther, + UfuncKey.CUDAFunctor, + ] + else: + keys = [UfuncKey.CUDAFunctor] + for k in [UfuncKey.CUDAFunctorOnSelf, UfuncKey.CUDAFunctorOnOther]: + assert k not in loops, f"cannot use {k} on non-binary function" + for k in keys: + # If the key was directly defined, skip functor codegen; we assume the + # user already done it for us + if k in loops: + ufunctor_sig = UfunctorSignature( + g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=loops[k].name + ) + for dtype in loops[k].supported_dtypes: + ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig + continue + + # Note [ScalarOnly and Generic must match names for CUDA] + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Otherwise, look in ANY of the generic entries. For simplicity of + # codegen, both ScalarOnly and Generic are defined, the ufunc name + # must match (if they didn't match, we'd have to generate distinct + # functors per dtype, which is awful, so we're not going to do it unless + # someone really forces us to) + ufunc_name = None + supported_dtypes: OrderedSet[ScalarType] = OrderedSet() + for lk in [UfuncKey.ScalarOnly, UfuncKey.Generic]: + if lk not in loops: + continue + if ufunc_name is None: + ufunc_name = loops[lk].name + else: + # See Note [ScalarOnly and Generic must match names for CUDA] + assert ( + ufunc_name == loops[lk].name + ), "ScalarOnly and Generic must have same ufunc name" + supported_dtypes |= loops[lk].supported_dtypes + assert ufunc_name is not None + + name = f"{k}_{ufunc_name}" + ufunctor_sig = UfunctorSignature( + g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=name + ) + for dtype in supported_dtypes: + ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig + + ufunc_sig = UfuncSignature( + g, name=f"ufunc::{ufunc_name}", compute_t=BaseCType(opmath_t) + ) + apply_ctx = ufunctor_sig.fields() + ufunctor_sig.arguments().apply + ufunctors.append( + f""" +template +struct {ufunctor_sig.name} {{ + using opmath_t = at::opmath_type; + {ufunctor_sig.decl_fields()} + {ufunctor_sig.inline_defn_ctor()} + __device__ {ufunctor_sig.decl_apply()} {{ + return {ufunc_sig.call(apply_ctx)}; + }} +}}; +""" + ) + + return ufunctor_sigs, "\n".join(ufunctors) + + +@dataclass(frozen=True) +class BinaryScalarSpecializationConfig: + scalar_idx: int + ctor_tensor: str + ufunc_key: UfuncKey + + +BinaryScalarSpecializationConfigs = [ + BinaryScalarSpecializationConfig( + scalar_idx=0, + ctor_tensor="self", + ufunc_key=UfuncKey.CUDAFunctorOnOther, + ), + BinaryScalarSpecializationConfig( + scalar_idx=1, + ctor_tensor="other", + ufunc_key=UfuncKey.CUDAFunctorOnSelf, + ), +] + + +def compute_ufunc_cuda_dtype_body( + g: NativeFunctionsGroup, + dtype: ScalarType, + inner_loops: Dict[UfuncKey, UfunctorSignature], + parent_ctx: Sequence[Binding], +) -> str: + body = "using opmath_t = at::opmath_type;" + body += "if (false) {}\n" # for ease of codegen + for config in BinaryScalarSpecializationConfigs: + if config.ufunc_key not in inner_loops: + continue + ufunctor_sig = inner_loops[config.ufunc_key] + scalar_idx = config.scalar_idx + 1 + # Make a copy and at the same time widen the type (not permissible + # without copy; we don't want to mutate the input argument anyway) + ctx: List[Union[Expr, Binding]] = list(parent_ctx) + ctx.append( + Expr( + expr=f"iter.scalar_value({scalar_idx})", + type=NamedCType(config.ctor_tensor, BaseCType(opmath_t)), + ) + ) + ufunctor_ctor_exprs_str = ", ".join( + a.expr for a in translate(ctx, ufunctor_sig.arguments().ctor) + ) + + # NB: ufunctor must be allocated before iter.remove_operand is called, + # as it relies on iter + body += f"""\ +else if (iter.is_cpu_scalar({scalar_idx})) {{ + {ufunctor_sig.name} ufunctor({ufunctor_ctor_exprs_str}); + iter.remove_operand({scalar_idx}); + gpu_kernel(iter, ufunctor); +}}""" + + ufunctor_sig = inner_loops[UfuncKey.CUDAFunctor] + ufunctor_ctor_exprs_str = ", ".join( + a.expr for a in translate(parent_ctx, ufunctor_sig.arguments().ctor) + ) + body += f""" +else {{ + gpu_kernel(iter, {ufunctor_sig.name}({ufunctor_ctor_exprs_str})); +}} + """ + return body + + +@with_native_function +def compute_ufunc_cuda(g: NativeFunctionsGroup) -> str: + # First, build the functors, indexing them by dtype + ufunctor_sigs, ufunctors = compute_ufunc_cuda_functors(g) + + # Next, build the conditionals + sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CUDA)) + dtype_cases = [] + for dtype, inner_ufunc_sigs in ufunctor_sigs.items(): + dtype_cases.append( + f""" +AT_DISPATCH_CASE(at::ScalarType::{dtype}, + [&]() {{ + {compute_ufunc_cuda_dtype_body(g, dtype, inner_ufunc_sigs, sig.arguments())} + }} +) +""" + ) + + dtype_cases_str = "\n".join(dtype_cases) + + stub_sig = StubSignature(g) + + return f""" +{ufunctors} + +{stub_sig.type_defn()}; +{stub_sig.dispatch_decl()}; + +{stub_sig.kernel_defn()} {{ + AT_DISPATCH_SWITCH(iter.common_dtype(), "{sig.name}", + {dtype_cases_str} + ); +}} +REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name}); + +{sig.defn()} {{ + {stub_sig.direct_call(sig.arguments())}; +}} +""" + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# CPU STUFF +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +@dataclass(frozen=True) +class StubSignature: + g: NativeFunctionsGroup + + @property + def name(self) -> str: + return f"{str(self.g.functional.func.name.name)}_stub" + + @property + def kernel_name(self) -> str: + return f"{str(self.g.functional.func.name.name)}_kernel" + + @property + def type_name(self) -> str: + return f"{str(self.g.functional.func.name.name)}_fn" + + def arguments(self) -> List[Binding]: + return ufunc.stub_arguments(self.g) + + def type(self) -> str: + cpp_args = self.arguments() + return f"void(*)(TensorIteratorBase&, {', '.join(a.type for a in cpp_args)})" + + def dispatch_decl(self) -> str: + return f"DECLARE_DISPATCH({self.type_name}, {self.name})" + + def dispatch_defn(self) -> str: + return f"DEFINE_DISPATCH({self.name})" + + def kernel_defn(self) -> str: + return f"void {self.kernel_name}(TensorIteratorBase& iter, {', '.join(a.defn() for a in self.arguments())})" + + def type_defn(self) -> str: + return f"using {self.type_name} = {self.type()}" + + # must be called from context where this is TensorIteratorBase* + def call(self, ctx: Sequence[Binding]) -> str: + return f"{self.name}(device_type(), *this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})" + + # used in CUDA to skip the unnecessary dynamic dispatch + def direct_call(self, ctx: Sequence[Binding]) -> str: + return f"{self.kernel_name}(*this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})" + + +@with_native_function +def compute_ufunc_cpu(g: NativeFunctionsGroup) -> str: + stub_sig = StubSignature(g) + sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CPU)) + + return f""" +{stub_sig.type_defn()}; +{stub_sig.dispatch_decl()}; +{stub_sig.dispatch_defn()}; + +{sig.defn()} {{ + {stub_sig.call(sig.arguments())}; +}} +""" + + +def compute_ufunc_cpu_dtype_body( + g: NativeFunctionsGroup, + dtype: ScalarType, + inner_loops: Dict[UfuncKey, UfuncSignature], + parent_ctx: Sequence[Binding], +) -> str: + assert UfuncKey.CPUScalar in inner_loops, f"{dtype}, {inner_loops.keys()}" + assert inner_loops.keys() <= {UfuncKey.CPUScalar, UfuncKey.CPUVector} + scalar_loop = inner_loops[UfuncKey.CPUScalar] + vec_loop = None + if UfuncKey.CPUVector in inner_loops: + vec_loop = inner_loops[UfuncKey.CPUVector] + + # NB: We DON'T use translate here, because translate is + # incapable of CSE'ing the scalar accesses in case it is also + # used by Vectorized; also, the unpacking here is very simple + # and only affects Scalar; everything else is implicitly captured + # by the lambda + + # Setup scalar in scope + body = [] + ctx = [] + for b in parent_ctx: + if isinstance(b.argument, Argument) and b.argument.type != BaseType( + BaseTy.Scalar + ): + continue + body.append(f"auto _s_{b.name} = {b.name}.to();") + ctx.append(Expr(f"_s_{b.name}", NamedCType(b.nctype.name, BaseCType(scalar_t)))) + if vec_loop is not None: + for b in parent_ctx: + if isinstance(b.argument, Argument) and b.argument.type != BaseType( + BaseTy.Scalar + ): + continue + body.append( + f"auto _v_{b.name} = at::vec::Vectorized(_s_{b.name});" + ) + ctx.append( + Expr( + f"_v_{b.name}", + NamedCType(b.nctype.name, VectorizedCType(BaseCType(scalar_t))), + ) + ) + + # Setup lambda signature + # NB: simplified version of ufunctor_arguments + scalar_bindings = [] + vec_bindings = [] + for a in g.functional.func.arguments.flat_non_out: + if not a.type.is_tensor_like(): + continue + assert a.type == BaseType(BaseTy.Tensor) + scalar_bindings.append( + Binding( + name=a.name, + nctype=NamedCType(a.name, BaseCType(scalar_t)), + argument=a, + ) + ) + if vec_loop is not None: + vec_bindings.append( + Binding( + name=a.name, + nctype=NamedCType(a.name, VectorizedCType(BaseCType(scalar_t))), + argument=a, + ) + ) + + def with_ctx(b: Sequence[Binding]) -> List[Union[Expr, Binding]]: + r: List[Union[Expr, Binding]] = [] + r.extend(ctx) + r.extend(b) + return r + + body_str = "\n".join(body) + if vec_loop is not None: + return f""" +{body_str} +cpu_kernel_vec(iter, + [=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }}, + [=]({', '.join(b.decl() for b in vec_bindings)}) {{ return {vec_loop.call(with_ctx(vec_bindings))}; }} +); +""" + else: + return f""" +{body_str} +cpu_kernel(iter, + [=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }} +); +""" + + +@with_native_function +def compute_ufunc_cpu_kernel(g: NativeFunctionsGroup) -> str: + stub_sig = StubSignature(g) + + # Reindex the ufunc by dtypes; processing generic/scalaronly as well + loops = g.out.ufunc_inner_loop + ufunc_sigs: Dict[ScalarType, Dict[UfuncKey, UfuncSignature]] = {} + for k in [UfuncKey.CPUScalar, UfuncKey.CPUVector]: + lks = [] + # ORDER MATTERS: this specifies overriding precedence + if k in loops: # should happen rarely + lks.append(k) + if UfuncKey.ScalarOnly in loops and k is UfuncKey.CPUScalar: + lks.append(UfuncKey.ScalarOnly) + if UfuncKey.Generic in loops: + lks.append(UfuncKey.Generic) + # TODO: don't hardcode ufunc:: namespace here, should be centralized smh + for lk in lks: + for dtype in loops[lk].supported_dtypes: + compute_t: CType + if k is UfuncKey.CPUScalar: + compute_t = BaseCType(scalar_t) + elif k is UfuncKey.CPUVector: + compute_t = VectorizedCType(BaseCType(scalar_t)) + else: + raise AssertionError() + inner_ufunc_sigs = ufunc_sigs.setdefault(dtype, {}) + if k not in inner_ufunc_sigs: + inner_ufunc_sigs[k] = UfuncSignature( + g, name=f"ufunc::{loops[lk].name}", compute_t=compute_t + ) + + # Build the conditionals + dtype_cases = [] + for dtype, inner_ufunc_sigs in ufunc_sigs.items(): + dtype_cases.append( + f""" +AT_DISPATCH_CASE(at::ScalarType::{dtype}, + [&]() {{ + {compute_ufunc_cpu_dtype_body(g, dtype, inner_ufunc_sigs, stub_sig.arguments())} + }} +) +""" + ) + + dtype_cases_str = "\n".join(dtype_cases) + return f""" +namespace {{ + +{stub_sig.kernel_defn()} {{ + AT_DISPATCH_SWITCH(iter.common_dtype(), "{stub_sig.name}", + {dtype_cases_str} + ); +}} + +}} // anonymous namespace + +{stub_sig.type_defn()}; +{stub_sig.dispatch_decl()}; +REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name}); +""" diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f94aa74ba9e559b6651635f71074e83f424aa513 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/__pycache__/parse.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/__pycache__/parse.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae718a5ee43012a7091c42917ffbf945287e7e75 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/__pycache__/parse.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/api/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/api/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5e6a7b97263236bfa5da6b8236cab5f1d686616 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/api/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/api/__pycache__/et_cpp.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/api/__pycache__/et_cpp.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a97348b73a7ea99ecc522a3fcafcc0e73cd0e3bf Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/api/__pycache__/et_cpp.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/api/__pycache__/unboxing.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/api/__pycache__/unboxing.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1b4f0ead23750b709656cb7a73430503db691b3 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/api/__pycache__/unboxing.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/api/custom_ops.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/api/custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..5d11f1300bb8b7ccb7d6b4bbd372a70f2e6fb219 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/api/custom_ops.py @@ -0,0 +1,142 @@ +from collections import defaultdict + +from dataclasses import dataclass +from typing import Dict, List, Optional, Sequence, Tuple + +from torchgen import dest + +# disable import sorting to avoid circular dependency. +from torchgen.api.types import DispatcherSignature # isort:skip +from torchgen.context import method_with_native_function +from torchgen.executorch.model import ETKernelIndex +from torchgen.model import BaseTy, BaseType, DispatchKey, NativeFunction, Variant +from torchgen.selective_build.selector import SelectiveBuilder +from torchgen.utils import concatMap, Target + + +# Generates RegisterKernelStub.cpp, which provides placeholder kernels for custom operators. This will be used at +# model authoring side. +@dataclass(frozen=True) +class ComputeNativeFunctionStub: + @method_with_native_function + def __call__(self, f: NativeFunction) -> Optional[str]: + if Variant.function not in f.variants: + return None + + sig = DispatcherSignature.from_schema( + f.func, prefix=f"wrapper_CPU_{f.func.name.overload_name}_", symint=False + ) + assert sig is not None + if len(f.func.returns) == 0: + ret_name = "" + elif len(f.func.returns) == 1: + if f.func.arguments.out: + ret_name = f.func.arguments.out[0].name + else: + ret_name = next( + ( + a.name + for a in f.func.arguments.flat_non_out + if a.type == f.func.returns[0].type + ), + "", + ) + if not ret_name: + # if return type is tensor + if f.func.returns[0].type == BaseType(BaseTy.Tensor): + # Returns an empty tensor + ret_name = "at::Tensor()" + else: + raise Exception(f"Can't handle this return type {f.func}") + elif len(f.func.arguments.out) == len(f.func.returns): + # Returns a tuple of out arguments + tensor_type = "at::Tensor &" + comma = ", " + ret_name = f"""::std::tuple<{comma.join([tensor_type] * len(f.func.returns))}>( + {comma.join([r.name for r in f.func.arguments.out])} + )""" + else: + assert all( + a.type == BaseType(BaseTy.Tensor) for a in f.func.returns + ), f"Only support tensor returns but got {f.func.returns}" + # Returns a tuple of empty tensors + tensor_type = "at::Tensor" + comma = ", " + ret_name = f"""::std::tuple<{comma.join([tensor_type] * len(f.func.returns))}>( + {comma.join(["at::Tensor()" for _ in f.func.returns])} + )""" + ret_str = f"return {ret_name};" if len(f.func.returns) > 0 else "" + return f""" +{sig.defn()} {{ + {ret_str} +}} + """ + + +def gen_custom_ops_registration( + *, + native_functions: Sequence[NativeFunction], + selector: SelectiveBuilder, + kernel_index: ETKernelIndex, + rocm: bool, +) -> Tuple[str, str]: + """ + Generate custom ops registration code for dest.RegisterDispatchKey. + + :param native_functions: a sequence of `NativeFunction` + :param selector: for selective build. + :param kernel_index: kernels for all the ops. + :param rocm: bool for dest.RegisterDispatchKey. + :return: generated C++ code to register custom operators into PyTorch + """ + + # convert kernel index to BackendIndex. This is because we can't handle ETKernelIndex yet. + # TODO larryliu: evaluate if this code is still needed. If yes let it handle ETKernelIndex. + + dispatch_key = DispatchKey.CPU + backend_index = kernel_index._to_backend_index() + static_init_dispatch_registrations = "" + ns_grouped_native_functions: Dict[str, List[NativeFunction]] = defaultdict(list) + for native_function in native_functions: + ns_grouped_native_functions[native_function.namespace].append(native_function) + + for namespace, functions in ns_grouped_native_functions.items(): + if len(functions) == 0: + continue + dispatch_registrations_body = "\n".join( + list( + concatMap( + dest.RegisterDispatchKey( + backend_index, + Target.REGISTRATION, + selector, + rocm=rocm, + symint=False, + class_method_name=None, + skip_dispatcher_op_registration=False, + ), + functions, + ) + ) + ) + static_init_dispatch_registrations += f""" +TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{ +{dispatch_registrations_body} +}};""" + anonymous_definition = "\n".join( + list( + concatMap( + dest.RegisterDispatchKey( + backend_index, + Target.ANONYMOUS_DEFINITION, + selector, + rocm=rocm, + symint=False, + class_method_name=None, + skip_dispatcher_op_registration=False, + ), + native_functions, + ) + ) + ) + return anonymous_definition, static_init_dispatch_registrations diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/api/et_cpp.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/api/et_cpp.py new file mode 100644 index 0000000000000000000000000000000000000000..24dda58ecdbc4884b8502d0d44dba29098e080af --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/api/et_cpp.py @@ -0,0 +1,368 @@ +from typing import List, Optional, Sequence, Set, Union + +from torchgen import local +from torchgen.api.types import ( + ArgName, + ArrayCType, + BaseCType, + Binding, + ConstRefCType, + CType, + MutRefCType, + NamedCType, + SpecialArgName, + TupleCType, + VectorCType, + voidT, +) +from torchgen.model import ( + Argument, + Arguments, + BaseTy, + BaseType, + ListType, + NativeFunction, + OptionalType, + Return, + SelfArgument, + TensorOptionsArguments, + Type, +) +from torchgen.utils import assert_never +from .types import ( + ArrayRefCType, + BaseTypeToCppMapping, + OptionalCType, + scalarT, + tensorListT, + tensorT, +) + +""" +This file describes the translation of JIT schema to the public C++ API, which is what people use when they call +functions like at::add. It also serves as a native function API, which is the signature of kernels, +since in Executorch CppSignature is the same as NativeSignature. + +Difference between this file and torchgen.api.cpp.py: + + - Executorch doesn't support TensorOptions, however in this file we still keep the logic here to be compatible with + torchgen.api.cpp, so that we can do stuff like ATen mode (running ATen kernels in Executorch). + + - Executorch doesn't support Dimname. + + - Executorch runtime doesn't support SymInt, will treat it as int. +""" + + +# Translation of "value types" in JIT schema to C++ API type. Value +# types look the same no matter if they are argument types or return +# types. Returns None if the type in question is not a value type. +def valuetype_type( + t: Type, + *, + binds: ArgName, + remove_non_owning_ref_types: bool = False, +) -> Optional[NamedCType]: + if isinstance(t, BaseType): + if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar: + return None + # For SymInt we simply treat it as int. + elif str(t) == "SymInt": + return NamedCType(binds, BaseCType(BaseTypeToCppMapping[BaseTy.int])) + if remove_non_owning_ref_types: + if t.name == BaseTy.str: + raise AssertionError( + "string ref->value conversion: not implemented yet" + ) + # All other BaseType currently map directly to BaseCppTypes. + return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name])) + elif isinstance(t, OptionalType): + elem = valuetype_type(t.elem, binds=binds) + if elem is None: + return None + return NamedCType(binds, OptionalCType(elem.type)) + elif isinstance(t, ListType): + if str(t.elem) == "bool": + assert t.size is not None + return NamedCType( + binds, ArrayCType(BaseCType(BaseTypeToCppMapping[BaseTy.bool]), t.size) + ) + else: + return None + else: + raise AssertionError(f"unrecognized type {repr(t)}") + + +# Translation of types occurring in JIT arguments to a C++ argument type. +# If remove_non_owning_ref_types is set, we'll guarantee that the outputed CType is not a non-owning reference type. +# For example, we'll return std::vector instead of IntArrayRef. +# See Note [translation from C++ reference to value types] +def argumenttype_type( + t: Type, + *, + mutable: bool, + binds: ArgName, + remove_non_owning_ref_types: bool = False, +) -> NamedCType: + # If it's a value type, do the value type translation + r = valuetype_type( + t, + binds=binds, + remove_non_owning_ref_types=remove_non_owning_ref_types, + ) + if r is not None: + return r + if isinstance(t, BaseType): + if t.name == BaseTy.Tensor: + if mutable and not local.use_const_ref_for_mutable_tensors(): + return NamedCType(binds, MutRefCType(BaseCType(tensorT))) + else: + return NamedCType(binds, ConstRefCType(BaseCType(tensorT))) + elif t.name == BaseTy.Scalar: + return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) + else: + raise AssertionError(f"base type should have been value type {t}") + elif isinstance(t, OptionalType): + if str(t.elem) == "Tensor": + if mutable and not local.use_const_ref_for_mutable_tensors(): + return NamedCType( + binds, MutRefCType(BaseCType(tensorT)) + ) # TODO: fix this discrepancy + else: + return NamedCType( + binds, ConstRefCType(OptionalCType(BaseCType(tensorT))) + ) + elif str(t.elem) == "Scalar": + return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT)))) + elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) + return NamedCType(binds, OptionalCType(elem.type)) + elif isinstance(t, ListType): + # TODO: keeping these special cases for Tensor[] and Tensor?[] so that we can hookup with ATen kernels. + if str(t.elem) == "Tensor": + return NamedCType(binds, BaseCType(tensorListT)) + elif str(t.elem) == "Dimname": + raise NotImplementedError("Executorch doesn't support Dimname") + elif str(t.elem) == "Tensor?": + return NamedCType(binds, ArrayRefCType(OptionalCType(BaseCType(tensorT)))) + elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) + return NamedCType(binds, ArrayRefCType(elem.type)) + else: + raise AssertionError(f"unrecognized type {repr(t)}") + + +# Translate a JIT argument into its C++ type +def argument_type(a: Argument, *, binds: ArgName) -> NamedCType: + return argumenttype_type(a.type, mutable=a.is_write, binds=binds) + + +# Translation of a (non-multi) return type from JIT to C++ +# N.B: returntype_type returns a CType, not a NamedCType. +# This is mostly because of the mismatch between return types and return names. +# e.g. a function with a return type of 'void' has 0 return names, +# and a function with a return type of 'std::tuple' has >1 return name. +def returntype_type(t: Type, *, mutable: bool) -> CType: + # placeholder is ignored + r = valuetype_type(t, binds="__placeholder__") + if r is not None: + return r.type + + if isinstance(t, BaseType): + if t.name == BaseTy.Tensor: + if mutable: + if local.use_const_ref_for_mutable_tensors(): + return ConstRefCType(BaseCType(tensorT)) + else: + return MutRefCType(BaseCType(tensorT)) + else: + # Note [Tensor Copy Returns] + # Currently, we use "Argument.is_write" to determine + # whether or not Tensor return types should be copies or references. + # If that ever changes, take a look at other locations of this note! + return BaseCType(tensorT) + elif t.name == BaseTy.Scalar: + return BaseCType(scalarT) + elif isinstance(t, ListType): + assert ( + not mutable + ), "Native functions should never return a mutable tensor list. They should return void." + elem = returntype_type(t.elem, mutable=False) + assert t.size is None, f"fixed size list returns not supported: {t}" + return VectorCType(elem) + + raise AssertionError(f"unrecognized return type {t}") + + +# Translation of a single return to its C++ type +def return_type(r: Return) -> CType: + return returntype_type(r.type, mutable=r.is_write) + + +# Translation of a full (possibly multi) return from JIT to its C++ type +def returns_type(rs: Sequence[Return]) -> CType: + if len(rs) == 0: + return BaseCType(voidT) + elif len(rs) == 1: + return return_type(rs[0]) + else: + return TupleCType([return_type(r) for r in rs]) + + +def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]: + returns: List[str] = [] + for i, r in enumerate(f.func.returns): + # If we have an inplace function, the return argument is + # implicitly named self. + # TODO: Consider incorporating this into the data model + if f.func.name.name.inplace: + assert i == 0, "illegal inplace function with multiple returns" + name = "self" + # If we are out function, the name is the name of the + # corresponding output function (r.name will get recorded + # in field_name later.) + elif f.func.is_out_fn(): + name = f.func.arguments.out[i].name + # If the return argument is explicitly named... + elif r.name: + name_conflict = any( + r.name == a.name for a in f.func.schema_order_arguments() + ) + if name_conflict and not f.func.is_out_fn(): + name = f"{r.name}_return" + else: + name = r.name + # If there is no explicit name and no fallback name was passed in, we just name the output result, + # unless it's a multi-return, in which case it's result0, + # result1, etc (zero-indexed) + else: + name = fallback_name if len(f.func.returns) == 1 else f"{fallback_name}{i}" + returns.append(name) + return returns + + +JIT_TO_CPP_DEFAULT = { + "False": "false", + "True": "true", + "None": "torch::executorch::nullopt", # UGH this one is type directed + "[]": "{}", + "contiguous_format": "torch::executorch::MemoryFormat::Contiguous", + "long": "torch::executorch::kLong", +} + + +# Convert a JIT default into C++ expression representing the default +def default_expr(d: str, t: Type) -> str: + if d == "None" and str(t) == "Tensor?": + return "{}" + if isinstance(t, BaseType) and t.name is BaseTy.str: + # Schema allows single quotes but C++ needs double + if len(d) >= 2 and d[0] == "'" and d[-1] == "'": + s = "" + i = 1 + while i + 1 < len(d): + if d[i] != "\\": + if d[i] == '"': + s += '\\"' + else: + s += d[i] + i += 1 + else: + if d[i + 1] == "'": + s += "'" + else: + s += d[i : i + 2] + i += 2 + + return f'"{s}"' + + if isinstance(t, OptionalType): + if d == "None": + return "torch::executor::nullopt" + + return default_expr(d, t.elem) + + if isinstance(t, ListType): + if d.startswith("[") and d.endswith("]"): + return "{" + d[1:-1] + "}" + elif t.size is None: + # NOTE: Sized lists can have scalar defaults + raise ValueError(f"Expected a list default '[...]' but found: '{d}'") + + return JIT_TO_CPP_DEFAULT.get(d, d) + + +# Convert an argument into its C++ API form + + +def argument( + a: Union[Argument, TensorOptionsArguments, SelfArgument], + *, + cpp_no_default_args: Set[str], + method: bool, + faithful: bool, + has_tensor_options: bool, +) -> List[Binding]: + def sub_argument( + a: Union[Argument, TensorOptionsArguments, SelfArgument] + ) -> List[Binding]: + return argument( + a, + cpp_no_default_args=cpp_no_default_args, + method=method, + faithful=faithful, + has_tensor_options=has_tensor_options, + ) + + if isinstance(a, Argument): + binds: ArgName + if a.name == "memory_format" and has_tensor_options: + binds = SpecialArgName.possibly_redundant_memory_format + else: + binds = a.name + default: Optional[str] = None + if a.name not in cpp_no_default_args and a.default is not None: + default = default_expr(a.default, a.type) + return [ + Binding( + nctype=argument_type(a, binds=binds), + name=a.name, + default=default, + argument=a, + ) + ] + elif isinstance(a, TensorOptionsArguments): + raise NotImplementedError("Need to implement type resolution for TensorOptions") + elif isinstance(a, SelfArgument): + if method: + # Caller is responsible for installing implicit this in context! + return [] + else: + return sub_argument(a.argument) + else: + assert_never(a) + + +def arguments( + arguments: Arguments, + *, + faithful: bool, + method: bool, + cpp_no_default_args: Set[str], +) -> List[Binding]: + args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = [] + if faithful: + args.extend(arguments.non_out) + args.extend(arguments.out) + else: + args.extend(arguments.out) + args.extend(arguments.non_out) + return [ + r.no_default() if faithful else r + for a in args + for r in argument( + a, + faithful=faithful, + method=method, + has_tensor_options=arguments.tensor_options is not None, + cpp_no_default_args=cpp_no_default_args, + ) + ] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/api/types/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/api/types/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eb5e802634f82e1557f9245bf857d9e54b748d31 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/api/types/__init__.py @@ -0,0 +1,2 @@ +from .types import * +from .signatures import * # isort:skip diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/api/types/__pycache__/types.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/api/types/__pycache__/types.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1d53e17e0c057f353ffa92c2d3807b4ef2c745c Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/api/types/__pycache__/types.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/parse.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/parse.py new file mode 100644 index 0000000000000000000000000000000000000000..89b4b93558a6a22b21beafba722bff76372be9c0 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/parse.py @@ -0,0 +1,151 @@ +from collections import defaultdict, namedtuple +from typing import Any, Dict, List, Optional, Set, Tuple + +import yaml + +from torchgen.executorch.model import ETKernelIndex, ETKernelKey + +from torchgen.gen import LineLoader, parse_native_yaml +from torchgen.model import ( + BackendMetadata, + DispatchKey, + FunctionSchema, + NativeFunction, + OperatorName, +) +from torchgen.utils import NamespaceHelper + +# Parse native_functions.yaml into a sequence of NativeFunctions and ET Backend Indices. +ETParsedYaml = namedtuple("ETParsedYaml", ["native_functions", "et_kernel_indices"]) + +# Fields in native_functions.yaml used to determine which kernels should be used +ET_FIELDS = ["kernels", "type_alias", "dim_order_alias"] + + +def parse_from_yaml(ei: Dict[str, object]) -> Dict[ETKernelKey, BackendMetadata]: + """Given a loaded yaml representing kernel assignment information, extract the + mapping from `kernel keys` to `BackendMetadata` (the latter representing the kernel instance) + + Args: + ei: Dict keys {kernels, type_alias, dim_order_alias} + See ETKernelKey for description of arguments + """ + e = ei.copy() + if (kernels := e.pop("kernels", None)) is None: + return {} + + type_alias: Dict[str, List[str]] = e.pop("type_alias", {}) # type: ignore[assignment] + dim_order_alias: Dict[str, List[str]] = e.pop("dim_order_alias", {}) # type: ignore[assignment] + dim_order_alias.pop("__line__", None) + + kernel_mapping: Dict[ETKernelKey, BackendMetadata] = {} + + for entry in kernels: # type: ignore[attr-defined] + arg_meta = entry.get("arg_meta") + if arg_meta is not None: + arg_meta.pop("__line__") + + kernel_name = entry.get("kernel_name") + namespace_helper = NamespaceHelper.from_namespaced_entity( + kernel_name, max_level=3 + ) + kernel_namespace = namespace_helper.get_cpp_namespace(default="at") + backend_metadata = BackendMetadata( + kernel=namespace_helper.entity_name, + structured=False, + cpp_namespace=(kernel_namespace + "::native"), + ) + + kernel_keys = ( + [ETKernelKey((), default=True)] + if arg_meta is None + else ETKernelKey.gen_from_yaml(arg_meta, type_alias, dim_order_alias) # type: ignore[arg-type] + ) + + for kernel_key in kernel_keys: + assert kernel_key not in kernel_mapping, ( + "Duplicate kernel key: " + str(kernel_key) + " " + str(e) + ) + kernel_mapping[kernel_key] = backend_metadata + + return kernel_mapping + + +def parse_et_yaml_struct(es: object) -> ETKernelIndex: + """Given a loaded yaml representing a list of operators, for each op extract the mapping + of `kernel keys` to `BackendMetadata` (the latter representing the kernel instance + that should be used by the kernel key). + """ + indices: Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]] = {} + for ei in es: # type: ignore[attr-defined] + e = ei.copy() + + funcs = e.pop("func") + assert isinstance(funcs, str), f"not a str: {funcs}" + namespace_helper = NamespaceHelper.from_namespaced_entity( + namespaced_entity=funcs, max_level=1 + ) + opname = FunctionSchema.parse(namespace_helper.entity_name).name + + assert opname not in indices, f"Duplicate func found in yaml: {opname} already" + + if len(index := parse_from_yaml(e)) != 0: + indices[opname] = index + + return ETKernelIndex(indices) + + +def extract_kernel_fields(es: object) -> Dict[OperatorName, Dict[str, Any]]: + """Given a loaded yaml representing a list of operators, extract the + kernel key related fields indexed by the operator name. + """ + fields: Dict[OperatorName, Dict[str, Any]] = defaultdict(dict) + for ei in es: # type: ignore[attr-defined] + funcs = ei.get("func") + assert isinstance(funcs, str), f"not a str: {funcs}" + namespace_helper = NamespaceHelper.from_namespaced_entity( + namespaced_entity=funcs, max_level=1 + ) + opname = FunctionSchema.parse(namespace_helper.entity_name).name + + for field in ET_FIELDS: + if (value := ei.get(field)) is not None: + fields[opname][field] = value + + return fields + + +def parse_et_yaml( + path: str, + tags_yaml_path: str, + ignore_keys: Optional[Set[DispatchKey]] = None, + skip_native_fns_gen: bool = False, +) -> Tuple[List[NativeFunction], Dict[OperatorName, Dict[str, Any]]]: + """Parse native_functions.yaml into NativeFunctions and an Operator Indexed Dict + of fields to persist from native_functions.yaml to functions.yaml + """ + with open(path) as f: + es = yaml.load(f, Loader=LineLoader) + + et_kernel = extract_kernel_fields(es) + + # Remove ET specific fields from entries for BC compatibility + strip_et_fields(es) + + native_yaml = parse_native_yaml( + path, + tags_yaml_path, + ignore_keys, + skip_native_fns_gen=skip_native_fns_gen, + loaded_yaml=es, + ) + return native_yaml.native_functions, et_kernel + + +def strip_et_fields(es: object) -> None: + """Given a loaded yaml representing a list of operators, + remove ET specific fields from every entries for BC compatibility + """ + for entry in es: # type: ignore[attr-defined] + for field in ET_FIELDS: + entry.pop(field, None) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/gen.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/gen.py new file mode 100644 index 0000000000000000000000000000000000000000..6cb980fb29dd09cfb6429a8835582f21fe9c082c --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/gen.py @@ -0,0 +1,2937 @@ +import argparse +import functools +import json +import os +import pathlib +from collections import defaultdict, namedtuple, OrderedDict +from dataclasses import dataclass, field +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Optional, + Sequence, + Set, + Tuple, + TypeVar, + Union, +) + +import yaml + +import torchgen.api.dispatcher as dispatcher +import torchgen.api.meta as meta +import torchgen.api.native as native +import torchgen.api.structured as structured +import torchgen.dest as dest + +from torchgen.api import cpp +from torchgen.api.translate import translate +from torchgen.api.types import ( + Binding, + CppSignature, + CppSignatureGroup, + DispatcherSignature, + NamedCType, + NativeSignature, + SpecialArgName, +) +from torchgen.context import ( + method_with_native_function, + native_function_manager, + with_native_function, + with_native_function_and_indices, +) +from torchgen.gen_aoti_c_shim import ( + gen_aoti_c_shim, + gen_static_dispatch_backend_call_signature, + get_backend_index_for_aoti, +) +from torchgen.gen_functionalization_type import ( + gen_functionalization_definition, + gen_functionalization_registration, + gen_functionalization_view_inverse_declaration, + GenCompositeViewCopyKernel, +) +from torchgen.gen_vmap_plumbing import gen_all_vmap_plumbing + +from torchgen.model import ( + Argument, + BackendIndex, + BackendMetadata, + BaseOperatorName, + DEFAULT_KERNEL_NAMESPACE, + DispatchKey, + FRAGMENT_NAMESPACES, + FunctionSchema, + is_cuda_dispatch_key, + is_generic_dispatch_key, + is_ufunc_dispatch_key, + Location, + NativeFunction, + NativeFunctionsGroup, + NativeFunctionsViewGroup, + OperatorName, + OptionalType, + SchemaKind, + SelfArgument, + STRUCTURED_DISPATCH_KEYS, + TensorOptionsArguments, + Type, + Variant, + ViewSchemaKind, +) +from torchgen.native_function_generation import ( + add_generated_native_functions, + gen_composite_functional_kernel, + gen_composite_out_kernel, + pre_group_native_functions, +) +from torchgen.selective_build.selector import SelectiveBuilder +from torchgen.utils import ( + assert_never, + concatMap, + context, + FileManager, + make_file_manager, + mapMaybe, + NamespaceHelper, + Target, +) +from torchgen.yaml_utils import YamlDumper, YamlLoader + +T = TypeVar("T") + +# Welcome to the ATen code generator v2! The ATen code generator is +# responsible for parsing native_functions.yaml and then generating +# various generated files (e.g., TypeDefault.cpp) based on the operators +# defined in this file. This means that the code generator knows how to +# parse function schema, and then translate this into various C++ types +# and boilerplate code. +# +# Some things to know about this file when you modify it: +# +# - This file has STRICT mypy typechecking. Typecheck it with +# `mypy --config mypy-strict.ini` in the root source directory +# +# - Most of the heavy lifting lives in external modules: +# - 'model' has the data model for native_functions.yaml. The classes +# in those file represent what you see when you look at +# a native_functions.yaml +# - 'api' has conversions for how to translate JIT schema into +# the various C++ APIs that the codegen interacts with. There +# are in fact THREE different C++ APIs: the public C++ API, +# the dispatcher API, and the legacy dispatcher API. See each +# of these respective files for more information + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# HELPER FUNCTIONS +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +# A custom loader for YAML to let us also keep track of line numbers +# of each entry in the YAML file +class LineLoader(YamlLoader): + def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def] + mapping = super().construct_mapping(node, deep=deep) # type: ignore[no-untyped-call] + # Add 1 so line numbering starts at 1 + mapping["__line__"] = node.start_mark.line + 1 + return mapping + + +# Parse native_functions.yaml into a sequence of NativeFunctions and Backend Indices. +ParsedYaml = namedtuple("ParsedYaml", ["native_functions", "backend_indices"]) + + +_GLOBAL_PARSE_NATIVE_YAML_CACHE: Dict[str, ParsedYaml] = {} +_GLOBAL_PARSE_TAGS_YAML_CACHE: Dict[str, Set[str]] = {} + + +def parse_native_yaml_struct( + es: object, + valid_tags: Set[str], + ignore_keys: Optional[Set[DispatchKey]] = None, + path: str = "", + skip_native_fns_gen: bool = False, +) -> ParsedYaml: + assert isinstance(es, list) + rs: List[NativeFunction] = [] + bs: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = defaultdict(dict) + for e in es: + assert isinstance(e.get("__line__"), int), e + loc = Location(path, e["__line__"]) + funcs = e.get("func") + with context(lambda: f"in {loc}:\n {funcs}"): + func, m = NativeFunction.from_yaml(e, loc, valid_tags, ignore_keys) + rs.append(func) + BackendIndex.grow_index(bs, m) + error_check_native_functions(rs) + # Default dict is to prevent the codegen from barfing when we have a dispatch key that has no kernels yet. + indices: Dict[DispatchKey, BackendIndex] = defaultdict( + lambda: BackendIndex( + dispatch_key=DispatchKey.Undefined, + use_out_as_primary=True, + external=False, + device_guard=False, + # I'm actually not sure about this; undefined could be hit on + # empty TensorList, hypothetically that could have sizes in it + index={}, + ) + ) + if not skip_native_fns_gen: + add_generated_native_functions(rs, bs) + for k, v in bs.items(): + # All structured in-tree operators are implemented in terms of their out operator. + indices[k] = BackendIndex( + dispatch_key=k, + use_out_as_primary=True, + external=False, + # Only cuda-like devices in tree require device guards + device_guard=is_cuda_dispatch_key(k), + index=v, + ) + return ParsedYaml(rs, indices) + + +def parse_tags_yaml_struct(es: object, path: str = "") -> Set[str]: + assert isinstance(es, list) + rs: Set[str] = set() + for e in es: + assert isinstance(e.get("__line__"), int), e + loc = Location(path, e["__line__"]) + tags = e.get("tag") + with context(lambda: f"in {loc}:\n {tags}"): + e_i = e.copy() + name = e_i.pop("tag") + desc = e_i.pop("desc", "") + # ensure that each tag has a non-empty description + assert desc != "" + rs.add(name) + return rs + + +@functools.lru_cache(maxsize=None) +def parse_tags_yaml(path: str) -> Set[str]: + global _GLOBAL_PARSE_TAGS_YAML_CACHE + if path not in _GLOBAL_PARSE_TAGS_YAML_CACHE: + with open(path) as f: + es = yaml.load(f, Loader=LineLoader) + _GLOBAL_PARSE_TAGS_YAML_CACHE[path] = parse_tags_yaml_struct(es, path=path) + + return _GLOBAL_PARSE_TAGS_YAML_CACHE[path] + + +def parse_native_yaml( + path: str, + tags_yaml_path: str, + ignore_keys: Optional[Set[DispatchKey]] = None, + *, + skip_native_fns_gen: bool = False, + loaded_yaml: Optional[object] = None, +) -> ParsedYaml: + global _GLOBAL_PARSE_NATIVE_YAML_CACHE + if path not in _GLOBAL_PARSE_NATIVE_YAML_CACHE: + valid_tags = parse_tags_yaml(tags_yaml_path) + + # if a loaded yaml is provided, use that instead of reading from path + if loaded_yaml is None: + with open(path) as f: + es = yaml.load(f, Loader=LineLoader) + else: + es = loaded_yaml + + _GLOBAL_PARSE_NATIVE_YAML_CACHE[path] = parse_native_yaml_struct( + es, + valid_tags, + ignore_keys, + path=path, + skip_native_fns_gen=skip_native_fns_gen, + ) + + return _GLOBAL_PARSE_NATIVE_YAML_CACHE[path] + + +# Some assertions are already performed during parsing, but those are only within a single NativeFunction. +# Assertions here are meant to be performed across NativeFunctions. +def error_check_native_functions(funcs: Sequence[NativeFunction]) -> None: + func_map: Dict[OperatorName, NativeFunction] = {} + base_func_map: Dict[BaseOperatorName, List[NativeFunction]] = defaultdict(list) + for f in funcs: + func_map[f.func.name] = f + base_func_map[f.func.name.name].append(f) + for f in funcs: + if f.structured_delegate is not None: + delegate_func = func_map[f.structured_delegate] + assert delegate_func.structured, ( + f"{f.func.name} is marked as a structured_delegate pointing to " + f"{f.structured_delegate}, but {f.structured_delegate} is not marked as structured. " + f"Consider adding 'structured=True' to the delegated operator" + ) + # See Note [resize_ in Functionalization] + # resize_() is technically an inplace view op (and therefore needs the tag), + # but it would be overkill to add a true "view" variant of resize. + # Instead, resize_() gets special treatment in functionalization, + # and we have a resize() op that is non-aliasing + functional. + if ( + "inplace_view" in f.tags + and str(f.func.name) != "resize_" + and str(f.func.name) != "resize_as_" + and str(f.func.name.name) != "set_" + ): + base_name = f.func.name.name + overload_name = f.func.name.overload_name + assert base_name.inplace, ( + f"{f.func.name} is marked with tag: inplace_view, but it doesn't follow the naming " + "convention for inplace ops - the codegen expects the base name to have a trailing underscore. " + ) + out_of_place_base_name = BaseOperatorName( + base_name.base, False, base_name.dunder_method + ) + assert len(base_func_map[out_of_place_base_name]) > 0, ( + f"{f.func.name} is marked with tag: inplace_view. The codegen expects there to be a corresponding " + f"out-of-place view op with the name '{base_name}' and matching schema, but it didn't find one. " + ) + + +def cpp_string(s: str) -> str: + """Convert a python string into a c++ string literal""" + s = s.replace("\\", "\\\\") + s = s.replace('"', '\\"') + s = s.replace("\a", "\\a") + s = s.replace("\b", "\\b") + s = s.replace("\f", "\\f") + s = s.replace("\n", "\\n") + s = s.replace("\v", "\\v") + s = s.replace("\t", "\\t") + return f'"{s}"' + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# C++ CODE GENERATION +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + +# Most functions in this section are curried: they consist of a function +# that takes some parameters (e.g., what is to be generated) which itself +# returns a function that actually maps NativeFunction to the code +# to be generated. This pattern makes it convenient to use map, concatMap +# and similar functional combinators. + + +def static_dispatch_keys(backends: List[BackendIndex]) -> List[DispatchKey]: + if len(backends) == 0: + return [] + else: + return [backend.dispatch_key for backend in backends] + [ + DispatchKey.CompositeImplicitAutograd, + DispatchKey.CompositeImplicitAutogradNestedTensor, + DispatchKey.CompositeExplicitAutograd, + DispatchKey.CompositeExplicitAutogradNonFunctional, + ] + + +def get_static_dispatch_backend( + f: NativeFunction, backend_index: BackendIndex +) -> Optional[DispatchKey]: + if f.structured_delegate is not None or backend_index.has_kernel(f): + # TODO: for ops with structured_delegate it should check the dispatch table of + # the out variant instead. For now, these structured ops all have CPU/CUDA kernels + # so we always dispatch to the `backend`, but this could be wrong when we + # migrate math/default_backend ops to use structured delegate. + return backend_index.dispatch_key + elif f.has_composite_explicit_autograd_kernel: + return DispatchKey.CompositeExplicitAutograd + elif f.has_composite_explicit_autograd_non_functional_kernel: + return DispatchKey.CompositeExplicitAutogradNonFunctional + elif f.has_composite_implicit_autograd_kernel: + return DispatchKey.CompositeImplicitAutograd + elif f.has_composite_implicit_autograd_nested_tensor_kernel: + return DispatchKey.CompositeImplicitAutogradNestedTensor + return None + + +def static_dispatch_ops_header( + f: NativeFunction, backend_index: List[BackendIndex] +) -> Optional[str]: + if backend_index is None or f.manual_kernel_registration: + return None + + output = [] + for index in backend_index: + dispatch_key = get_static_dispatch_backend(f, index) + if dispatch_key is not None: + output.append( + f"#include " + ) + return "\n".join(output) + + +def static_dispatch_extra_headers(backends: List[BackendIndex]) -> List[str]: + return [ + f"#include " + for dispatch_key in static_dispatch_keys(backends) + ] + + +# Translates arguments of `sig` to CppSignature bindings. +# Note that we have a special case for `memory_format` argument and this case is not covered by +# tools.codegen.api.translate() yet as its application is limited to static dispatch. +def translate_args( + sig: Union[CppSignature, DispatcherSignature], + cpp_sig: CppSignature, +) -> str: + # Adds SpecialArgName.possibly_redundant_memory_format NamedCType for memory_format bindings + def add_spl_memory_format_binding(input_bindings: List[Binding]) -> List[Binding]: + output_bindings: List[Binding] = [] + for binding in input_bindings: + if binding.name == "memory_format": + spl_mem_format_binding = Binding( + nctype=NamedCType( + SpecialArgName.possibly_redundant_memory_format, + binding.nctype.type, + ), + name=binding.name, + default=binding.default, + argument=binding.argument, + ) + output_bindings.append(spl_mem_format_binding) + else: + output_bindings.append(binding) + return output_bindings + + src_bindings = list(sig.arguments()) + goal_bindings = list(cpp_sig.arguments()) + # When last argument of CPP signature has SpecialArgName.possibly_redundant_memory_format NCType, + # get memory_format bindings of dispatcher signature to have the same NCType as well + for arg in goal_bindings: + if arg.nctype.name == SpecialArgName.possibly_redundant_memory_format: + src_bindings = add_spl_memory_format_binding(src_bindings) + break + exprs = translate(src_bindings, goal_bindings) + return ", ".join(a.expr for a in exprs) + + +def generate_static_dispatch_backend_call( + sig: Union[CppSignature, DispatcherSignature], + f: NativeFunction, + backend_index: BackendIndex, +) -> str: + cpp_sig = gen_static_dispatch_backend_call_signature(sig, f) + name = cpp_sig.name() + exprs = translate_args(sig, cpp_sig) + backend_metadata = backend_index.get_kernel(f) + kernel_ns = ( + backend_metadata.cpp_namespace + if backend_metadata and backend_metadata.cpp_namespace + else DEFAULT_KERNEL_NAMESPACE + ) + ns = kernel_ns.replace("::native", "") + return f"return {ns}::{backend_index.dispatch_key.lower()}::{name}({exprs});" + + +def generate_static_dispatch_fallback_call( + sig: Union[CppSignature, DispatcherSignature], + f: NativeFunction, + backend_indices: List[BackendIndex], +) -> str: + cpp_sigs = CppSignatureGroup.from_native_function( + f, method=False, fallback_binding=False + ) + if sig.symint and f.func.has_symint(): + cpp_sig = cpp_sigs.symint_signature + else: + cpp_sig = cpp_sigs.signature + assert cpp_sig is not None + name = cpp_sig.name() + exprs = translate_args(sig, cpp_sig) + ns = DEFAULT_KERNEL_NAMESPACE.replace("::native", "") + if f.has_composite_explicit_autograd_kernel: + return f"return {ns}::{DispatchKey.CompositeExplicitAutograd.lower()}::{name}({exprs});" + elif f.has_composite_explicit_autograd_non_functional_kernel: + return f"return {ns}::{DispatchKey.CompositeExplicitAutogradNonFunctional.lower()}::{name}({exprs});" + elif f.has_composite_implicit_autograd_kernel: + return f"return {ns}::{DispatchKey.CompositeImplicitAutograd.lower()}::{name}({exprs});" + elif f.has_composite_implicit_autograd_nested_tensor_kernel: + return f"return {ns}::{DispatchKey.CompositeImplicitAutogradNestedTensor.lower()}::{name}({exprs});" + else: + return f"""TORCH_CHECK(false, "Static dispatch does not support {name} for\ +{', '.join([str(index.dispatch_key)for index in backend_indices])} ");""" + + +def static_dispatch( + sig: Union[CppSignature, DispatcherSignature], + f: NativeFunction, + backend_indices: List[BackendIndex], +) -> str: + """ + For a given `NativeFunction`, find out the corresponding backend and dispatch to it. If more than one + backends exsit, fallback to static dispatch by determining dispatch key from inputs. + Arguments: + sig: A CppSignature or DispatcherSignature for this native function we want to use. + f: NativeFunction to generate static dispatch. + backend_indices: All available backends. + Return: + C++ code to call backend-specific functions, e.g., "return at::cpu::add(self, other, scale);" + """ + if len(backend_indices) == 0 or f.manual_kernel_registration: + return "" + + keys = [ + b + for b in backend_indices + if b.has_kernel(f) + or ( + f.structured_delegate is not None + and b.dispatch_key in STRUCTURED_DISPATCH_KEYS + ) + ] + if len(keys) == 1: + return generate_static_dispatch_backend_call(sig, f, keys[0]) + elif len(keys) == 0: + return generate_static_dispatch_fallback_call(sig, f, backend_indices) + + native_tensor_args = [ + a.name + for a in sig.arguments() + if isinstance(a.argument, SelfArgument) + or isinstance(a.argument, Argument) + and a.argument.type.is_tensor_like() + ] + tensor_args = ", ".join(native_tensor_args) + tensor_opts = f.func.arguments.tensor_options + + stmts = [] + subexprs: List[str] = [] + if tensor_opts is not None: + subexprs.append( + "DispatchKeySet(c10::computeDispatchKey(dtype, layout, device))" + ) + if tensor_args != "": + subexprs.append(f"c10::detail::multi_dispatch_key_set({tensor_args})") + stmts.append(f"""DispatchKeySet _dk_set = {' | '.join(subexprs)};""") + stmts.append("DispatchKey _dk = c10::highestPriorityBackendTypeId(_dk_set);") + + dispatch_code = [] + for index in keys: + dispatch_code.append(f"""case DispatchKey::{index.dispatch_key}:""") + dispatch_code.append( + f"""\t{generate_static_dispatch_backend_call(sig, f, index)};""" + ) + + fallback = generate_static_dispatch_fallback_call(sig, f, backend_indices) + connector = "\n\t\t" + + return f""" + {connector.join(stmts)} + switch (_dk) {{ + {connector.join(dispatch_code)} + default: + {fallback} + }} + """ + + +# Generates RegisterSchema.cpp. Depending on the selector, either +# all schemas are registered, or only some are (in the case of +# selective build) +@dataclass(frozen=True) +class RegisterSchema: + selector: SelectiveBuilder + known_tags: Dict[str, int] = field(default_factory=dict) + + @method_with_native_function + def __call__(self, f: NativeFunction) -> Optional[str]: + if not self.selector.is_native_function_selected(f): + return None + tags = "{" + ", ".join(f"at::Tag::{tag}" for tag in sorted(f.tags)) + "}" + if tags == "{}": + return f"m.def({cpp_string(str(f.func))}, {{}});\n" + maybe_tags = "" + if tags not in self.known_tags: + idx = len(self.known_tags) + self.known_tags[tags] = idx + maybe_tags = f"const std::vector tags_{idx} = {tags};\n" + return f"{maybe_tags}m.def({cpp_string(str(f.func))}, tags_{self.known_tags[tags]});\n" + + +# Generates Operators.h and Operators.cpp. +# These provide macros that, given an operator and overload name, allow users +# to access an "un-overloaded" function version of the operator. This +# is useful for extension writers who want to (1) want to decltype the operator +# and (2) don't want to worry about method-only operators. +@dataclass(frozen=True) +class ComputeOperators: + target: Literal[Target.DECLARATION, Target.DEFINITION] + static_dispatch_backend_indices: List[BackendIndex] + + @method_with_native_function + def __call__(self, f: NativeFunction) -> str: + sig = DispatcherSignature.from_schema(f.func) + name = f.func.name.unambiguous_name() + + if self.target is Target.DECLARATION: + # Note [The ATen Operators API] + # The ATen Operators API lives in the at::_ops namespace, and contains compile-time + # metadata about each operator + entry points into the Dispatcher. + # The C++ function, method, and redispatch API's are all implemented as wrappers + # into various bits of the structs defined here. + # + # Important characteristics about the Operators API: + # (1) It follows the Dispatcher API. + # This is kind of necessary to avoid overhead. + # For example: if it followed the C++ API, then all of the faithful C++ factory functions + # would need to wrap their arguments into TensorOptions only to unwrap them again. + # (2) Overload names are disambiguated. + # This is helpful for pytorch extenders who would like to decltype() an aten operator, + # that has overloads, e.g. decltype(at::_ops::mul_Tensor::call) + # (3) No argument defaulting is allowed. + # This is more of an implementation detail to avoid #include cycles, + # since TensorBody.h (which defines the Tensor class) needs to include this file. + # (4) manual_cpp_bindings and faithful names are not included in the API. + # This applies to stuff like __dispatch__is_complex(), and add_outf(). + # These aren't "real aten ops", they're just additional functions provided by the C++ API. + # They're implemented as wrappers in Functions.h that call into the actual operators + # defined here, i.e. at::_ops::is_complex::call() and at::_ops::add_out::call(). + # This means that ATEN_OP(is_complex) will not fastpath, and will go through the dispatcher. + return f""" +struct TORCH_API {name} {{ + using schema = {sig.type()}; + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::{f.func.name.name}") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "{f.func.name.overload_name}") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, {cpp_string(str(f.func))}) + static {sig.defn(name="call", is_redispatching_fn=False)}; + static {sig.defn(name="redispatch", is_redispatching_fn=True)}; +}};""" + + elif self.target is Target.DEFINITION: + defns = f""" +STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, name, "aten::{f.func.name.name}") +STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, overload_name, "{f.func.name.overload_name}") +STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, schema_str, {cpp_string(str(f.func))}) + +// aten::{f.func} +static C10_NOINLINE c10::TypedOperatorHandle<{name}::schema> create_{name}_typed_handle() {{ + return c10::Dispatcher::singleton() + .findSchemaOrThrow({name}::name, {name}::overload_name) + .typed<{name}::schema>(); +}} +""" + for is_redispatching_fn in [False, True]: + if is_redispatching_fn: + dispatcher_exprs_str = ", ".join( + ["dispatchKeySet"] + [a.name for a in sig.arguments()] + ) + method_base = "redispatch" + else: + dispatcher_exprs_str = ", ".join([a.name for a in sig.arguments()]) + method_base = "call" + + dispatcher_call = method_base + method_name = f"{name}::{method_base}" + + fn_body = f""" + static auto op = create_{name}_typed_handle(); + return op.{dispatcher_call}({dispatcher_exprs_str});""" + + if ( + not is_redispatching_fn + and len(self.static_dispatch_backend_indices) > 0 + ): + # call() should go through static dispatch + fn_body = static_dispatch( + sig, f, backend_indices=self.static_dispatch_backend_indices + ) + defns += f""" +// aten::{f.func} +{sig.defn(name=method_name, is_redispatching_fn=is_redispatching_fn)} {{ + {fn_body} +}} +""" + return defns + else: + assert_never(self.target) + + +# Generates Functions.h, which provides the functional public C++ API, +# and the scaffolding to call into the dispatcher from these functions. +@dataclass(frozen=True) +class ComputeFunction: + @method_with_native_function + def __call__(self, f: NativeFunction) -> Optional[str]: + sig_group = CppSignatureGroup.from_native_function( + f, method=False, fallback_binding=f.manual_cpp_binding + ) + has_symint = f.func.has_symint() + + result = "" + for sig in sig_group.signatures(): + # See Note [The ATen Operators API] + target_sig = DispatcherSignature.from_schema(f.func) + exprs = translate(sig.arguments(), target_sig.arguments()) + exprs_str = ", ".join([e.expr for e in exprs]) + + if sig.symint: + intlike_t = "c10::SymInt" + else: + intlike_t = "int64_t" + + if Variant.function in f.variants: + result += f""" +// aten::{f.func} +inline {sig.decl()} {{ + return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str}); +}}""" + + # The template function can be used from template situations + # where you want to switch between the symint or not version + # depending on a template argument + # + # NB: we ALWAYS generate this even for methods. But we put it in + # this header so it can take advantage of per-op headers + if has_symint: + result += f""" +namespace symint {{ + template ::value>> + {sig.decl(suppress_symint_suffix=True)} {{ + return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str}); + }} +}} +""" + return result + + +# Generates TensorBody.h. This file provides the object-oriented (method-based) +# public C++ API, and the scaffolding to call into the dispatcher from these functions. +@dataclass(frozen=True) +class ComputeTensorMethod: + target: Literal[Target.DECLARATION, Target.DEFINITION] + static_dispatch_backend_indices: List[BackendIndex] + + @method_with_native_function + def __call__(self, f: NativeFunction) -> Optional[str]: + if Variant.method not in f.variants: + return None + + assert not f.func.is_out_fn() + assert f.func.arguments.self_arg is not None + + sig_group = CppSignatureGroup.from_native_function( + f, method=True, fallback_binding=f.manual_cpp_binding + ) + + if self.target is Target.DECLARATION: + result = "" + for sig in sig_group.signatures(): + result += f"{sig.decl()} const;\n" + return result + + if self.target is not Target.DEFINITION: + assert_never(self.target) + + result = "" + + for sig in sig_group.signatures(): + target_sig = DispatcherSignature.from_schema(f.func) + exprs = translate(sig.arguments(), target_sig.arguments(), method=True) + exprs_str = ", ".join([e.expr for e in exprs]) + + result += f""" +// aten::{f.func} +inline {sig.defn(prefix="Tensor::")} const {{ + return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str}); +}} +""" + + return result + + +# Generates RedispatchFunctions.h. +# This is similar to the C++ API defined in Functions.h, but provides access +# to the dispatcher's redispatch API. +@dataclass(frozen=True) +class ComputeRedispatchFunction: + @method_with_native_function + def __call__(self, f: NativeFunction) -> Optional[str]: + # We unconditionally generate function variants of the redispatch API. + # This is mainly because we can namespace functions separately, but not methods, + sig_group = CppSignatureGroup.from_native_function( + f, method=False, fallback_binding=f.manual_cpp_binding + ) + + result = "" + for sig in sig_group.signatures(): + target_sig = DispatcherSignature.from_schema(f.func) + exprs = translate(sig.arguments(), target_sig.arguments()) + exprs_str = ", ".join(["dispatchKeySet"] + [a.expr for a in exprs]) + + result += f""" +// aten::{f.func} +inline {sig.decl(is_redispatching_fn=True)} {{ + return at::_ops::{f.func.name.unambiguous_name()}::redispatch({exprs_str}); +}} +""" + + return result + + +# Generates ATenOpList.cpp, a runtime accessible list of all aten +# operators. +# TODO: This was historically used to help some JIT interop code +# figure out whether or not to treat aten namespace'd operators +# one way or another, we should reevaluate if this is actually needed. +@with_native_function +def compute_aten_op(f: NativeFunction) -> str: + return f'{{"aten::{f.func.name.name}", "{f.func.name.overload_name}"}},' + + +# Generates MetaFunctions.h +def compute_meta_function_declaration(g: NativeFunctionsGroup) -> Optional[str]: + if not g.structured: + return None + with native_function_manager(g.out): + name = meta.name(g) + args = structured.meta_arguments(g) + args_str = ", ".join(a.decl() for a in args) + parent_class = g.out.structured_inherits + if parent_class is None: + parent_class = "at::impl::MetaBase" + meta_return = "void" + precomputed = g.out.precomputed if g.structured else None + + if precomputed: + # Generate the template declaration with one bool parameter for each + # precomputed element. Each parameter is true if the corresponding (in + # terms of position) precomputed element has been set. + precomputed_values = [*precomputed.replace.values(), precomputed.add] + precomputed_elements = [ + elem for replace_list in precomputed_values for elem in replace_list + ] + precomputed_template_parameters = [ + elem.name.upper() for elem in precomputed_elements + ] + precomputed_template_params_str = ", ".join( + f"bool {param} = false" for param in precomputed_template_parameters + ) + precompute_template_decl = f"template <{precomputed_template_params_str}>" + + # Generate a string containing declarations of all precomputed elements. + precomputed_elements_with_cpp_types = [ + structured.argument_type(elem, binds=elem.name) + for elem in precomputed_elements + ] + + precomputed_elements_decl = ";\n".join( + f"{elem.cpp_type(strip_ref=True)} {elem.name}" + for elem in precomputed_elements_with_cpp_types + ) + + # Generate "setter" methods for each precomputed element. Each method will return + # a new instance of precompute_out with the template parameter that corresponds to + # the member set by the method to true (to indicate that it has been set). + setter_methods = [] + for i, elem in enumerate(precomputed_elements): + # Generate the signature. The return type will be the same + # as the type of `this` but with the template parameter + # corresponding to the element set by this method set to true. + # The assert generated below will ensure that this template + # parameter is false on the type of `this`. + return_ty_templates = ", ".join( + precomputed_template_parameters[:i] + + ["true"] + + precomputed_template_parameters[i + 1 :] + ) + return_ty = f"precompute_out<{return_ty_templates}>" + elem_cpp_ty = precomputed_elements_with_cpp_types[i].cpp_type( + strip_ref=True + ) + signature = f"{return_ty} set_{elem.name}({elem_cpp_ty} value)" + + # Generate an assert which checks that the + # template parameter corresponding to the precomputed + # element that is set by this method is false on the + # class corresponding to the object that `this` points to. + # This ensures that each element can be set only once. + assert_msg = f'"{precomputed_elements[i].name} already set"' + assert_stmt = f"static_assert({precomputed_template_parameters[i]} == false, {assert_msg});" + + # Generate the new object construction block. All state + # except the element that this method sets is copied from the + # object that `this` points to. The value for the element that + # the method sets is taken from a method parameter. + construction_stmts = [] + construction_stmts.append(f"{return_ty} ret;") + + for j, elem in enumerate(precomputed_elements): + if i == j: + construction_stmts.append(f"ret.{elem.name} = value;") + else: + construction_stmts.append( + f"ret.{elem.name} = this->{elem.name};" + ) + + construction_stmts.append("return ret;") + construction_block = "\n".join(construction_stmts) + + setter_methods.append( + f""" + {signature} {{ + {assert_stmt} + {construction_block} + }} + """ + ) + setter_methods_decl = "\n".join(setter_methods) + + # Meta should return an instance of the struct containing the precomputed elements. + meta_return_template_params = ", ".join( + ["true"] * len(precomputed_template_parameters) + ) + # This typedef (actually a using statement) is needed so that TORCH_META_FUNC can reuse the return + # type (which has a variable number of template parameters). + meta_return_typedef = f"using meta_return_ty = precompute_out <{meta_return_template_params}>;" + meta_return = "meta_return_ty" + precomputed_decl = f""" + {precompute_template_decl} + struct TORCH_API precompute_out {{ + {setter_methods_decl} + {precomputed_elements_decl}; + }};""" + else: + meta_return_typedef = "" + precomputed_decl = "" + + return f"""\ +struct TORCH_API structured_{name} : public {parent_class} {{ + {precomputed_decl} + {meta_return_typedef} + {meta_return} meta({args_str}); +}}; +""" + + +def needs_backend_select(f: NativeFunction, selector: SelectiveBuilder) -> bool: + name = str(f.func.name.name) + if name.endswith("_like") or name.startswith("new_"): + return False + if f.func.arguments.tensor_options is None: + return False + return selector.is_native_function_selected(f) + + +# Generates RegisterBackendSelect.cpp, a series of kernels which provide +# specialized computation of dispatch key for operator signatures which cannot +# be easily done automatically using templating. +@dataclass(frozen=True) +class ComputeBackendSelect: + target: Literal[Target.DEFINITION, Target.REGISTRATION] + + # Selector object to determine which operators to generate + # registration code for. + selector: SelectiveBuilder + + @method_with_native_function + def __call__(self, f: NativeFunction) -> Optional[str]: + if not needs_backend_select(f, self.selector): + return None + + name = native.name(f.func) + # BackendSelect can go to Meta, so it must preserve symints + native_sig = NativeSignature(f.func, symint=True) + + native_tensor_args = [ + a + for a in native_sig.arguments() + if isinstance(a.argument, Argument) and a.argument.type.is_tensor_like() + ] + + dispatcher_sig = DispatcherSignature.from_schema(f.func) + + sig: Union[NativeSignature, DispatcherSignature] + sig = dispatcher_sig + dispatcher_exprs = dispatcher_sig.exprs() + dispatch_key = "c10::computeDispatchKey(dtype, layout, device)" + + if self.target is Target.DEFINITION: + # I don't think there's actually a good reason to generate + # these two cases differently + # The first case could probably be improved though- it calls computeDispatchKeySet(), + # which looks at TLS dispatch keys- there should not be any by the time we reach backend select. + if native_tensor_args: + assert f.func.arguments.has_tensor_arg() + tensor_args = ", ".join(a.name for a in native_tensor_args) + compute_dk = f"""\ +DispatchKeySet _dk_set = c10::DispatchKeySet({dispatch_key}) | c10::detail::multi_dispatch_key_set({tensor_args}); +DispatchKeySet _dk_mask = c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::BackendSelect); +DispatchKeySet _dk = c10::impl::computeDispatchKeySet(_dk_set, _dk_mask);""" + else: + assert not f.func.arguments.has_tensor_arg() + compute_dk = ( + f"DispatchKeySet _dk = c10::DispatchKeySet({dispatch_key});" + ) + return f"""\ +// aten::{f.func} +C10_ALWAYS_INLINE +{sig.defn(name)} {{ + {compute_dk} + return at::_ops::{f.func.name.unambiguous_name()}::redispatch( + _dk, {', '.join(a.expr for a in dispatcher_exprs)}); +}} +""" + elif self.target is Target.REGISTRATION: + return f"""m.impl("aten::{f.func.name}", TORCH_FN({name}));""" + else: + assert_never(self.target) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# YAML CODE GENERATION +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def format_yaml(data: object) -> str: + # Ignore alias in Dumper + YamlDumper.ignore_aliases = lambda self, data: True # type: ignore[assignment] + + # Support serializing OrderedDict + def dict_representer(dumper: Any, data: Any) -> Any: + return dumper.represent_dict(data.items()) + + YamlDumper.add_representer(OrderedDict, dict_representer) # type: ignore[no-untyped-call] + # Some yaml parsers (e.g. Haskell's) don't understand line breaks. + # width=1e9 turns off optional line breaks and improves + # the portability of the outputted yaml. + return yaml.dump(data, default_flow_style=False, Dumper=YamlDumper, width=1e9) # type: ignore[no-any-return, call-overload] + + +# For some reason, some defaults we write to YAML are written as native +# YAML objects, rather than doing them uniformly as strings. This +# function detects those cases and converts them into native Python +# objects. +def pythonify_default(s: str) -> object: + if s == "true": + return True + elif s == "false": + return False + + try: + return int(s) + except ValueError: + try: + return float(s) + except ValueError: + return s + + +# What is a dynamic type? Over time, the semantic meaning of +# dynamic type has degraded to meaninglessness (in the old days, +# it captured dtype-ness of types, but that has gone away with +# the removal of TH). These days, it's mostly the same thing as +# the C++ API argument type, except that Tensor and Tensor? +# arguments simply present as Tensor. +# +# TODO: Get rid of dynamic_type, after getting tools/autograd +# to use the new codegen framework +def dynamic_type(t: Type) -> str: + if isinstance(t, OptionalType): + return dynamic_type(t.elem) + # Note we don't use t.is_tensor_like() here because it would + # also include Tensor[] + if str(t) == "Tensor": + return "at::Tensor" + # This is a legacy concept, so never report SymInt + return cpp.argumenttype_type( + t, mutable=False, binds="__placeholder__", symint=False + ).cpp_type() + + +def compute_method_of_yaml(variants: Set[Variant]) -> List[str]: + # This is written out explicitly to ensure that Tensor and + # namespace are put into the list in the right order + method_of = ["Type"] + if Variant.method in variants: + method_of.append("Tensor") + if Variant.function in variants: + method_of.append("namespace") + return method_of + + +def compute_returns_yaml( + f: NativeFunction, +) -> Tuple[List[Dict[str, str]], Dict[str, str]]: + # Note [name and field_name] + # ~~~~~~~~~~~~~~~~~~~~~~~~~~ + # To understand name_to_field_name, we must first talk about this + # schema: + # + # lstsq.X(Tensor self, Tensor A, *, Tensor(a!) X, Tensor(b!) qr) -> (Tensor(a!) solution, Tensor(b!) QR) + # + # There is something very odd about this schema: it is an out + # variant of the function (that is to say, it will convert into + # at::lstsq_out() in the C++ API), but the names of the output + # return arguments don't match the keyword argument names of + # the inputs. It TURNS OUT that in this situation, the historical + # Declarations.yaml we want to output is this (abbreviated to + # only show relevant fields): + # + # arguments: + # ... + # - field_name: solution + # name: X + # - field_name: QR + # name: qr + # ... + # + # returns: + # - field_name: solution + # name: X + # - field_name: QR + # name: qr + # + # The name of the return fields is stored in 'field_name', and the + # name of the arguments is stored in 'name'. So when we process + # arguments, we need a way to get at the corresponding return. At + # the moment, this is most conveniently done by constructing a + # mapping from name (the argument concept) to field_name (the + # return concept) while processing return arguments, since we don't + # directly maintain this correspondence in the modeling of function + # schema itself. + # + # See also https://github.com/pytorch/pytorch/issues/43114 + name_to_field_name: Dict[str, str] = {} + + # Compute the returns field of the YAML entry + names = cpp.return_names(f) + returns = [] + for i, (r, name) in enumerate(zip(f.func.returns, names)): + ret = { + "dynamic_type": dynamic_type(r.type), + "name": name, + # legacy, report ints + "type": cpp.return_type(r, symint=False).cpp_type(), + } + + if r.name: + # See Note [name and field_name] + ret["field_name"] = r.name + if f.func.is_out_fn(): + name_to_field_name[f.func.arguments.out[i].name] = r.name + + returns.append(ret) + + return returns, name_to_field_name + + +# arguments in yaml roughly corresponds to the public C++ API +def compute_cpp_argument_yaml( + cpp_a: Binding, + *, + schema_order: bool, + kwarg_only_set: Set[str], + out_arg_set: Set[str], + name_to_field_name: Dict[str, str], +) -> object: + if isinstance(cpp_a.argument, TensorOptionsArguments): + arg: Dict[str, object] = { + "annotation": None, + "dynamic_type": "at::TensorOptions", + "is_nullable": False, + "name": cpp_a.name, + "type": cpp_a.type, + "kwarg_only": True, + } + if cpp_a.default is not None: + arg["default"] = cpp_a.default + return arg + elif isinstance(cpp_a.argument, SelfArgument): + raise AssertionError() + elif isinstance(cpp_a.argument, Argument): + return compute_argument_yaml( + cpp_a.argument, + schema_order=schema_order, + kwarg_only_set=kwarg_only_set, + out_arg_set=out_arg_set, + name_to_field_name=name_to_field_name, + ) + + +def compute_argument_yaml( + a: Argument, + *, + schema_order: bool, + kwarg_only_set: Set[str], + out_arg_set: Set[str], + name_to_field_name: Dict[str, str], +) -> object: + arg: Dict[str, object] = { + "annotation": str(a.annotation) if a.annotation else None, + "dynamic_type": dynamic_type(a.type), + "is_nullable": a.type.is_nullable(), + "name": a.name, + # legacy, report ints + "type": cpp.argument_type(a, binds="__placeholder__", symint=False).cpp_type(), + } + if a.default is not None: + arg["default"] = pythonify_default( + cpp.default_expr(a.default, a.type, symint=False) + ) + if a.name in kwarg_only_set: + arg["kwarg_only"] = True + if a.name in out_arg_set: + arg["output"] = True + arg["allocate"] = True + # See Note [name and field_name] + if a.name in name_to_field_name: + arg["field_name"] = name_to_field_name[a.name] + # Historically, booleans don't get their size recorded, because it + # is already built into the cpp type (e.g., std::array) + l = a.type.is_list_like() + if l is not None and l.size is not None and str(l.elem) != "bool": + arg["size"] = l.size + return arg + + +@with_native_function +def compute_declaration_yaml(f: NativeFunction) -> object: + returns, name_to_field_name = compute_returns_yaml(f) + + # These sets are used to conveniently test if an argument is a + # kwarg-only or out argument + kwarg_only_set = {a.name for a in f.func.arguments.flat_kwarg_only} + out_arg_set = {a.name for a in f.func.arguments.out} + + sig_group = CppSignatureGroup.from_native_function( + f, method=False, fallback_binding=False + ) + cpp_args = sig_group.signature.arguments() + arguments = [ + compute_cpp_argument_yaml( + cpp_a, + schema_order=False, + kwarg_only_set=kwarg_only_set, + out_arg_set=out_arg_set, + name_to_field_name=name_to_field_name, + ) + for cpp_a in cpp_args + ] + + schema_order_jit_arguments = list(f.func.schema_order_arguments()) + + schema_order_arguments = [ + compute_argument_yaml( + a, + schema_order=True, + kwarg_only_set=kwarg_only_set, + out_arg_set=out_arg_set, + name_to_field_name=name_to_field_name, + ) + for a in schema_order_jit_arguments + ] + + cpp_schema_order_types = [ + # NB: method here doesn't matter + r.type + for a in schema_order_jit_arguments + for r in cpp.argument( + a, + method=False, + cpp_no_default_args=set(), + faithful=False, + symint=False, + has_tensor_options=False, + ) + ] + + # legacy, report ints + cpp_returns = cpp.returns_type(f.func.returns, symint=False).cpp_type() + schema_order_cpp_signature = f"{cpp_returns} ({', '.join(cpp_schema_order_types)})" + + is_factory_method = ( + any(isinstance(a.argument, TensorOptionsArguments) for a in cpp_args) + and Variant.method not in f.variants + ) + + return OrderedDict( + [ + ("name", cpp.name(f.func)), + ("operator_name", str(f.func.name.name)), + ("overload_name", str(f.func.name.overload_name)), + ("manual_kernel_registration", f.manual_kernel_registration), + ( + "category_override", + f.category_override if f.category_override is not None else "", + ), + ("schema_string", f"aten::{f.func}"), + ("arguments", arguments), + ("schema_order_cpp_signature", schema_order_cpp_signature), + ("schema_order_arguments", schema_order_arguments), + ("method_of", compute_method_of_yaml(f.variants)), + ("mode", "native"), + ("python_module", "" if f.python_module is None else f.python_module), + ("returns", returns), + ("inplace", f.func.name.name.inplace), + ("is_factory_method", is_factory_method), + ("abstract", f.is_abstract), + ("device_guard", f.device_guard), + ("with_gil", False), + ("deprecated", False), + ("has_math_kernel", f.has_composite_implicit_autograd_kernel), + ] + ) + + +# See Note [Auto generated composite kernels] +def has_autogenerated_composite_kernel(f: NativeFunction) -> bool: + return (f.structured or f.structured_delegate is not None) and ( + f.func.kind() == SchemaKind.functional or f.func.kind() == SchemaKind.inplace + ) + + +@with_native_function_and_indices +def compute_registration_declarations( + f: NativeFunction, backend_indices: Dict[DispatchKey, BackendIndex] +) -> str: + name = dispatcher.name(f.func) + returns_type = dispatcher.returns_type( + f.func.returns + ).cpp_type_registration_declarations() + args = dispatcher.arguments(f.func) + args_str = ", ".join(a.no_default().decl_registration_declarations() for a in args) + comment_data: Dict[str, str] = { + "schema": f"aten::{f.func}", + # TODO: What exactly is the semantics of the 'dispatch' field? + "dispatch": str( + {k for k, v in backend_indices.items() if v.has_kernel(f)} + != {DispatchKey.CompositeImplicitAutograd} + and {k for k, v in backend_indices.items() if v.has_kernel(f)} + != { + DispatchKey.CompositeImplicitAutograd, + DispatchKey.CompositeImplicitAutogradNestedTensor, + } + ), + "default": str(f.has_composite_kernel or has_autogenerated_composite_kernel(f)), + } + return f"""{returns_type} {name}({args_str}); // {json.dumps(comment_data)} +""" + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# RUN IT ALL +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def get_custom_build_selector( + provided_op_registration_allowlist: Optional[List[str]], + op_selection_yaml_path: Optional[str], +) -> SelectiveBuilder: + assert not ( + provided_op_registration_allowlist is not None + and op_selection_yaml_path is not None + ), ( + "Both provided_op_registration_allowlist and " + + "op_selection_yaml_path can NOT be provided at the " + + "same time." + ) + + op_registration_allowlist: Optional[Set[str]] = None + if provided_op_registration_allowlist is not None: + op_registration_allowlist = set(provided_op_registration_allowlist) + + if op_registration_allowlist is not None: + selector = SelectiveBuilder.from_legacy_op_registration_allow_list( + op_registration_allowlist, + True, + False, + ) + elif op_selection_yaml_path is not None: + selector = SelectiveBuilder.from_yaml_path(op_selection_yaml_path) + else: + selector = SelectiveBuilder.get_nop_selector() + + return selector + + +def get_grouped_by_view_native_functions( + native_functions: Sequence[NativeFunction], +) -> Sequence[Union[NativeFunction, NativeFunctionsViewGroup]]: + def maybe_create_view_group( + d: Dict[Union[ViewSchemaKind, SchemaKind], NativeFunction] + ) -> List[Union[NativeFunction, NativeFunctionsViewGroup]]: + funcs: List[Union[NativeFunction, NativeFunctionsViewGroup]] = [] + if ViewSchemaKind.aliasing in d: + view = d.pop(ViewSchemaKind.aliasing) + view_inplace = d.pop(ViewSchemaKind.aliasing_inplace, None) + view_copy = d.pop(SchemaKind.functional, None) + + funcs.append( + NativeFunctionsViewGroup( + view=view, + view_copy=view_copy, + view_inplace=view_inplace, + ) + ) + # Take the remaining functions that weren't part of the view group + # and emit them separately + funcs.extend(d.values()) + return funcs + + grouped_by_views: Dict[ + FunctionSchema, Dict[Union[SchemaKind, ViewSchemaKind], NativeFunction] + ] = defaultdict(dict) + for f in native_functions: + schema = f.func.view_signature() + view_kind: ViewSchemaKind = f.view_schema_kind + # We need to group up ops relevant to the same "view", consisting of: + # view op (ViewSchemaKind.aliasing) + # view_inplace op (ViewSchemaKind.aliasing_inplace) + # view_copy op (SchemaKind.functional) + if view_kind == ViewSchemaKind.non_aliasing: + kind = f.func.kind() + assert kind not in grouped_by_views[schema] + grouped_by_views[schema][kind] = f + else: + assert view_kind not in grouped_by_views[schema] + grouped_by_views[schema][view_kind] = f + + return list(concatMap(maybe_create_view_group, grouped_by_views.values())) + + +def get_grouped_native_functions( + native_functions: Sequence[NativeFunction], +) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]: + def flatten_pre_group( + d: Dict[SchemaKind, NativeFunction] + ) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]: + r = NativeFunctionsGroup.from_dict(d) + if r is None: + # Invariant: any NativeFunctions that are code-generated + # should have been grouped into NativeFunctionsGroup objects + assert not any("generated" in f.tags for f in d.values()) + return list(d.values()) + else: + return [r] + + # TODO: how come ValuesView isn't a Sequence lol + pre_grouped_native_functions = pre_group_native_functions(native_functions) + return list( + concatMap(flatten_pre_group, list(pre_grouped_native_functions.values())) + ) + + +def get_ns_grouped_kernels( + *, + grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], + backend_indices: Dict[DispatchKey, BackendIndex], + native_function_decl_gen: Callable[ + [Union[NativeFunctionsGroup, NativeFunction], BackendIndex], List[str] + ] = dest.compute_native_function_declaration, +) -> Dict[str, List[str]]: + ns_grouped_kernels: Dict[str, List[str]] = defaultdict(list) + for f in grouped_native_functions: + native_function_namespaces = set() + dispatch_keys = set() + for dispatch_key, backend_idx in backend_indices.items(): + backend_metadata = backend_idx.get_kernel(f) + if backend_metadata: + namespace = backend_metadata.cpp_namespace + dispatch_keys.add(dispatch_key) + native_function_namespaces.add(namespace) + else: + namespace = DEFAULT_KERNEL_NAMESPACE + assert ( + len(native_function_namespaces) <= 1 + ), f"Codegen only supports one namespace per operator, got {native_function_namespaces} from {dispatch_keys}" + ns_grouped_kernels[namespace].extend( + native_function_decl_gen(f, backend_idx) + ) + return ns_grouped_kernels + + +def get_native_function_declarations_from_ns_grouped_kernels( + *, + ns_grouped_kernels: Dict[str, List[str]], +) -> List[str]: + declarations: List[str] = [] + newline = "\n" + for namespace, kernels in ns_grouped_kernels.items(): + ns_helper = NamespaceHelper( + namespace_str=namespace, + entity_name="", + max_level=4, + ) + # Convert to a set first to remove duplicate kernel names. Backends are + # allowed to repeat kernel names; only generate the declaration once! + ordered_kernels = list(OrderedDict.fromkeys(kernels)) + declarations.extend( + f""" +{ns_helper.prologue} +{newline.join(ordered_kernels)} +{ns_helper.epilogue} + """.split( + newline + ) + ) + return declarations + + +# Return native function declarations grouped by their namespaces. +def get_native_function_declarations( + *, + grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], + backend_indices: Dict[DispatchKey, BackendIndex], + native_function_decl_gen: Callable[ + [Union[NativeFunctionsGroup, NativeFunction], BackendIndex], List[str] + ] = dest.compute_native_function_declaration, +) -> List[str]: + """ + Generate kernel declarations, in `NativeFunction(s).h`. + :param grouped_native_functions: a sequence of `NativeFunction` or `NativeFunctionGroup`. + :param backend_indices: kernel collections grouped by dispatch key. + :param native_function_decl_gen: callable to generate kernel declaration for each `NativeFunction`. + :return: a list of string, from the string with all declarations, grouped by namespaces, split by newline. + """ + + ns_grouped_kernels = get_ns_grouped_kernels( + grouped_native_functions=grouped_native_functions, + backend_indices=backend_indices, + native_function_decl_gen=native_function_decl_gen, + ) + return get_native_function_declarations_from_ns_grouped_kernels( + ns_grouped_kernels=ns_grouped_kernels + ) + + +def get_kernel_namespace( + *, f: Union[NativeFunction, NativeFunctionsGroup], backend_idx: BackendIndex +) -> str: + backend_metadata = backend_idx.get_kernel(f) + assert not backend_metadata or "::native" in backend_metadata.cpp_namespace, ( + f"The kernel for function {f.func.name if isinstance(f, NativeFunction) else f.functional.func.name} " + f"with dispatch key {backend_idx.dispatch_key}" + f" has a namespace {backend_metadata.cpp_namespace} and it's not ending with '::native'." + ) + return ( + backend_metadata.cpp_namespace if backend_metadata else DEFAULT_KERNEL_NAMESPACE + ) + + +# Return native function definitions grouped by dispatch key and custom namespace. +# Used in RegisterDispatchKey.cpp and etc. +def get_native_function_definitions( + *, + fm: FileManager, + grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], + dispatch_key: DispatchKey, + backend_idx: BackendIndex, + selector: SelectiveBuilder, + rocm: bool, + symint: bool, + skip_dispatcher_op_registration: bool, + gen_dispatch_helpers: bool, +) -> List[str]: + definitions: List[str] = [] + ns_definitions: Dict[str, List[str]] = defaultdict(list) + anonymous_definitions: Dict[str, List[str]] = defaultdict(list) + registrations: Dict[str, Dict[str, List[str]]] = defaultdict(dict) + newline = "\n" + ns_gen = dest.RegisterDispatchKey( + backend_idx, + Target.NAMESPACED_DEFINITION, + selector, + rocm=rocm, + symint=symint, + class_method_name=None, + skip_dispatcher_op_registration=skip_dispatcher_op_registration, + ) + anonymous_gen = dest.RegisterDispatchKey( + backend_idx, + Target.ANONYMOUS_DEFINITION, + selector, + rocm=rocm, + symint=symint, + class_method_name=None, + skip_dispatcher_op_registration=skip_dispatcher_op_registration, + ) + reg_gen = dest.RegisterDispatchKey( + backend_idx, + Target.REGISTRATION, + selector, + rocm=rocm, + symint=symint, + class_method_name=None, + skip_dispatcher_op_registration=skip_dispatcher_op_registration, + ) + for f in grouped_native_functions: + kernel_namespace = get_kernel_namespace(f=f, backend_idx=backend_idx).replace( + "::native", "" + ) + + ns_definitions[kernel_namespace].extend( + ns_gen(f), + ) + anonymous_definitions[kernel_namespace].extend( + anonymous_gen(f), + ) + namespace = ( + f.namespace if isinstance(f, NativeFunction) else f.functional.namespace + ) + if namespace not in registrations[kernel_namespace]: + registrations[kernel_namespace] = defaultdict(list) + registrations[kernel_namespace][namespace].extend( + reg_gen(f), + ) + + for kernel_namespace in ns_definitions: + if len(ns_definitions[kernel_namespace]) == 0: + continue + ns_helper = NamespaceHelper(namespace_str=kernel_namespace) + registration_body = "" + for namespace in registrations[kernel_namespace]: + if not registrations[kernel_namespace][namespace]: + continue + registration_body += f""" +TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{ + {newline.join(registrations[kernel_namespace][namespace])} +}};""" + definitions.extend( + fm.substitute_with_template( + "RegisterDispatchDefinitions.ini", + lambda: { + "ns_prologue": ns_helper.prologue, + "ns_epilogue": ns_helper.epilogue, + "dispatch_helpers": dest.gen_registration_helpers(backend_idx) + if gen_dispatch_helpers + else [], + "dispatch_anonymous_definitions": anonymous_definitions[ + kernel_namespace + ], + "static_init_dispatch_registrations": "" + if skip_dispatcher_op_registration + else registration_body, + "deferred_dispatch_registrations": "", + "dispatch_namespace": dispatch_key.lower(), + "dispatch_namespaced_definitions": ns_definitions[kernel_namespace], + }, + ).split(newline) + ) + + return definitions + + +# Return native function declarations grouped by dispatch key and custom namespace. +# Used in CPUFunctions_inl.h and etc. +def get_namespaced_declaration( + *, + grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], + dispatch_key: DispatchKey, + backend_idx: BackendIndex, + selector: SelectiveBuilder, + rocm: bool, + symint: bool, +) -> List[str]: + declarations: List[str] = [] + ns_grouped_kernels: Dict[str, List[str]] = defaultdict(list) + newline = "\n" + func = dest.RegisterDispatchKey( + backend_idx, + Target.NAMESPACED_DECLARATION, + selector, + rocm=rocm, + class_method_name=None, + skip_dispatcher_op_registration=False, + symint=symint, + ) + for f in grouped_native_functions: + namespace = get_kernel_namespace(f=f, backend_idx=backend_idx).replace( + "native", dispatch_key.lower() + ) + + ns_grouped_kernels[namespace].extend( + func(f), + ) + + for namespace, kernels in ns_grouped_kernels.items(): + if len(kernels) == 0: + continue + ns_helper = NamespaceHelper( + namespace_str=namespace, entity_name="", max_level=3 + ) + ordered_kernels = list(OrderedDict.fromkeys(kernels)) + declarations.extend( + f""" +{ns_helper.prologue} +{newline.join(ordered_kernels)} +{ns_helper.epilogue} + """.split( + newline + ) + ) + return declarations + + +# Return native function schema registration code for aten and other namespaces. +def get_native_function_schema_registrations( + *, + native_functions: Sequence[NativeFunction], + schema_selector: SelectiveBuilder, +) -> Tuple[List[str], str]: + ns_native_functions: Dict[str, List[NativeFunction]] = defaultdict(list) + for native_function in native_functions: + ns_native_functions[native_function.namespace].append(native_function) + schema_registrations = "" + aten_schema_registrations = [] + custom_namespace = None + for namespace, funcs in ns_native_functions.items(): + schema_registrations_body = list( + mapMaybe(RegisterSchema(schema_selector), funcs) + ) + # NB: we have to separate aten namespace registration from other namespaces, + # because in the template we hardcoded an operator for ATen already. + if namespace == "aten": + aten_schema_registrations = schema_registrations_body + else: + custom_namespace = namespace + tab = "\t" + # if the namespace is predefined, we should use define a library fragment + # instead of a new library + torch_library_macro = ( + "TORCH_LIBRARY_FRAGMENT" + if namespace in FRAGMENT_NAMESPACES + else "TORCH_LIBRARY" + ) + schema_registrations += f""" +{torch_library_macro}({custom_namespace}, m) {{ + {tab.join(schema_registrations_body)} +}};""" + return (aten_schema_registrations, schema_registrations) + + +def gen_aggregated_headers( + *, + native_functions: Sequence[NativeFunction], + grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], + structured_native_functions: Sequence[NativeFunctionsGroup], + static_dispatch_idx: List[BackendIndex], + selector: SelectiveBuilder, + backend_indices: Dict[DispatchKey, BackendIndex], + cpu_fm: FileManager, + cuda_fm: FileManager, + functions_keys: Set[DispatchKey], + dispatch_keys: Sequence[DispatchKey], + rocm: bool, +) -> None: + # Buck doesn't support dynamic output files, so we aggregate all operator + # headers into a single file + cpu_fm.write( + "NativeMetaFunctions.h", + lambda: { + "NativeMetaFunctions_includes": [], + "NativeMetaFunctions_declarations": list( + mapMaybe(compute_meta_function_declaration, structured_native_functions) + ), + }, + ) + method_native_functions = [ + fn for fn in native_functions if Variant.method in fn.variants + ] + non_method_native_functions = [ + fn for fn in native_functions if fn not in method_native_functions + ] + cpu_fm.write( + "MethodOperators.h", + lambda: { + "MethodOperators_includes": [], + "MethodOperators_declarations": list( + mapMaybe( + ComputeOperators( + Target.DECLARATION, + static_dispatch_backend_indices=static_dispatch_idx, + ), + method_native_functions, + ) + ), + }, + ) + cpu_fm.write( + "Operators.h", + lambda: { + "Operators_includes": ["#include "], + "Operators_declarations": list( + mapMaybe( + ComputeOperators( + Target.DECLARATION, + static_dispatch_backend_indices=static_dispatch_idx, + ), + non_method_native_functions, + ) + ), + }, + ) + cpu_fm.write( + "Functions.h", + lambda: { + "static_dispatch_extra_headers": static_dispatch_extra_headers( + static_dispatch_idx + ), + "Functions_includes": ["#include "], + "Functions_declarations": list( + mapMaybe( + ComputeFunction(), + native_functions, + ) + ), + }, + ) + declarations = get_native_function_declarations( + grouped_native_functions=grouped_native_functions, + backend_indices=backend_indices, + ) + cpu_fm.write( + "NativeFunctions.h", + lambda: { + "NativeFunctions_includes": ["#include "], + "NativeFunctions_declarations": declarations, + }, + ) + + for dispatch_key in dispatch_keys: + fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm + if dispatch_key in functions_keys: + inl_headers = f"#include " + + fm.write_with_template( + f"{dispatch_key}Functions.h", + "DispatchKeyFunctions.h", + lambda: { + "dispatch_key": str(dispatch_key), + "inline_headers": inl_headers, + }, + ) + fm.write_with_template( + f"{dispatch_key}Functions_inl.h", + "DispatchKeyFunctions_inl.h", + lambda: { + "DispatchKeyFunctions_inl_includes": [], + "dispatch_namespace": dispatch_key.lower(), + "dispatch_namespaced_declarations": get_namespaced_declaration( + grouped_native_functions=grouped_native_functions, + dispatch_key=dispatch_key, + backend_idx=backend_indices[dispatch_key], + selector=selector, + rocm=rocm, + symint=True, + ), + }, + ) + + del fm + + +def gen_per_operator_headers( + *, + native_functions: Sequence[NativeFunction], + grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], + static_dispatch_idx: List[BackendIndex], + selector: SelectiveBuilder, + backend_indices: Dict[DispatchKey, BackendIndex], + cpu_fm: FileManager, + cuda_fm: FileManager, + ops_fm: FileManager, + functions_keys: Set[DispatchKey], + dispatch_keys: Sequence[DispatchKey], + rocm: bool, +) -> None: + # For CMake builds, split operator declarations into separate headers in + # the ATen/ops folder to split up header dependencies + functions_by_root_name: Dict[str, List[NativeFunction]] = defaultdict(list) + for fn in native_functions: + functions_by_root_name[fn.root_name].append(fn) + + grouped_functions_by_root_name: Dict[ + str, List[Union[NativeFunction, NativeFunctionsGroup]] + ] = defaultdict(list) + for group in grouped_native_functions: + name = group.root_name + grouped_functions_by_root_name[name].append(group) + + for name, functions in functions_by_root_name.items(): + ops_fm.write_with_template( + f"{name}_ops.h", + "Operator.h", + lambda: { + "declarations": list( + mapMaybe( + ComputeOperators( + Target.DECLARATION, + static_dispatch_backend_indices=static_dispatch_idx, + ), + functions, + ) + ), + }, + ) + + ops_fm.write_with_template( + f"{name}.h", + "Function.h", + lambda: { + "static_dispatch_ops_headers": list( + mapMaybe( + lambda fn: static_dispatch_ops_header( + fn, backend_index=static_dispatch_idx + ), + functions, + ) + ), + "operator_includes": f"#include ", + "function_definitions": list( + mapMaybe( + ComputeFunction(), + functions, + ) + ), + }, + ) + + grouped_functions = grouped_functions_by_root_name.get(name, []) + structured_functions = [ + fn + for fn in grouped_functions + if isinstance(fn, NativeFunctionsGroup) and fn.structured + ] + is_structured = len(structured_functions) > 0 + + if is_structured: + ops_fm.write_with_template( + f"{name}_meta.h", + "NativeMetaFunction.h", + lambda: { + "meta_function_declarations": list( + mapMaybe( + compute_meta_function_declaration, structured_functions + ) + ), + }, + ) + declarations = get_native_function_declarations( + grouped_native_functions=grouped_functions, + backend_indices=backend_indices, + native_function_decl_gen=dest.compute_native_function_declaration, + ) + ops_fm.write_with_template( + f"{name}_native.h", + "NativeFunction.h", + lambda: { + "extra_includes": ( + f"#include " if is_structured else [] + ), + "native_function_declarations": declarations, + }, + ) + + for category, suffix in [ + ("Functions", ""), + ("Operators", "_ops"), + ("NativeMetaFunctions", "_meta"), + ("NativeFunctions", "_native"), + ]: + cpu_fm.write( + f"{category}.h", + lambda: { + f"{category}_includes": [ + f"#include " + for name in sorted(functions_by_root_name.keys()) + ], + f"{category}_declarations": [], + }, + ) + + for dispatch_key in dispatch_keys: + if dispatch_key not in functions_keys: + continue + + dispatch_namespace = dispatch_key.lower() + dispatch_names = [] + + for name, functions in functions_by_root_name.items(): + grouped_functions = grouped_functions_by_root_name.get(name, []) + declarations = list( + concatMap( + dest.RegisterDispatchKey( + backend_indices[dispatch_key], + Target.NAMESPACED_DECLARATION, + selector, + rocm=rocm, + symint=True, + class_method_name=None, + skip_dispatcher_op_registration=False, + ), + grouped_functions, + ) + ) + + if len(declarations) == 0: + continue + + dispatch_names.append(name) + ops_fm.write_with_template( + f"{name}_{dispatch_namespace}_dispatch.h", + "DispatchKeyFunction.h", + lambda: { + "dispatch_namespace": dispatch_namespace, + "dispatch_namespaced_declarations": declarations, + }, + ) + + fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm + inl_headers = f"#include " + + fm.write_with_template( + f"{dispatch_key}Functions.h", + "DispatchKeyFunctions.h", + lambda: { + "dispatch_key": str(dispatch_key), + "inline_headers": inl_headers, + }, + ) + fm.write_with_template( + f"{dispatch_key}Functions_inl.h", + "DispatchKeyFunctions_inl.h", + lambda: { + "dispatch_namespace": dispatch_namespace, + "DispatchKeyFunctions_inl_includes": [ + f"#include " + for name in sorted(dispatch_names) + ], + "dispatch_namespaced_declarations": [], + }, + ) + del fm + + cpu_fm.write( + "MethodOperators.h", + lambda: { + "MethodOperators_includes": sorted( + f"#include " + for name, functions in functions_by_root_name.items() + if any(Variant.method in fn.variants for fn in functions) + ), + "MethodOperators_declarations": [], + }, + ) + + +def gen_headers( + *, + native_functions: Sequence[NativeFunction], + valid_tags: Set[str], + grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], + structured_native_functions: Sequence[NativeFunctionsGroup], + static_dispatch_idx: List[BackendIndex], + selector: SelectiveBuilder, + backend_indices: Dict[DispatchKey, BackendIndex], + core_fm: FileManager, + cpu_fm: FileManager, + cuda_fm: FileManager, + ops_fm: FileManager, + dispatch_keys: Sequence[DispatchKey], + functions_keys: Set[DispatchKey], + rocm: bool, + per_operator_headers: bool, +) -> None: + if per_operator_headers: + gen_per_operator_headers( + native_functions=native_functions, + grouped_native_functions=grouped_native_functions, + static_dispatch_idx=static_dispatch_idx, + selector=selector, + backend_indices=backend_indices, + cpu_fm=cpu_fm, + cuda_fm=cuda_fm, + ops_fm=ops_fm, + dispatch_keys=dispatch_keys, + functions_keys=functions_keys, + rocm=rocm, + ) + else: + gen_aggregated_headers( + native_functions=native_functions, + grouped_native_functions=grouped_native_functions, + structured_native_functions=structured_native_functions, + static_dispatch_idx=static_dispatch_idx, + selector=selector, + backend_indices=backend_indices, + cpu_fm=cpu_fm, + cuda_fm=cuda_fm, + dispatch_keys=dispatch_keys, + functions_keys=functions_keys, + rocm=rocm, + ) + + core_fm.write( + "TensorBody.h", + lambda: { + "tensor_method_declarations": list( + mapMaybe( + ComputeTensorMethod( + target=Target.DECLARATION, + static_dispatch_backend_indices=static_dispatch_idx, + ), + native_functions, + ) + ), + "tensor_method_definitions": list( + mapMaybe( + ComputeTensorMethod( + target=Target.DEFINITION, + static_dispatch_backend_indices=static_dispatch_idx, + ), + native_functions, + ) + ), + }, + ) + + cpu_fm.write( + "RedispatchFunctions.h", + lambda: { + "function_redispatch_definitions": list( + mapMaybe(ComputeRedispatchFunction(), native_functions) + ), + }, + ) + + cpu_fm.write( + "RegistrationDeclarations.h", + lambda: { + "registration_declarations": [ + compute_registration_declarations(f, backend_indices) + for f in native_functions + ], + }, + ) + + cpu_fm.write( + "VmapGeneratedPlumbing.h", lambda: gen_all_vmap_plumbing(native_functions) + ) + + def gen_aten_interned_strings() -> Dict[str, str]: + attrs = set() # All function argument names + names = set() # All ATen function names + for func in native_functions: + names.add(str(func.func.name.name)) + # Some operators don't have a functional variant but we still create a + # symbol without the underscore + names.add(func.func.name.name.base) + + for arg in func.func.schema_order_arguments(): + attrs.add(arg.name) + + # These are keywords in C++, so aren't valid symbol names + # https://en.cppreference.com/w/cpp/language/operator_alternative + names -= { + "and", + "and_eq", + "bitand", + "bitor", + "compl", + "not", + "not_eq", + "or", + "or_eq", + "xor", + "xor_eq", + } + + return { + "aten_symbols": " \\\n".join( + [f"_(aten, {name})" for name in sorted(names)] + ), + "attr_symbols": " \\\n".join( + [f"_(attr, {name})" for name in sorted(attrs)] + ), + } + + core_fm.write("aten_interned_strings.h", gen_aten_interned_strings) + + def gen_tags_enum() -> Dict[str, str]: + return {"enum_of_valid_tags": (",\n".join(sorted(valid_tags)))} + + core_fm.write("enum_tag.h", gen_tags_enum) + + +def gen_source_files( + *, + native_functions: Sequence[NativeFunction], + grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], + structured_native_functions: Sequence[NativeFunctionsGroup], + view_groups: Sequence[NativeFunctionsViewGroup], + selector: SelectiveBuilder, + static_dispatch_idx: List[BackendIndex], + backend_indices: Dict[DispatchKey, BackendIndex], + aoti_fm: FileManager, + core_fm: FileManager, + cpu_fm: FileManager, + cpu_vec_fm: FileManager, + cuda_fm: FileManager, + dispatch_keys: Sequence[DispatchKey], + functions_keys: Set[DispatchKey], + rocm: bool, + force_schema_registration: bool, + per_operator_headers: bool, + skip_dispatcher_op_registration: bool, +) -> None: + extra_cuda_headers = """\ +#include +#include +#include +#include """ + if rocm: + extra_cuda_headers = """\ +#include +#include +#include +#include """ + + for dispatch_key in dispatch_keys: + fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm + + if per_operator_headers: + + def operator_headers() -> List[str]: + headers = [] + for g in grouped_native_functions: + is_registered = False + if backend_index.has_kernel(g): + is_registered = True + # The above has_kernel test on a group will only test for + # the existence of out dispatch, because that's how + # structured kernels work. But sometimes functions can be + # grouped but not be structured, and then you need to check + # each individual piece, as they may have manual dispatch + # entries. + elif isinstance(g, NativeFunctionsGroup) and any( + backend_index.has_kernel(fn) for fn in g.functions() + ): + is_registered = True + # TODO: this condition is a bit questionable + # (It has to do with the fact that structured kernels get generated kernels + # to the Meta + CompositeExplicitAutogradNonFunctional keys). + elif g.structured and dispatch_key in ( + DispatchKey.Meta, + DispatchKey.CompositeExplicitAutogradNonFunctional, + ): + is_registered = True + if not is_registered: + continue + + headers.append(f"#include ") + if ( + dispatch_key + == DispatchKey.CompositeExplicitAutogradNonFunctional + ): + headers.append(f"#include ") + if dispatch_key in functions_keys: + headers.append( + f"#include " + ) + + return sorted(set(headers)) + + else: + + def operator_headers() -> List[str]: + headers = ["#include "] + if dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional: + headers.append("#include ") + if dispatch_key in functions_keys: + headers.append(f"#include ") + return headers + + backend_index = backend_indices[dispatch_key] + ns_grouped_native_functions = defaultdict(list) + for grouped_native_function in grouped_native_functions: + namespace = ( + grouped_native_function.namespace + if isinstance(grouped_native_function, NativeFunction) + else grouped_native_function.functional.namespace + ) + ns_grouped_native_functions[namespace].append(grouped_native_function) + + dispatch_namespace = str(dispatch_key).lower() + + # CompositeImplicitAutogradNestdTensor does not currently user the helpers generated + # compilation will fail when `-Werror=unused-function` flag is set + gen_dispatch_helpers: bool = ( + dispatch_key != DispatchKey.CompositeImplicitAutogradNestedTensor + ) + + dispatch_definitions = get_native_function_definitions( + fm=fm, + grouped_native_functions=grouped_native_functions, + dispatch_key=dispatch_key, + backend_idx=backend_index, + selector=selector, + rocm=rocm, + symint=True, + skip_dispatcher_op_registration=skip_dispatcher_op_registration, + gen_dispatch_helpers=gen_dispatch_helpers, + ) + fm.write_with_template( + f"Register{dispatch_key}.cpp", + "RegisterDispatchKey.cpp", + lambda: { + "extra_cuda_headers": extra_cuda_headers + if is_cuda_dispatch_key(dispatch_key) + else "", + "external_backend_headers": "", + "dispatch_headers": dest.gen_registration_headers( + backend_index, per_operator_headers, rocm + ), + "ops_headers": operator_headers(), + "dispatch_helpers": "", + "dispatch_definitions": dispatch_definitions, + }, + ) + + for g in structured_native_functions: + if not g.out.ufunc_inner_loop or not is_ufunc_dispatch_key(dispatch_key): + continue + name = g.functional.func.name.name + if dispatch_key is DispatchKey.CPU: + assert fm is cpu_fm + fm.write_with_template( + f"UfuncCPU_{name}.cpp", + "UfuncCPU.cpp", + lambda: { + "meta_declaration": compute_meta_function_declaration(g), + "native_declaration": dest.compute_native_function_declaration( + g, backend_indices[dispatch_key] + ), + "native_definitions": dest.compute_ufunc_cpu(g), + }, + ) + cpu_vec_fm.write_with_template( + f"UfuncCPUKernel_{name}.cpp", + "UfuncCPUKernel.cpp", + lambda: { + "name": name, + "native_definitions": dest.compute_ufunc_cpu_kernel(g), + }, + ) + elif dispatch_key is DispatchKey.CUDA: + cuda_headers = "#include " + if rocm: + cuda_headers = "#include " + fm.write_with_template( + f"UfuncCUDA_{name}.cu", + "UfuncCUDA.cu", + lambda: { + "name": name, + "cuda_headers": cuda_headers, + "meta_declaration": compute_meta_function_declaration(g), + "native_declaration": dest.compute_native_function_declaration( + g, backend_indices[dispatch_key] + ), + "native_definitions": dest.compute_ufunc_cuda(g), + }, + ) + else: + raise AssertionError(f"unrecognized {dispatch_key} for ufunc") + + if dispatch_key in (DispatchKey.CPU, DispatchKey.CUDA): + + def get_header( + f: NativeFunction, + ) -> Optional[str]: + backend_index = get_backend_index_for_aoti( + f, dispatch_key, backend_indices + ) + return ( + None + if backend_index is None + else f"#include " + ) + + def headers_for_aoti() -> str: + headers = [] + for g in grouped_native_functions: + if isinstance(g, NativeFunctionsGroup): + for f in g.functions(): + # some variants are registered in the backend, but some are registered as CompositeExplicitAutograd + header = get_header(f) + if header is not None: + headers.append(header) + else: + header = get_header(g) + if header is not None: + headers.append(header) + return "\n".join(sorted(set(headers))) + + extra_headers = ( + extra_cuda_headers if is_cuda_dispatch_key(dispatch_key) else "" + ) + + aoti_fm.write( + f"c_shim_{dispatch_key.lower()}.h", + lambda: gen_aoti_c_shim( + native_functions, + dispatch_key, + backend_indices, + header=True, + includes="", + ), + ) + aoti_fm.write( + f"c_shim_{dispatch_key.lower()}.cpp", + lambda: gen_aoti_c_shim( + native_functions, + dispatch_key, + backend_indices, + header=False, + includes=headers_for_aoti() + "\n" + extra_headers, + ), + ) + + del fm + + # BackendSelect is generated specially + def gen_backend_select() -> Dict[str, List[str]]: + relevant_fns = [ + fn for fn in native_functions if needs_backend_select(fn, selector) + ] + return { + "ops_headers": [ + f"#include " for fn in relevant_fns + ], + "backend_select_method_definitions": list( + mapMaybe( + ComputeBackendSelect(Target.DEFINITION, selector), relevant_fns + ) + ), + "backend_select_function_registrations": list( + mapMaybe( + ComputeBackendSelect(Target.REGISTRATION, selector), relevant_fns + ) + ), + } + + cpu_fm.write("RegisterBackendSelect.cpp", gen_backend_select) + + schema_selector = selector + if force_schema_registration: + schema_selector = SelectiveBuilder.get_nop_selector() + + ( + aten_schema_registrations, + schema_registrations, + ) = get_native_function_schema_registrations( + native_functions=native_functions, schema_selector=schema_selector + ) + cpu_fm.write( + "RegisterSchema.cpp", + lambda: { + "aten_schema_registrations": [] + if skip_dispatcher_op_registration + else aten_schema_registrations, + "schema_registrations": [] + if skip_dispatcher_op_registration + else schema_registrations, + }, + ) + + def key_func( + fn: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup] + ) -> str: + return fn.root_name + + cpu_fm.write_sharded( + "Operators.cpp", + native_functions, + key_fn=key_func, + env_callable=lambda fn: { + "operator_headers": [f"#include "], + "definitions": [ + ComputeOperators( + Target.DEFINITION, + static_dispatch_backend_indices=static_dispatch_idx, + )(fn) + ], + }, + base_env={ + "static_dispatch_extra_headers": static_dispatch_extra_headers( + static_dispatch_idx + ), + }, + num_shards=5, + sharded_keys={ + "operator_headers", + "definitions", + "static_dispatch_extra_headers", + }, + ) + + cpu_fm.write("Functions.cpp", dict) + + core_fm.write("TensorMethods.cpp", dict) + + core_fm.write( + "ATenOpList.cpp", + lambda: { + "aten_ops": list(mapMaybe(compute_aten_op, native_functions)), + }, + ) + + def functionalization_env_callable( + g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup] + ) -> Dict[str, List[str]]: + def gen_op_headers( + g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup] + ) -> List[str]: + if isinstance(g, NativeFunctionsViewGroup): + # view ops always get a functionalization kernel + headers = [ + f"#include ", + f"#include ", + ] + if g.view_copy is not None: + headers += [ + f"#include ", + f"#include ", + ] + return headers + elif isinstance(g, NativeFunctionsGroup): + headers = [ + f"#include ", + f"#include ", + f"#include ", + f"#include ", + ] + if g.inplace is not None: + headers += [ + f"#include ", + f"#include ", + ] + if g.mutable is not None: + headers += [ + f"#include ", + f"#include ", + ] + return headers + else: + return [ + f"#include ", + f"#include ", + ] + + return { + "ops_headers": gen_op_headers(g), + "func_definitions": gen_functionalization_definition( + selector, + g, + ), + "func_registrations": gen_functionalization_registration( + selector, + g, + backend_indices[DispatchKey.CompositeImplicitAutograd], + ), + } + + all_groups: List[ + Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup] + ] = list(structured_native_functions) + list( + view_groups # type: ignore[assignment, arg-type, operator] + ) + # Note: all operators that functionalization needs to handle (mutable and aliasing ops) should be grouped properly. + # The only reason we really need to deal with direct NativeFunctions here (instead of the groups) is because: + # (1) We can provide better error checking (error out if someone introduces a mutable op that doesn't obey the grouping logic) + # (2) functionalization needs to manually register CompositeImplicitAutograd kernels, which might not be grouped. + # Although this could go away long-term if we add a dedicated dispatch key for decompositions. + structured_map: Dict[OperatorName, NativeFunction] = { + f.func.name: f + for f in concatMap(lambda g: list(g.functions()), structured_native_functions) + } + view_map: Dict[OperatorName, NativeFunction] = { + f.func.name: f for f in concatMap(lambda g: list(g.functions()), view_groups) + } + for f in native_functions: + if f.func.name not in structured_map and f.func.name not in view_map: + all_groups.append(f) + + cpu_fm.write_sharded( + "RegisterFunctionalization.cpp", + all_groups, + key_fn=key_func, + env_callable=functionalization_env_callable, + num_shards=4, + sharded_keys={ + "ops_headers", + "func_definitions", + "func_registrations", + "func_add_back_views_definitions", + "func_add_back_views_registrations", + }, + ) + + cpu_fm.write( + "FunctionalInverses.h", + lambda: { + "view_inverse_declarations": list( + mapMaybe( + lambda g: gen_functionalization_view_inverse_declaration( + selector, g + ), + view_groups, + ) + ) + }, + ) + + # Note [view_copy NativeFunctions] + # Every view operator in native_functions.yaml that is not CompositeImplicitAutograd + # needs to have a corresponding non-aliasing {view}_copy variant. + # Backends that use functionalization and don't know how to handle aliasing ops + # are expected to implement kernels for these {view}_copy kernels instead. + # The code for {view}_copy operators in core is pretty boilerplate-heavy however, + # so we codegen the following: + # (1) A CompositeExplicitAutogradNonFunctional kernel for every {view}_copy operator. + # These are never explicitly invoked by the functionalization pass, + # but they could theoretically be called from user code (I added these kernels for completeness, + # since the ops are part of the public API). + # (2) A derivative formula for every {view}_copy operator + # {view}_copy operators can re-use the same derivative formulas as their {view} op counterparts, + # so rather than stamping all of the entries out in derivatives.yaml, + # we codegen them in. + # This is similar to how autograd codegen doesn't require inplace ops to have a derivatives.yaml entry. + cpu_fm.write( + "CompositeViewCopyKernels.cpp", + lambda: { + "ops_headers": [ + "\n".join( + f"#include \n" + # NB: this include is important as it ensures we + # set the visibility on generated view_copy kernels + # correctly + f"#include " + for f in ( + [g.view] if g.view_copy is None else [g.view, g.view_copy] + ) + ) + for g in view_groups + ] + + [ + "\n".join( + f"#include " + for f in [g.inplace, g.mutable, g.functional] + if f is not None and "generated" not in f.tags + ) + for g in structured_native_functions + ], + "CompositeViewCopyKernel_Definitions": list( + mapMaybe( + GenCompositeViewCopyKernel( + backend_indices[ + DispatchKey.CompositeExplicitAutogradNonFunctional + ] + ), + view_groups, + ) + ), + "GeneratedCompositeFunctional_Definitions": list( + mapMaybe( + gen_composite_functional_kernel, + structured_native_functions, + ) + ), + "GeneratedCompositeOut_Definitions": list( + mapMaybe( + gen_composite_out_kernel, + structured_native_functions, + ) + ), + }, + ) + + +def gen_declarations_yaml( + cpu_fm: FileManager, native_functions: Sequence[NativeFunction] +) -> None: + cpu_fm.write( + "Declarations.yaml", + lambda: format_yaml([compute_declaration_yaml(f) for f in native_functions]), + ) + + +def get_torchgen_root() -> pathlib.Path: + """ + If you're depending on torchgen out-of-tree, you can use the root to figure + out the path to native_functions.yaml + """ + return pathlib.Path(__file__).parent.resolve() + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate ATen source files") + parser.add_argument( + "-s", + "--source-path", + help="path to source directory for ATen", + default="aten/src/ATen", + ) + parser.add_argument( + "-o", + "--output-dependencies", + help="output a list of dependencies into the given file and exit", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="run without writing any files (still updates outputs)", + ) + parser.add_argument( + "--per-operator-headers", + action="store_true", + help="generate separate headers per operator in ATen/ops", + ) + parser.add_argument( + "-d", + "--install-dir", + "--install_dir", + help="output directory", + default="build/aten/src/ATen", + ) + parser.add_argument( + "--rocm", + action="store_true", + help="reinterpret CUDA as ROCm/HIP and adjust filepaths accordingly", + ) + parser.add_argument( + "--mps", + action="store_true", + help="Generate MPS registration code when set", + ) + # TODO: --op-registration-whitelist will be removed when all call-sites + # for gen.py are moved over to using the operator YAML file for mobile + # custom build. + parser.add_argument( + "--op-registration-whitelist", + "--op_registration_whitelist", + nargs="*", + help="filter op registrations by the whitelist (if set); " + "each item is `namespace`::`operator name` without overload name; " + "e.g.: aten::empty aten::conv2d ...", + ) + parser.add_argument( + "--op-selection-yaml-path", + "--op_selection_yaml_path", + help="Provide a path to the operator selection (for custom build) YAML " + "that contains the information about the set of selected operators " + "and their categories (training, ...). Each operator is either a " + "full operator name with overload or just a bare operator name. " + "The operator names also contain the namespace prefix (e.g. aten::)", + ) + parser.add_argument( + "--backend-whitelist", + "--backend_whitelist", + nargs="*", + help="filter dispatch backend by the whitelist (if set), " + "e.g.: CPU CUDA QuantizedCPU ...", + ) + parser.add_argument( + "--static-dispatch-backend", + "--static_dispatch_backend", + nargs="*", + help="generate static dispatch code for the specific backend (if set)", + ) + parser.add_argument( + "--skip-dispatcher-op-registration", + "--skip_dispatcher_op_registration", + action="store_true", + help="Avoid registering operators into the dispatcher.", + ) + parser.add_argument( + "--force-schema-registration", + "--force_schema_registration", + action="store_true", + help="force it to generate schema-only registrations for all ops, including" + "those that are not listed on --op-registration-whitelist", + ) + parser.add_argument( + "--generate", + type=str, + nargs="*", + choices=["headers", "sources", "declarations_yaml"], + default=["headers", "sources", "declarations_yaml"], + help="Generate only a subset of files", + ) + + options = parser.parse_args() + + selector = get_custom_build_selector( + options.op_registration_whitelist, + options.op_selection_yaml_path, + ) + + native_yaml_path = os.path.join(options.source_path, "native/native_functions.yaml") + tags_yaml_path = os.path.join(options.source_path, "native/tags.yaml") + + from torchgen.model import dispatch_keys + + # TODO: stop generating CUDA kernels for non-CUDA builds + ignore_keys = set() + if not options.mps: + ignore_keys.add(DispatchKey.MPS) + + if DispatchKey.MPS in dispatch_keys: + del dispatch_keys[dispatch_keys.index(DispatchKey.MPS)] + + parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path, ignore_keys) + valid_tags = _GLOBAL_PARSE_TAGS_YAML_CACHE[tags_yaml_path] + native_functions, backend_indices = ( + parsed_yaml.native_functions, + parsed_yaml.backend_indices, + ) + + grouped_native_functions = get_grouped_native_functions(native_functions) + + structured_native_functions = [ + g for g in grouped_native_functions if isinstance(g, NativeFunctionsGroup) + ] + native_functions_with_view_groups = get_grouped_by_view_native_functions( + native_functions + ) + view_groups = [ + g + for g in native_functions_with_view_groups + if isinstance(g, NativeFunctionsViewGroup) + ] + + # NB: It is mandatory to NOT use os.path.join here, as the install directory + # will eventually be ingested by cmake, which does not respect Windows style + # path slashes. If you switch this to use os.path.join, you'll get an error + # like: + # + # Syntax error in cmake code when parsing string + # + # C:/Jenkins/workspace/pytorch-builds/pytorch-win-ws2016-cuda9-cudnn7-py3-build/build/aten/src/ATen\core/TensorMethods.h + # + # Invalid character escape '\c'. + core_install_dir = f"{options.install_dir}/core" + pathlib.Path(core_install_dir).mkdir(parents=True, exist_ok=True) + ops_install_dir = f"{options.install_dir}/ops" + pathlib.Path(ops_install_dir).mkdir(parents=True, exist_ok=True) + + core_fm = make_file_manager(options=options, install_dir=core_install_dir) + cpu_fm = make_file_manager(options=options) + cpu_vec_fm = make_file_manager(options=options) + cuda_fm = make_file_manager(options=options) + ops_fm = make_file_manager(options=options, install_dir=ops_install_dir) + aoti_fm = make_file_manager( + options=options, install_dir="torch/csrc/inductor/aoti_torch/generated" + ) + + # Only a limited set of dispatch keys get CPUFunctions.h headers generated + # for them; this is the set + functions_keys = { + DispatchKey.CPU, + DispatchKey.CUDA, + DispatchKey.CompositeImplicitAutograd, + DispatchKey.CompositeImplicitAutogradNestedTensor, + DispatchKey.CompositeExplicitAutograd, + DispatchKey.CompositeExplicitAutogradNonFunctional, + DispatchKey.Meta, + } + if options.mps: + functions_keys.add(DispatchKey.MPS) + + if options.backend_whitelist: + dispatch_keys = [ + k + for k in dispatch_keys + if is_generic_dispatch_key(k) or str(k) in options.backend_whitelist + ] + + static_dispatch_idx: List[BackendIndex] = [] + if options.static_dispatch_backend: + static_dispatch_idx = [ + backend_indices[DispatchKey.parse(key)] + for key in options.static_dispatch_backend + ] + for key in options.static_dispatch_backend: + dp_key = DispatchKey.parse(key) + if dp_key not in functions_keys: + functions_keys.add(dp_key) + + if "sources" in options.generate: + gen_source_files( + native_functions=native_functions, + grouped_native_functions=grouped_native_functions, + structured_native_functions=structured_native_functions, + view_groups=view_groups, + selector=selector, + static_dispatch_idx=static_dispatch_idx, + backend_indices=backend_indices, + aoti_fm=aoti_fm, + core_fm=core_fm, + cpu_fm=cpu_fm, + cpu_vec_fm=cpu_vec_fm, + cuda_fm=cuda_fm, + dispatch_keys=dispatch_keys, + functions_keys=functions_keys, + rocm=options.rocm, + force_schema_registration=options.force_schema_registration, + per_operator_headers=options.per_operator_headers, + skip_dispatcher_op_registration=options.skip_dispatcher_op_registration, + ) + + if "headers" in options.generate: + gen_headers( + native_functions=native_functions, + valid_tags=valid_tags, + grouped_native_functions=grouped_native_functions, + structured_native_functions=structured_native_functions, + static_dispatch_idx=static_dispatch_idx, + selector=selector, + backend_indices=backend_indices, + core_fm=core_fm, + cpu_fm=cpu_fm, + cuda_fm=cuda_fm, + ops_fm=ops_fm, + dispatch_keys=dispatch_keys, + functions_keys=functions_keys, + rocm=options.rocm, + per_operator_headers=options.per_operator_headers, + ) + + if "declarations_yaml" in options.generate: + gen_declarations_yaml(native_functions=native_functions, cpu_fm=cpu_fm) + + if options.output_dependencies: + depfile_path = pathlib.Path(options.output_dependencies).resolve() + depfile_name = depfile_path.name + depfile_stem = depfile_path.stem + + for fm, prefix in [ + (cpu_fm, ""), + (cpu_vec_fm, "cpu_vec_"), + (core_fm, "core_"), + (cuda_fm, "cuda_"), + (ops_fm, "ops_"), + ]: + varname = prefix + depfile_stem + path = depfile_path.parent / (prefix + depfile_name) + fm.write_outputs(varname, str(path)) + + +if __name__ == "__main__": + main() diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/operator_versions/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/operator_versions/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66a5928660abcae00d42056f72b3b6100eb653e0 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/operator_versions/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/operator_versions/__pycache__/gen_mobile_upgraders.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/operator_versions/__pycache__/gen_mobile_upgraders.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57bd0125f5cbdbb7a523e93be4f880b61e0bfcf7 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/operator_versions/__pycache__/gen_mobile_upgraders.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/operator_versions/__pycache__/gen_mobile_upgraders_constant.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/operator_versions/__pycache__/gen_mobile_upgraders_constant.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb8d8c4620a0fd5e406464f41408c8ffee897dbd Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/operator_versions/__pycache__/gen_mobile_upgraders_constant.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/Functions.cpp b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/Functions.cpp new file mode 100644 index 0000000000000000000000000000000000000000..05962bc2482914a32f04ce2b45bfda9579550bdd --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/Functions.cpp @@ -0,0 +1,103 @@ +#include + +#include +#include +#include + +namespace at { + +Tensor TensorMaker::make_tensor() { + AutoDispatchBelowADInplaceOrView guard{}; // TODO: Remove. + tracer::impl::NoTracerDispatchMode tracer_guard{}; + + check_size_nonnegative(sizes_); + + TORCH_CHECK_VALUE( + !deleter_ || !ctx_, + "The deleter and context arguments are mutually exclusive."); + + if (device_ == nullopt) { + device_ = globalContext().getDeviceFromPtr(data_, opts_.device().type()); + } + + if (opts_.device().has_index()) { + // clang-format off + TORCH_CHECK_VALUE( + opts_.device() == *device_, + "Specified device ", opts_.device(), " does not match device of data ", *device_); + // clang-format on + } + + std::size_t size_bytes = computeStorageSize(); + + DataPtr data_ptr{}; + if (deleter_) { + data_ptr = makeDataPtrFromDeleter(); + } else { + data_ptr = makeDataPtrFromContext(); + } + + TORCH_CHECK(!resizeable_ || allocator_ != nullptr, "Must specify an allocator with allocator() if you want to use resizeable_storage()"); + Storage storage{Storage::use_byte_size_t{}, size_bytes, std::move(data_ptr), /*allocator=*/allocator_, /*resizeable=*/resizeable_}; + + Tensor tensor = detail::make_tensor( + std::move(storage), opts_.computeDispatchKey(), opts_.dtype()); + + TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl(); + if (strides_) { + tensor_impl->set_sizes_and_strides(sizes_, *strides_); + } else { + tensor_impl->set_sizes_contiguous(sizes_); + } + if (storage_offset_) { + tensor_impl->set_storage_offset(*storage_offset_); + } + + return tensor; + } + + std::size_t TensorMaker::computeStorageSize() const noexcept { + std::size_t itemsize = opts_.dtype().itemsize(); + + if (strides_) { + auto storage_size = detail::computeStorageNbytes(sizes_, *strides_, itemsize); + if (storage_offset_) { + storage_size += storage_offset_.value(); + } + return storage_size; + } + + std::size_t size = 1; + for (std::int64_t s : sizes_) { + size *= static_cast(s); + } + auto storage_size = size * itemsize; + if (storage_offset_) { + storage_size += storage_offset_.value(); + } + return storage_size; + } + + inline DataPtr TensorMaker::makeDataPtrFromDeleter() noexcept { + return InefficientStdFunctionContext::makeDataPtr(data_, std::move(deleter_), *device_); + } + + inline DataPtr TensorMaker::makeDataPtrFromContext() noexcept { + return DataPtr{data_, ctx_.release(), ctx_.get_deleter(), *device_}; + } + + IntArrayRef TensorMaker::makeTempSizes() const noexcept { + static std::int64_t zeros[5] = {0, 0, 0, 0, 0}; + if (opts_.has_memory_format()) { + MemoryFormat format = *opts_.memory_format_opt(); + if (format == MemoryFormat::ChannelsLast) { + return IntArrayRef(zeros, 4); + } + if (format == MemoryFormat::ChannelsLast3d) { + return IntArrayRef(zeros, 5); + } + } + return IntArrayRef(zeros, 1); + } + +} // namespace at diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/LazyIr.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/LazyIr.h new file mode 100644 index 0000000000000000000000000000000000000000..1ee90e66cc6cedc616baa725c2fd562a7fcfdda2 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/LazyIr.h @@ -0,0 +1,19 @@ +#pragma once + +// This file contains autogenerated LazyTensor IR nodes +${lazy_ir_sysinc} +${lazy_ir_inc} + +${namespace_prologue} +using at::operator<<; + +// kNullValue is used to contribute a static hash value any time +// a node has an Optional input that is nullopt. It is important +// to differentiate between HASH(nullopt, something) and HASH(something, nullopt), +// and using kNullValue in the hash function in the order of arguments +// serves this purpose. +static const torch::lazy::Value kNullValue = torch::lazy::Value(); + +${ir_declarations} + +${namespace_epilogue} diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/NativeFunctions.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/NativeFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..d6d7205b5793ba91c720757cdac7168a4a16dbc0 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/NativeFunctions.h @@ -0,0 +1,33 @@ +#pragma once + +// ${generated_comment} + +#ifdef TORCH_ASSERT_NO_OPERATORS +#error This change adds a dependency on native_functions.yaml, \ + meaning the file will need to be re-compiled every time an operator \ + is changed or added. Consider if your change would be better placed in \ + another file, or if a more specific header might achieve the same goal. \ + See NOTE: [Tensor vs. TensorBase] +#endif + +#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS) +#error This change adds a dependency on all pytorch operators, meaning the \ + file will need to be re-compiled every time an operator is changed or added. \ + Consider including a specific operator from \ + and see NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS]. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +${NativeFunctions_includes} + +${NativeFunctions_declarations} diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/Operator.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/Operator.h new file mode 100644 index 0000000000000000000000000000000000000000..8b3989b66debc86e3782169c29a6f83fea222ac6 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/Operator.h @@ -0,0 +1,18 @@ +#pragma once + +// ${generated_comment} + +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + +${declarations} + +}} // namespace at::_ops diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/RegisterDispatchKey.cpp b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/RegisterDispatchKey.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7a1584d505f5a3c42861fde0ea5ee4da67485a32 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/RegisterDispatchKey.cpp @@ -0,0 +1,54 @@ +// required for old g++ to compile PRId64 macros, see +// https://github.com/pytorch/pytorch/issues/3571 +// for context +#ifndef __STDC_FORMAT_MACROS +#define __STDC_FORMAT_MACROS +#endif + +// an external backend might generate file within its code tree +// and check all the source files within the tree with clang-format. +// so, disable it since the backend might have a different config. +// clang-format off + +// NOTE: This condition is true for all PyTorch internal libraries, it +// just excludes external projects such as torch_xla which +// re-use some of the PyTorch codegen machinery. +#if defined(CAFFE2_BUILD_MAIN_LIB) || \ + defined(TORCH_CUDA_BUILD_MAIN_LIB) || \ + defined(TORCH_HIP_BUILD_MAIN_LIB) || \ + defined(TORCH_CUDA_CU_BUILD_MAIN_LIB) || \ + defined(TORCH_CUDA_CPP_BUILD_MAIN_LIB) +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#endif + +// ${generated_comment} + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +$extra_cuda_headers +$external_backend_headers +$dispatch_headers +$ops_headers + +// See template file RegisterDispatchDefinitions.ini +$dispatch_definitions diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/RegisterFunctionalization.cpp b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/RegisterFunctionalization.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4eb587ab468d291d8d3bbe5297c1046e248c98a3 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/RegisterFunctionalization.cpp @@ -0,0 +1,110 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +// ${generated_comment} + +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +// needed for the meta tensor calls to get stride info in functionalization +#include +// needed for special handling of copy_(). +// See Note [functionalizating copy_() and not preserving strides] +#include +#include + +$ops_headers +#endif + +namespace at { +namespace functionalization { + +// This keyset is used by functionalization when it calls into meta kernels +// to accurately propagate stride metadata. +// Exclude any modes: the purpose of calling into meta kernels is only as an implementation +// detail to perform shape inference, and we don't want any modal keys to run. +// Specifically, we want to prevent functionalization and Python modes from running. +constexpr auto exclude_keys_for_meta_dispatch = + c10::functorch_transforms_ks | + c10::DispatchKeySet({ + c10::DispatchKey::FuncTorchDynamicLayerBackMode, + c10::DispatchKey::FuncTorchDynamicLayerFrontMode, + c10::DispatchKey::Python, + c10::DispatchKey::PreDispatch, + + }); + +// Helper around at::has_internal_overlap. +// The ATen util is used in hot-path eager mode: it's always fast, +// but might return TOO_HARD sometimes. +// During functionalization, we're ok taking a bit longer +// to detect memory overlap. +inline bool has_internal_overlap_helper(const at::Tensor t) { + auto has_overlap = at::has_internal_overlap(t); + if (has_overlap == at::MemOverlap::Yes) return true; + if (has_overlap == at::MemOverlap::No) return false; + return false; +} + + +inline Tensor to_meta(const Tensor& t) { + if (!t.defined()) return t; + return at::native::empty_strided_meta_symint(t.sym_sizes(), t.sym_strides(), +/*dtype=*/c10::make_optional(t.scalar_type()), /*layout=*/c10::make_optional(t.layout()), +/*device=*/c10::make_optional(c10::Device(kMeta)), /*pin_memory=*/c10::nullopt); +} + +inline c10::optional to_meta(const c10::optional& t) { + if (t.has_value()) { + return c10::make_optional(to_meta(*t)); + } + return c10::nullopt; +} + +inline std::vector to_meta(at::ITensorListRef t_list) { + std::vector outputs; + outputs.reserve(t_list.size()); + for (const auto& tensor : t_list) { + outputs.push_back(to_meta(tensor)); + } + return outputs; +} + +inline c10::List to_meta(const c10::List& t_list) { + c10::List outputs; + outputs.reserve(t_list.size()); + for (const auto i : c10::irange(t_list.size())) { + outputs.push_back(to_meta(t_list[i])); + } + return outputs; +} + +inline c10::List> to_meta(const c10::List>& t_list) { + c10::List> outputs; + outputs.reserve(t_list.size()); + for (const auto i : c10::irange(t_list.size())) { + outputs.push_back(to_meta(t_list[i])); + } + return outputs; +} + + +${func_definitions} + +} // namespace functionalization + +namespace { + +TORCH_LIBRARY_IMPL(aten, Functionalize, m) { + ${func_registrations}; +} + +} // namespace + +} // namespace at diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/RegisterSchema.cpp b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/RegisterSchema.cpp new file mode 100644 index 0000000000000000000000000000000000000000..029796d3e575b2bde85cfd44af9e6fcbb56466cd --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/RegisterSchema.cpp @@ -0,0 +1,13 @@ +// ${generated_comment} +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include + +namespace at { +TORCH_LIBRARY(aten, m) { + ${aten_schema_registrations}; + // Distributed Ops + // Implementations located in torch/csrc/jit/runtime/register_distributed_ops.cpp + m.def("get_gradients(int context_id) -> Dict(Tensor, Tensor)"); +} +${schema_registrations} +} // namespace at diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/UfuncCPUKernel.cpp b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/UfuncCPUKernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0cac55664d6125287bdee0bd94c150462b81d5b9 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/UfuncCPUKernel.cpp @@ -0,0 +1,14 @@ +#define TORCH_ASSERT_NO_OPERATORS + +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +${native_definitions} +}} // namespace at::native diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/aten_interned_strings.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/aten_interned_strings.h new file mode 100644 index 0000000000000000000000000000000000000000..326d4622334a776f4f1f94fb49a70f2c53c7e6eb --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/aten_interned_strings.h @@ -0,0 +1,22 @@ +#pragma once + +// ${generated_comment} + +#if defined(TORCH_ASSERT_NO_OPERATORS) || defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS) +#error This change adds a dependency on native_functions.yaml, \ + meaning the file will need to be re-compiled every time an operator \ + is changed or added. Consider if including for \ + the c10::Symbol class would be sufficient, or if your change would be \ + better placed in another file. +#endif + +// ATen symbols correspond exactly to operators defined in ATen. Every +// symbol here corresponds exactly to an ATen operation defined in +// native_functions.yaml; attributes are in one-to-one correspondence +// with their ATen name. + +#define FORALL_ATEN_BASE_SYMBOLS(_) \ +${aten_symbols} + +#define FORALL_ATTR_BASE_SYMBOLS(_) \ +${attr_symbols} diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/selective_build/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/selective_build/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/selective_build/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/selective_build/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4989d24e894993f5b7b4054246318b405fef7991 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/selective_build/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/selective_build/__pycache__/operator.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/selective_build/__pycache__/operator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c722b9678e1405227bfce79f0589ed93f3fd489c Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/selective_build/__pycache__/operator.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/selective_build/__pycache__/selector.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/selective_build/__pycache__/selector.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be910d5fbc747899d2dcb14ea61b5b5911eb0b28 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/selective_build/__pycache__/selector.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/selective_build/operator.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/selective_build/operator.py new file mode 100644 index 0000000000000000000000000000000000000000..feb4f08bb822eb1be99c67ad4415041f3648b67a --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/selective_build/operator.py @@ -0,0 +1,170 @@ +from dataclasses import dataclass +from typing import Dict, Optional, Tuple + + +# This class holds information about a single operator used to determine +# the outcome of a selective/custom PyTorch build that doesn't include +# registration code for all the supported operators. This is done to +# reduce the size of the generated binary so that it can be deployed in +# situations where binary size comes at a premium. +# +@dataclass(frozen=True) +class SelectiveBuildOperator: + # The name of the operator. This includes the aten::, etc... prefix + # The operator name may or may not have the overload name. If this + # operator name does not specify an overload name, the way to determine + # if this entry refers to the family of operators with this base name + # or just the operator with this name is to look at the value of the + # 'include_all_overloads' flag in this class. + name: str + + # True if this is a root operator (i.e. called directly from a + # TorchScript model, etc...). An operator is considered to be a + # root operator if it is called directly from any one of the models + # that this instance of the pytorch library was built for. Hence, it + # may not be a root operator in all of the models that are used in + # this instance of the pytorch library. + is_root_operator: bool + + # Is this operator used for on-device training? If True, then we need to + # use the information to generate code in VariableType_N.cpp for registration + # of training related operators. Again, this is True if this operator + # is used for training in one or more models used by this instance of the + # pytorch library. + is_used_for_training: bool + + # If True, it indicates that this operator instance (object) refers to an + # operator without the overload name and should apply to all overloads + # which have this operator name as the base name. This flag is applicable + # only for objects that have operator names without a DOT (period) character + # in them. + # + # Note: This flag is a temporary workaround to grandfather in the current + # static selective (custom) build mechanism, which largely ignores overload + # names when determining whether to select operators for registration + # purposes. + include_all_overloads: bool + + # Debug Information at the operator level + _debug_info: Optional[Tuple[str, ...]] + + @staticmethod + def from_yaml_dict( + op_name: str, op_info: Dict[str, object] + ) -> "SelectiveBuildOperator": + allowed_keys = { + "name", + "is_root_operator", + "is_used_for_training", + "include_all_overloads", + "debug_info", + } + + if len(set(op_info.keys()) - allowed_keys) > 0: + raise Exception( + "Got unexpected top level keys: {}".format( + ",".join(set(op_info.keys()) - allowed_keys), + ) + ) + + if "name" in op_info: + assert op_name == op_info["name"] + + is_root_operator = op_info.get("is_root_operator", True) + assert isinstance(is_root_operator, bool) + + is_used_for_training = op_info.get("is_used_for_training", True) + assert isinstance(is_used_for_training, bool) + + include_all_overloads = op_info.get("include_all_overloads", True) + assert isinstance(include_all_overloads, bool) + + debug_info: Optional[Tuple[str, ...]] = None + if "debug_info" in op_info: + di_list = op_info["debug_info"] + assert isinstance(di_list, list) + debug_info = tuple(str(x) for x in di_list) + + return SelectiveBuildOperator( + name=op_name, + is_root_operator=is_root_operator, + is_used_for_training=is_used_for_training, + include_all_overloads=include_all_overloads, + _debug_info=debug_info, + ) + + @staticmethod + def from_legacy_operator_name_without_overload( + name: str, + ) -> "SelectiveBuildOperator": + return SelectiveBuildOperator( + name=name, + is_root_operator=True, + is_used_for_training=True, + include_all_overloads=True, + _debug_info=None, + ) + + def to_dict(self) -> Dict[str, object]: + ret: Dict[str, object] = { + "is_root_operator": self.is_root_operator, + "is_used_for_training": self.is_used_for_training, + "include_all_overloads": self.include_all_overloads, + } + if self._debug_info is not None: + ret["debug_info"] = self._debug_info + + return ret + + +def merge_debug_info( + lhs: Optional[Tuple[str, ...]], + rhs: Optional[Tuple[str, ...]], +) -> Optional[Tuple[str, ...]]: + # Ensure that when merging, each entry shows up just once. + if lhs is None and rhs is None: + return None + + return tuple(set((lhs or ()) + (rhs or ()))) + + +def combine_operators( + lhs: "SelectiveBuildOperator", rhs: "SelectiveBuildOperator" +) -> "SelectiveBuildOperator": + if str(lhs.name) != str(rhs.name): + raise Exception( + f"Expected both arguments to have the same name, but got '{str(lhs.name)}' and '{str(rhs.name)}' instead" + ) + + return SelectiveBuildOperator( + name=lhs.name, + # Consider this operator to be a root operator if it is a + # root operator in any of the models used in this instance of + # the pytorch library. + is_root_operator=lhs.is_root_operator or rhs.is_root_operator, + # Consider this operator to be a training operator if it is + # an operator used for training in any of the models used + # in this instance of the pytorch library. + is_used_for_training=lhs.is_used_for_training or rhs.is_used_for_training, + include_all_overloads=lhs.include_all_overloads or rhs.include_all_overloads, + _debug_info=merge_debug_info(lhs._debug_info, rhs._debug_info), + ) + + +def merge_operator_dicts( + lhs: Dict[str, SelectiveBuildOperator], + rhs: Dict[str, SelectiveBuildOperator], +) -> Dict[str, SelectiveBuildOperator]: + operators: Dict[str, SelectiveBuildOperator] = {} + for op_name, op in list(lhs.items()) + list(rhs.items()): + new_op = op + if op_name in operators: + new_op = combine_operators(operators[op_name], op) + + operators[op_name] = new_op + + return operators + + +def strip_operator_overload_name(op_name: str) -> str: + return op_name.split(".")[0] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/selective_build/selector.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/selective_build/selector.py new file mode 100644 index 0000000000000000000000000000000000000000..4fdc513534444d83e58d267ae4a0d7fed0d5b190 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/selective_build/selector.py @@ -0,0 +1,347 @@ +from collections import defaultdict +from collections.abc import Iterable +from dataclasses import dataclass +from typing import Dict, List, Optional, Set, Tuple + +import yaml + +from torchgen.model import NativeFunction +from torchgen.selective_build.operator import ( + merge_debug_info, + merge_operator_dicts, + SelectiveBuildOperator, + strip_operator_overload_name, +) + + +# A SelectiveBuilder holds information extracted from the selective build +# YAML specification. +# +# It includes information about the build's selectivity, the debug_info +# associated with this selective build (opaque string), and the set of +# operators that should be included in the build. +# +@dataclass(frozen=True) +class SelectiveBuilder: + # If true, then the build is not selective, and includes all + # operators. + include_all_operators: bool + + # Debug Information at the selective/custom build level. + _debug_info: Optional[Tuple[str, ...]] + + # A dictionary of operator -> operator metadata. + operators: Dict[str, SelectiveBuildOperator] + + # A dictionary of selected kernel tags and dtypes. Typically a + # PyTorch Operator Kernel (function) may have many code paths + # that are specialized for many many Tensor dtypes, so it's not + # one per kernel function, but there could be many per kernel + # function. The tag isn't a kernel function name, but some fragment + # of the kernel function implementation itself. + kernel_metadata: Dict[str, List[str]] + + # ExecuTorch only. A dictionary of kernel tag -> list of (list of input + # dtypes for tensor-like input args). + # This is from selective.yaml + et_kernel_metadata: Dict[str, List[str]] + + # A set of all the custom torch bind classes used by the selected models + # Stored as a set internally to remove duplicates proactively, but written + # as a list to yamls + custom_classes: Set[str] + + # A set of all the build features used by the selected models + # Stored as a set internally to remove duplicates proactively, but written + # as a list to yamls + build_features: Set[str] + + # If true, then fragments for all dtypes for all kernel functions + # are included as well as all custom classes. This is typically set when any one of the + # operator lists is generated from a mechanism other than + # tracing based selective build. + include_all_non_op_selectives: bool + + @staticmethod + def get_nop_selector() -> "SelectiveBuilder": + return SelectiveBuilder.from_yaml_dict({"include_all_operators": True}) + + @staticmethod + def from_yaml_dict(data: Dict[str, object]) -> "SelectiveBuilder": + valid_top_level_keys = { + "include_all_non_op_selectives", + "include_all_operators", + "debug_info", + "operators", + "kernel_metadata", + "et_kernel_metadata", + "custom_classes", + "build_features", + } + top_level_keys = set(data.keys()) + if len(top_level_keys - valid_top_level_keys) > 0: + raise Exception( + "Got unexpected top level keys: {}".format( + ",".join(top_level_keys - valid_top_level_keys), + ) + ) + include_all_operators = data.get("include_all_operators", False) + assert isinstance(include_all_operators, bool) + + debug_info = None + if "debug_info" in data: + di_list = data["debug_info"] + assert isinstance(di_list, list) + + debug_info = tuple(str(x) for x in di_list) + + operators = {} + operators_dict = data.get("operators", {}) + assert isinstance(operators_dict, dict) + + for k, v in operators_dict.items(): + operators[k] = SelectiveBuildOperator.from_yaml_dict(k, v) + + kernel_metadata = {} + kernel_metadata_dict = data.get("kernel_metadata", {}) + assert isinstance(kernel_metadata_dict, dict) + + for k, v in kernel_metadata_dict.items(): + kernel_metadata[str(k)] = [str(dtype) for dtype in v] + + et_kernel_metadata = data.get("et_kernel_metadata", {}) + assert isinstance(et_kernel_metadata, dict) + + custom_classes = data.get("custom_classes", []) + assert isinstance(custom_classes, Iterable) + custom_classes = set(custom_classes) + + build_features = data.get("build_features", []) + assert isinstance(build_features, Iterable) + build_features = set(build_features) + + include_all_non_op_selectives = data.get("include_all_non_op_selectives", False) + assert isinstance(include_all_non_op_selectives, bool) + + return SelectiveBuilder( + include_all_operators, + debug_info, + operators, + kernel_metadata, + et_kernel_metadata, + custom_classes, # type: ignore[arg-type] + build_features, # type: ignore[arg-type] + include_all_non_op_selectives, + ) + + @staticmethod + def from_yaml_str(config_contents: str) -> "SelectiveBuilder": + contents = yaml.safe_load(config_contents) + return SelectiveBuilder.from_yaml_dict(contents) + + @staticmethod + def from_yaml_path(config_path: str) -> "SelectiveBuilder": + with open(config_path) as f: + contents = yaml.safe_load(f) + return SelectiveBuilder.from_yaml_dict(contents) + + @staticmethod + def from_legacy_op_registration_allow_list( + allow_list: Set[str], is_root_operator: bool, is_used_for_training: bool + ) -> "SelectiveBuilder": + operators = {} + for op in allow_list: + operators[op] = { + "name": op, + "is_root_operator": is_root_operator, + "is_used_for_training": is_used_for_training, + "include_all_overloads": True, + } + return SelectiveBuilder.from_yaml_dict( + { + "operators": operators, + "include_all_non_op_selectives": True, + } + ) + + def is_operator_selected(self, name: str) -> bool: + if self.include_all_operators: + return True + + if name in self.operators: + return True + name = strip_operator_overload_name(name) + return name in self.operators and self.operators[name].include_all_overloads + + def is_native_function_selected(self, func: NativeFunction) -> bool: + op_name = op_name_from_native_function(func) + return self.is_operator_selected(op_name) + + def is_operator_selected_for_training(self, name: str) -> bool: + if not self.is_operator_selected(name): + return False + if self.include_all_operators: + return True + + not_training_op = SelectiveBuildOperator( + name="", + is_root_operator=False, + is_used_for_training=False, + include_all_overloads=False, + _debug_info=None, + ) + op = not_training_op + if name in self.operators: + op = self.operators[name] + + name = strip_operator_overload_name(name) + base_op = not_training_op + if name in self.operators: + base_op = self.operators[name] + + return op.is_used_for_training or ( + base_op.include_all_overloads and base_op.is_used_for_training + ) + + def is_native_function_selected_for_training(self, func: NativeFunction) -> bool: + op_name = op_name_from_native_function(func) + return self.is_operator_selected_for_training(op_name) + + def is_root_operator(self, name: str) -> bool: + if not self.is_operator_selected(name): + return False + if self.include_all_operators: + return True + + if name in self.operators: + op: SelectiveBuildOperator = self.operators[name] + return op.is_root_operator + name = strip_operator_overload_name(name) + if name not in self.operators: + return False + base_op: SelectiveBuildOperator = self.operators[name] + return base_op.include_all_overloads and base_op.is_root_operator + + def is_kernel_dtype_selected(self, kernel_tag: str, dtype: str) -> bool: + if self.include_all_operators or self.include_all_non_op_selectives: + return True + + return ( + kernel_tag in self.kernel_metadata + and dtype in self.kernel_metadata[kernel_tag] + ) + + def et_get_selected_kernels(self, op_name: str, kernel_key: List[str]) -> List[str]: + """ + Return a list of kernel keys that cover the used ops + """ + # If no kernel metadata, either it's implied by include_all_operators=True or the op is not used. + if op_name not in self.et_kernel_metadata: + return kernel_key if self.include_all_operators else [] + # Otherwise, only return the specific kernel keys. + + result_set = set() + + for model_kernel_keys in self.et_kernel_metadata[op_name]: + key_found = False + for key in kernel_key: + # Don't compare the version for now + if ( + key != "default" + and key.split("/")[1] == model_kernel_keys.split("/")[1] + ): + result_set.add(key) + key_found = True + break + if not key_found: + if "default" not in kernel_key: + raise Exception("Missing kernel for the model") + else: + result_set.add("default") + + return list(result_set) + + def to_dict(self) -> Dict[str, object]: + ret: Dict[str, object] = { + "include_all_non_op_selectives": self.include_all_non_op_selectives, + "include_all_operators": self.include_all_operators, + } + operators = {} + for op_name, op in self.operators.items(): + operators[op_name] = op.to_dict() + ret["operators"] = operators + + if self._debug_info is not None: + ret["debug_info"] = sorted(self._debug_info) + + ret["kernel_metadata"] = { + k: sorted(v) for (k, v) in self.kernel_metadata.items() + } + + ret["et_kernel_metadata"] = self.et_kernel_metadata + + ret["custom_classes"] = sorted(self.custom_classes) + + ret["build_features"] = sorted(self.build_features) + + return ret + + +def merge_kernel_metadata( + lhs: Dict[str, List[str]], + rhs: Dict[str, List[str]], +) -> Dict[str, List[str]]: + kernel_metadata: Dict[str, List[str]] = {} + for tag_name, dtypes in list(lhs.items()) + list(rhs.items()): + dtypes_copy = set(dtypes) + if tag_name in kernel_metadata: + dtypes_copy |= set(kernel_metadata[tag_name]) + + kernel_metadata[tag_name] = list(dtypes_copy) + + return kernel_metadata + + +def merge_et_kernel_metadata( + lhs: Dict[str, List[str]], + rhs: Dict[str, List[str]], +) -> Dict[str, List[str]]: + merge_et_kernel_metadata: Dict[str, Set[str]] = defaultdict(set) + for op in list(lhs.keys()) + list(rhs.keys()): + merge_et_kernel_metadata[op].update(lhs.get(op, [])) + merge_et_kernel_metadata[op].update(rhs.get(op, [])) + + return {op: sorted(val) for op, val in merge_et_kernel_metadata.items()} + + +def combine_selective_builders( + lhs: SelectiveBuilder, rhs: SelectiveBuilder +) -> SelectiveBuilder: + include_all_operators = lhs.include_all_operators or rhs.include_all_operators + debug_info = merge_debug_info(lhs._debug_info, rhs._debug_info) + operators = merge_operator_dicts(lhs.operators, rhs.operators) + kernel_metadata = merge_kernel_metadata(lhs.kernel_metadata, rhs.kernel_metadata) + et_kernel_metadata = merge_et_kernel_metadata( + lhs.et_kernel_metadata, rhs.et_kernel_metadata + ) + include_all_non_op_selectives = ( + lhs.include_all_non_op_selectives or rhs.include_all_non_op_selectives + ) + custom_classes = lhs.custom_classes.union(rhs.custom_classes) + build_features = lhs.build_features.union(rhs.build_features) + return SelectiveBuilder( + include_all_operators, + debug_info, + operators, + kernel_metadata, + et_kernel_metadata, + custom_classes, + build_features, + include_all_non_op_selectives, + ) + + +def op_name_from_native_function(f: NativeFunction) -> str: + # This was originally read from the 'operator_name_with_overload' field in the + # declaration dict, which was the part before the first '(' in 'schema_string'. + return f"{f.namespace}::{f.func.name}" diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/static_runtime/__pycache__/config.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/static_runtime/__pycache__/config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d667fa39c5d7452c7701cb199fc1910d6df148a8 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/static_runtime/__pycache__/config.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/static_runtime/generator.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/static_runtime/generator.py new file mode 100644 index 0000000000000000000000000000000000000000..eb91a4985f0f114386f912da275a876c409472fb --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/static_runtime/generator.py @@ -0,0 +1,796 @@ +import json +import logging + +import math +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import torchgen.api.cpp as cpp +from torchgen.context import native_function_manager +from torchgen.model import ( + Argument, + BackendIndex, + BaseTy, + BaseType, + FunctionSchema, + NativeFunctionsGroup, + NativeFunctionsViewGroup, + OptionalType, + SelfArgument, + TensorOptionsArguments, + Type, +) +from torchgen.static_runtime import config + +logger: logging.Logger = logging.getLogger() + + +def has_alias( + arguments: Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]] +) -> bool: + for arg in arguments: + annotation = getattr(arg, "annotation", None) + if not annotation: + continue + alias_set = getattr(annotation, "alias_set", ()) + if alias_set: + return True + return False + + +BLOCKED_OPS = frozenset( + ( + # non cpu ops + "sparse_sampled_addmm", + "hspmm", + "linalg_svdvals", + # sparse ops + "sspaddmm", + "coalesce", + "_indices", + "indices", + "_values", + "values", + "crow_indices", + "col_indices", + # deprecated ops + "floor_divide", + "ger", + # buggy ops + "conj_physical", # P495807361 + "binary_cross_entropy", # P496394764 + "arccosh", + # uncommon ops + "cholesky", + "lu_solve", + "linalg_cholesky", + "linalg_householder_product", + "linalg_ldl_solve", + "_compute_linear_combination", + # training related ops + "_make_dual", + # cannot call directly + "_fw_primal", + # no documentation + "_index_reduce", + # TODO: these ones got added recently and need manual inspection + "_new_zeros_with_same_feature_meta", + "_conj_physical", + "binary_cross_entropy_with_logits", + "bincount", + "conv_tbc", + "copy", + "_copy_from", + "_copy_from_and_resize", + "count_nonzero", + "cudnn_affine_grid_generator", + "cudnn_affine_grid_generator_backward", + "cudnn_grid_sampler", + "diag_embed", + "embedding", + "embedding_dense_backward", + "_embedding_bag_dense_backward", + "_embedding_bag_per_sample_weights_backward", + "grid_sampler_2d", + "_grid_sampler_2d_cpu_fallback", + "grid_sampler_3d", + "isnan", + "mkldnn_linear", + "median", + "nanmedian", + "_sparse_sparse_matmul", + "batch_norm_backward_elemt", + "_euclidean_dist", + "pixel_shuffle", + "pixel_unshuffle", + "channel_shuffle", + "_reshape_nested_backward", + "relu", + "prelu", + "celu", + "slice_scatter", + "select_scatter", + "diagonal_scatter", + "sum", + "_mkldnn_transpose", + "_nested_tensor_from_mask", + "_nested_from_padded", + "_nested_tensor_size", + "_nested_from_padded_and_nested_example", + "_standard_gamma_grad", + "_dirichlet_grad", + "native_norm", + "_sparse_softmax", + "_sparse_softmax_backward_data", + "_sparse_log_softmax", + "_sparse_log_softmax_backward_data", + "zero", + "_sparse_addmm", + "sparse_mask", + "_sparse_mask_projection", + "_to_dense", + "_coalesce", + "_coalesced", + "copy_sparse_to_sparse", + "to_sparse", + "to_sparse_csr", + "to_sparse_csc", + "to_mkldnn", + "quantize_per_tensor_dynamic", + "quantize_per_channel", + "q_per_channel_scales", + "q_per_channel_zero_points", + "int_repr", + "_make_per_channel_quantized_tensor", + "set", + "lift", + "lift_fresh", + "lift_fresh_copy", + "masked_scatter", + "_masked_softmax", + "_masked_softmax_backward", + "put", + "index_reduce", + "trace", + "_cholesky_solve_helper", + "dist", + "max", + "_torch_cuda_cu_linker_symbol_op", + "glu_jvp", + "glu_backward_jvp", + "hardswish_backward", + "rrelu_with_noise_backward", + "mkldnn_adaptive_avg_pool2d_backward", + "_adaptive_avg_pool2d_backward", + "_adaptive_avg_pool3d_backward", + "isinf", + "linalg_lu_solve", + "linalg_vecdot", + "linalg_matrix_exp", + "linalg_eigvalsh", + "_test_warn_in_autograd", + "_test_autograd_multiple_dispatch_view", + "_test_autograd_multiple_dispatch_view_copy", + "_segment_reduce", + "_segment_reduce_backward", + "_fw_primal_copy", + "_make_dual_copy", + "view_as_real_copy", + "view_as_complex_copy", + "_conj_copy", + "_neg_view_copy", + "diagonal_copy", + "detach_copy", + "squeeze_copy", + "t_copy", + "unsqueeze_copy", + "_indices_copy", + "_values_copy", + "indices_copy", + "values_copy", + "crow_indices_copy", + "col_indices_copy", + "ccol_indices", + "ccol_indices_copy", + "row_indices", + "row_indices_copy", + "unfold_copy", + "alias_copy", + "_triton_multi_head_attention", + "special_airy_ai", + "special_bessel_j0", + "special_bessel_j1", + "special_bessel_y0", + "special_bessel_y1", + "special_chebyshev_polynomial_t", + "special_chebyshev_polynomial_u", + "special_chebyshev_polynomial_v", + "special_chebyshev_polynomial_w", + "special_hermite_polynomial_h", + "special_hermite_polynomial_he", + "special_laguerre_polynomial_l", + "special_legendre_polynomial_p", + "special_modified_bessel_i0", + "special_modified_bessel_i1", + "special_modified_bessel_k0", + "special_modified_bessel_k1", + "special_scaled_modified_bessel_k0", + "special_scaled_modified_bessel_k1", + "special_shifted_chebyshev_polynomial_t", + "special_shifted_chebyshev_polynomial_u", + "special_shifted_chebyshev_polynomial_v", + "special_shifted_chebyshev_polynomial_w", + "special_spherical_bessel_j0", + "_foobar", + "_nested_tensor_strides", + ) +) + + +def is_supported(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bool: + base_op_name = "" + func = None + if isinstance(g, NativeFunctionsViewGroup): + base_op_name = g.view.root_name + func = g.view.func + else: + base_op_name = g.out.func.name.name.base + func = g.out.func + if config.is_hand_written(g): + logger.info("HAND WRITTEN: %s", base_op_name) + return False + if base_op_name in BLOCKED_OPS: + logger.info("BLOCKED: %s", base_op_name) + return False + for arg in func.schema_order_arguments(): + maybe_method = ivalue_type_conversion_method(arg.type) + if not maybe_method: + # Type converting is unsupported yet. + logger.info("NOT SUPPORTED TYPE CONVERTING: %s", func) + return False + + if isinstance(g, NativeFunctionsViewGroup): + # TODO: stop doing type tests by converting to C++ and then testing + # the string, just test the dang thing directly + if "at::Tensor" != cpp.returns_type(func.returns, symint=False).cpp_type(): + # Returns a non-Tensor value. + logger.info("NON-TENSOR RET TYPE: %s", str(func)) + return False + return True + + # For out variant ops, we need to check the arguments of its functional func. + for arg in g.functional.func.schema_order_arguments(): + maybe_method = ivalue_type_conversion_method(arg.type) + if not maybe_method: + # Type converting is unsupported yet. + logger.info("NOT SUPPORTED TYPE CONVERTING: %s", g.functional.func) + return False + + if not g.structured: + # In case of unstructured op, we check if it has out variant implementation. + # The out variant implementation satisfies the minimum requirement that it has the output tensor as the last + # parameter. + if ( + not hasattr(g, "out") + or not str(func).endswith("Tensor(a!) out) -> Tensor(a!)") + or not str(func.name).endswith(".out") + ): + return False + # TODO: stop type testing by converting to C++ + if "at::Tensor &" != cpp.returns_type(func.returns, symint=False).cpp_type(): + logger.info("NON_TENSOR RET TYPE: %s", func) + return False + if has_alias(func.arguments.non_out): + # This op may create an alias of inputs. + logger.info("INPUTS ALIAS: %s", base_op_name) + return False + return True + + +def ivalue_type_conversion_method( + arg_type: Union[BaseType, OptionalType, Type] +) -> Optional[Tuple[bool, str]]: + """ + Return the method call expression of `c10::ivalue' to convert its contained value to + the expected value of `arg_type` type. For example, for `arg_type` == BaseTy.Tensor, + this function returns ".toTensor()", so that it can be appended to the ivalue's + variable name to get the value of the expected type. + """ + type_conversion_methods = { + BaseTy.Tensor: ((True, "toTensor()"), (False, "toOptional()")), + BaseTy.int: ((False, "toInt()"), (False, "toOptional()")), + BaseTy.bool: ((False, "toBool()"), (False, "toOptional()")), + BaseTy.Scalar: ((False, "toScalar()"), (False, "toOptional()")), + BaseTy.ScalarType: ( + (False, "toScalarType()"), + (False, "toOptional()"), + ), + BaseTy.str: ( + (False, "toStringView()"), + (False, "toOptional()"), + ), + } + + base_ty_object = None + if isinstance(arg_type, BaseType): + base_ty_object = arg_type.name + elif isinstance(arg_type, OptionalType): + if not isinstance(arg_type.elem, BaseType): + # ListType is currently unsupported. + return None + base_ty_object = arg_type.elem.name + else: + return None + + if base_ty_object not in type_conversion_methods: + return None + methods = type_conversion_methods[base_ty_object] + if isinstance(arg_type, BaseType): + return methods[0] + return methods[1] + + +should_use_int_tensor_ops_ = frozenset( + ( + "bitwise_not", + "bitwise_and", + "bitwise_or", + "bitwise_xor", + "bitwise_left_shift", + "bitwise_right_shift", + "gcd", + "lcm", + "scatter", + "gather", + "_convert_indices_from_coo_to_csr", + "_convert_indices_from_csr_to_coo", + ) +) +should_use_complex_tensor_ops_ = frozenset(("view_as_real", "imag", "_conj")) + + +def should_use_int_tensor(op_name: str) -> bool: + return op_name in should_use_int_tensor_ops_ + + +def should_use_complex_tensor(op_name: str) -> bool: + return op_name in should_use_complex_tensor_ops_ + + +test_tensor_dim_ops_1_ = frozenset( + ( + "addmv", + "index_add", + "_convert_indices_from_coo_to_csr", + "_convert_indices_from_csr_to_coo", + "nll_loss_backward", + "dot", + "vdot", + "outer", + "ger", + ) +) +test_tensor_dim_ops_2_ = frozenset( + ("addmm", "mm", "nuclear_norm", "diag", "_addmm_activation", "matrix_H", "t") +) + + +def test_tensor_dim(op_name: str) -> int: + if op_name in test_tensor_dim_ops_1_: + return 1 + if op_name in test_tensor_dim_ops_2_: + return 2 + return 3 + + +test_tensor_shapes_string = '{"view_as_complex": "{2, 2}"}' +test_tensor_shape_json: Dict[str, str] = json.loads(test_tensor_shapes_string) + + +def test_tensor_shape(op_name: str) -> str: + if op_name in test_tensor_shape_json: + return test_tensor_shape_json[op_name] + else: + return "" + + +def test_value_expression( + arg_type: Union[BaseType, OptionalType, Type], index: int, op_name: str +) -> str: + tensor_size_ex = test_tensor_shape(op_name) + if tensor_size_ex == "": + num_tensors = 16 if index == 0 else 64 + num_dim = test_tensor_dim(op_name) + size_per_dim = math.ceil(num_tensors / float(num_dim)) + size_per_dim += size_per_dim % 2 + tensor_size_ex = "{%s}" % (",".join([f"{size_per_dim}"] * num_dim)) + if should_use_int_tensor(op_name): + tensor_expression = f"at::randint(1, 100, {tensor_size_ex}, at::kInt)" + elif should_use_complex_tensor(op_name): + tensor_expression = f"at::randn({tensor_size_ex}, at::kComplexFloat)" + else: + tensor_expression = f"at::rand({tensor_size_ex})" + + value_expressions = { + BaseTy.Tensor: tensor_expression, + BaseTy.int: "1", + BaseTy.bool: "false", + BaseTy.Scalar: "2", + BaseTy.ScalarType: "at::ScalarType::Float", + BaseTy.str: '"floor"', + } + + base_ty_object = None + if isinstance(arg_type, BaseType): + base_ty_object = arg_type.name + else: + assert isinstance(arg_type, OptionalType) and isinstance( + arg_type.elem, BaseType + ) + base_ty_object = arg_type.elem.name + assert base_ty_object in value_expressions, "not expected type" + value_expression = value_expressions[base_ty_object] + return value_expression + + +def generate_test_value_definitions(schema: FunctionSchema, index: int) -> str: + assert not schema.is_out_fn() + schema_name = schema.name.name.base + arg_map = {} + for arg in schema.schema_order_arguments(): + test_value_exp = test_value_expression(arg.type, index, schema_name) + arg_map[arg.name] = test_value_exp + config.override_test_values(arg_map, schema_name, index) + arg_populations = [] + for arg_name, arg_value in arg_map.items(): + arg_populations.append(f"auto {arg_name}{index} = {arg_value}") + return ";\n ".join(arg_populations) + ";" + + +def generate_test_value_names(schema: FunctionSchema, index: int) -> str: + assert not schema.is_out_fn() + return ",".join(f"{arg.name}{index}" for arg in schema.schema_order_arguments()) + + +generate_test_ir_arguments_base_ty_to_type_str_ = { + BaseTy.Tensor: "Tensor", + BaseTy.int: "int", + BaseTy.float: "float", + BaseTy.str: "str", + BaseTy.Scalar: "int", + BaseTy.ScalarType: "int", + BaseTy.bool: "bool", +} + + +def generate_test_ir_arguments( + schema: FunctionSchema, +) -> List[Tuple[str, Optional[str]]]: + def ir_argument(arg: Argument) -> Tuple[str, Optional[str]]: + t = arg.type + add_optional = False + if isinstance(t, OptionalType): + t = t.elem + add_optional = True + assert isinstance(t, BaseType) + type_str = None + if t.name in generate_test_ir_arguments_base_ty_to_type_str_: + type_str = generate_test_ir_arguments_base_ty_to_type_str_[t.name] + if type_str and add_optional: + type_str = f"{type_str}?" + return ("%" + arg.name, type_str) + + return [ir_argument(arg) for arg in schema.schema_order_arguments()] + + +def generate_arg_extraction(schema: FunctionSchema) -> str: + arg_populations = [] + for i, arg in enumerate(schema.schema_order_arguments()): + maybe_method = ivalue_type_conversion_method(arg.type) + assert maybe_method + is_reference, type_conversion_method = maybe_method + reference = "&" if is_reference else "" + arg_populations.append( + f"const auto{reference} {arg.name} = p_node->Input({i}).{type_conversion_method}" + ) + return ";\n ".join(arg_populations) + ";" + + +def get_kernel_name(g: NativeFunctionsGroup, backend_index: BackendIndex) -> str: + kernel = backend_index.get_kernel(g.functional) + if g.structured or kernel is None: + return cpp.name(g.functional.func) + return kernel.kernel + + +def get_out_kernel_name(g: NativeFunctionsGroup, backend_index: BackendIndex) -> str: + kernel = backend_index.get_kernel(g.out) + if g.structured or kernel is None: + return cpp.name(g.out.func) + return kernel.kernel + + +def generate_non_out_variant_call( + g: NativeFunctionsGroup, backend_index: BackendIndex +) -> str: + schema = g.functional.func + assert not schema.is_out_fn() + kernel_name = get_kernel_name(g, backend_index) + arg_names = (arg.name for arg in schema.schema_order_arguments()) + namespace_name = "cpu" if g.structured else "native" + return f'at::{namespace_name}::{kernel_name}({",".join(arg_names)})' + + +def generate_call_to_view_ops( + g: NativeFunctionsViewGroup, backend_index: BackendIndex +) -> str: + schema = g.view.func + kernel_name = cpp.name(schema) + kernel = backend_index.get_kernel(g.view) + if kernel: + kernel_name = kernel.kernel + arg_names = (arg.name for arg in schema.schema_order_arguments()) + namespace_name = "native" + return f'at::{namespace_name}::{kernel_name}({",".join(arg_names)})' + + +def generate_out_variant_call( + g: NativeFunctionsGroup, backend_index: BackendIndex +) -> str: + schema = g.out.func + assert schema.is_out_fn() + arg_names = [] + kernel_name = get_out_kernel_name(g, backend_index) + if g.structured: + # structured op starts with the output tensor argument. + arg_names = [out_arg.name for out_arg in schema.arguments.out] + else: + arg_names = [] + for arg in schema.arguments.non_out: + if isinstance(arg, SelfArgument): + arg_names.append(arg.argument.name) + else: + assert isinstance(arg, Argument) + arg_names.append(arg.name) + if not g.structured: + assert len(schema.arguments.out) == 1 + arg_names.append(schema.arguments.out[0].name) + cpp_arg_names = ",".join(arg_names) + namespace_name = "cpu" if g.structured else "native" + return f"at::{namespace_name}::{kernel_name}({cpp_arg_names})" + + +no_memory_resize_ops = frozenset( + ( + "isin.Scalar_Tensor", + "index_add", + "dot", + "vdot", + "nuclear_norm", + "histc", + "l1_loss", + "multi_margin_loss", + "multilabel_margin_loss", + "nll_loss", + "nll_loss2d", + "prod", + ) +) + + +def should_check_resize(schema: FunctionSchema) -> bool: + schema_str = str(schema) + type_variant_op_name = schema_str[: schema_str.find("(")] + return type_variant_op_name not in no_memory_resize_ops + + +def op_name_from_group(g: NativeFunctionsGroup) -> str: + return g.functional.func.name.name.base + + +class GenOpDispatcher: + def out_variant( + self, groups: Sequence[NativeFunctionsGroup], backend_index: BackendIndex + ) -> str: + if not groups: + return "" + generated_type_variants = [] + for g in groups: + with native_function_manager(g): + assert is_supported(g) + assert isinstance(g, NativeFunctionsGroup) + generated_type_variant = self.out_variant_op_generator(g, backend_index) + generated_type_variants.append(generated_type_variant) + op_name = op_name_from_group(groups[0]) + body = "\n".join(generated_type_variants) + generated = f""" +REGISTER_OPERATOR_FUNCTOR( + aten::{op_name}, + aten_{op_name}, + [](Node* n) -> SROperator {{ + {body} + LogAndDumpSchema(n); + return nullptr; + }}); +""" + return generated + + def view( + self, groups: Sequence[NativeFunctionsViewGroup], backend_index: BackendIndex + ) -> str: + if not groups: + return "" + generated_type_variants = [] + for g in groups: + with native_function_manager(g): + assert is_supported(g) + assert isinstance(g, NativeFunctionsViewGroup) + generated_type_variant = self.view_op_generator(g, backend_index) + generated_type_variants.append(generated_type_variant) + op_name = config.func_name_base_str(groups[0]) + body = "\n".join(generated_type_variants) + generated = f""" +REGISTER_NATIVE_OPERATOR_FUNCTOR( + aten::{op_name}, + aten_{op_name}, + [](Node* n) -> SROperator {{ + {body} + LogAndDumpSchema(n); + return nullptr; + }}); +""" + return generated + + def out_variant_op_generator( + self, g: NativeFunctionsGroup, backend_index: BackendIndex + ) -> str: + functional = g.functional + schema = str(functional.func) + populated_argument = generate_arg_extraction(g.functional.func) + functional_variant_call = generate_non_out_variant_call(g, backend_index) + assert len(g.out.func.arguments.out) == 1 + out_variable_name = str(g.out.func.arguments.out[0].name) + out_variant_call = generate_out_variant_call(g, backend_index) + generated = f""" + if (n->matches(torch::schema("aten::{schema}"))) {{ + return [](ProcessedNode* p_node) {{ + {populated_argument} + if (p_node->Output(0).isNone()) {{ + p_node->Output(0) = {functional_variant_call}; + return; + }} + auto& {out_variable_name} = p_node->Output(0).toTensor(); + fastResizeToZero({out_variable_name}); + {out_variant_call}; + }}; + }}""" + return generated + + def view_op_generator( + self, g: NativeFunctionsViewGroup, backend_index: BackendIndex + ) -> str: + schema = str(g.view.func) + populated_argument = generate_arg_extraction(g.view.func) + functional_variant_call = generate_call_to_view_ops(g, backend_index) + generated = f""" + if (n->matches(torch::schema("aten::{schema}"))) {{ + return [](ProcessedNode* p_node) {{ + {populated_argument} + p_node->Output(0) = {functional_variant_call}; + }}; + }}""" + return generated + + +class GenOpTestCase: + def out_variant(self, groups: Sequence[NativeFunctionsGroup]) -> str: + if not groups: + return "" + generated_type_variants = [] + for g in groups: + with native_function_manager(g): + assert is_supported(g) + assert isinstance(g, NativeFunctionsGroup) + generated_type_variant = self.out_variant_op_test_case_generator(g) + generated_type_variants.append(generated_type_variant) + return "\n".join(generated_type_variants) + + def view(self, groups: Sequence[NativeFunctionsViewGroup]) -> str: + if not groups: + return "" + generated_type_variants = [] + for g in groups: + with native_function_manager(g): + assert is_supported(g) + assert isinstance(g, NativeFunctionsViewGroup) + generated_type_variant = self.view_op_test_case_generator(g) + generated_type_variants.append(generated_type_variant) + return "\n".join(generated_type_variants) + + def out_variant_op_test_case_generator(self, g: NativeFunctionsGroup) -> str: + schema = g.functional.func + schema_str = str(schema) + assert schema_str.find("(") > 0 + type_variant_op_name = schema_str[: schema_str.find("(")].replace(".", "_") + op_name = op_name_from_group(g) + assert type_variant_op_name.startswith(op_name) + + arg_types = generate_test_ir_arguments(schema) + arg_declarations = ", ".join( + ( + arg_name if arg_type is None else f"{arg_name}: {arg_type}" + for arg_name, arg_type in arg_types + ) + ) + arg_names = ", ".join((arg_name for arg_name, _ in arg_types)) + assert ( + len(schema.returns) == 1 + and isinstance(schema.returns[0].type, BaseType) + and schema.returns[0].type.name is BaseTy.Tensor + ) + test_value_definitions = generate_test_value_definitions(schema, 0) + test_value_names = generate_test_value_names(schema, 0) + test_value_definitions2 = generate_test_value_definitions(schema, 1) + test_value_names2 = generate_test_value_names(schema, 1) + check_resize = "true" if should_check_resize(schema) else "false" + generated = f""" +TEST(StaticRuntime, autogen_{type_variant_op_name}) {{ + const std::string script = R"IR( + graph({arg_declarations}): + %bias: None = prim::Constant() + %ret = aten::{op_name}({arg_names}) + %cloned = aten::clone(%ret, %bias) + return (%cloned) + )IR"; + + {test_value_definitions} + std::vector args{{{test_value_names}}}; + testStaticRuntime(script, args, {{}}, /*use_allclose=*/false, /*use_equalnan=*/false, /*check_resize=*/{check_resize}); + + {test_value_definitions2} + std::vector args2{{{test_value_names2}}}; + testStaticRuntime(script, args, args2, /*use_allclose=*/false, /*use_equalnan=*/false, /*check_resize=*/{check_resize}); + +}} +""" + return generated + + def view_op_test_case_generator(self, g: NativeFunctionsViewGroup) -> str: + schema = g.view.func + schema_str = str(schema) + assert schema_str.find("(") > 0 + type_variant_op_name = schema_str[: schema_str.find("(")].replace(".", "_") + op_name = g.view.root_name + assert type_variant_op_name.startswith(op_name) + + arg_types = generate_test_ir_arguments(schema) + arg_declarations = ", ".join( + ( + arg_name if arg_type is None else f"{arg_name}: {arg_type}" + for arg_name, arg_type in arg_types + ) + ) + arg_names = ", ".join((arg_name for arg_name, _ in arg_types)) + assert ( + len(schema.returns) == 1 + and isinstance(schema.returns[0].type, BaseType) + and schema.returns[0].type.name is BaseTy.Tensor + ) + test_value_definitions = generate_test_value_definitions(schema, 0) + test_value_names = generate_test_value_names(schema, 0) + generated = f""" +TEST(StaticRuntime, autogen_{type_variant_op_name}) {{ + const std::string script = R"IR( + graph({arg_declarations}): + %bias: None = prim::Constant() + %ret = aten::{op_name}({arg_names}) + %cloned = aten::clone(%ret, %bias) + return (%cloned) + )IR"; + + {test_value_definitions} + std::vector args{{{test_value_names}}}; + testStaticRuntime(script, args); +}} +""" + + return generated