koichi12 commited on
Commit
1c6d8d5
·
verified ·
1 Parent(s): e0a38d8

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/TestCyCache.py +119 -0
  3. tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/TestDependencies.py +142 -0
  4. tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/TestIpythonMagic.py +295 -0
  5. tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/TestStripLiterals.py +56 -0
  6. tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestCythonizeArgsParser.cpython-311.pyc +0 -0
  7. tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestDependencies.cpython-311.pyc +0 -0
  8. tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestInline.cpython-311.pyc +0 -0
  9. tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestIpythonMagic.cpython-311.pyc +0 -0
  10. tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestRecythonize.cpython-311.pyc +0 -0
  11. tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestStripLiterals.cpython-311.pyc +0 -0
  12. tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/Buffer.py +749 -0
  13. tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/Code.py +0 -0
  14. tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/FusedNode.cpython-311-x86_64-linux-gnu.so +3 -0
  15. tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/Parsing.py +0 -0
  16. tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__init__.py +1 -0
  17. tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Lexicon.cpython-311.pyc +0 -0
  18. tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Naming.cpython-311.pyc +0 -0
  19. tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Options.cpython-311.pyc +0 -0
  20. tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Pythran.cpython-311.pyc +0 -0
  21. tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Scanning.cpython-311.pyc +0 -0
  22. tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/UtilityCode.cpython-311.pyc +0 -0
  23. tuning-competition-baseline/.venv/lib/python3.11/site-packages/networkx/generators/atlas.dat.gz +3 -0
  24. tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pyproject_hooks/__init__.py +23 -0
  25. tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pyproject_hooks/__pycache__/_compat.cpython-311.pyc +0 -0
  26. tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pyproject_hooks/_in_process/__init__.py +18 -0
  27. tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pyproject_hooks/_in_process/__pycache__/__init__.cpython-311.pyc +0 -0
  28. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/code_template.cpython-311.pyc +0 -0
  29. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/context.cpython-311.pyc +0 -0
  30. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_aoti_c_shim.cpython-311.pyc +0 -0
  31. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_backend_stubs.cpython-311.pyc +0 -0
  32. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_executorch.cpython-311.pyc +0 -0
  33. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_functionalization_type.cpython-311.pyc +0 -0
  34. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_lazy_tensor.cpython-311.pyc +0 -0
  35. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_vmap_plumbing.cpython-311.pyc +0 -0
  36. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/local.cpython-311.pyc +0 -0
  37. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/native_function_generation.cpython-311.pyc +0 -0
  38. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/yaml_utils.cpython-311.pyc +0 -0
  39. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/__init__.py +19 -0
  40. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/__pycache__/lazy_ts_lowering.cpython-311.pyc +0 -0
  41. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/__pycache__/ufunc.cpython-311.pyc +0 -0
  42. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/lazy_ir.py +707 -0
  43. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/lazy_ts_lowering.py +48 -0
  44. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/native_functions.py +64 -0
  45. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/register_dispatch_key.py +989 -0
  46. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/ufunc.py +545 -0
  47. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/__init__.py +0 -0
  48. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/__pycache__/__init__.cpython-311.pyc +0 -0
  49. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/__pycache__/parse.cpython-311.pyc +0 -0
  50. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/api/__pycache__/__init__.cpython-311.pyc +0 -0
.gitattributes CHANGED
@@ -38,3 +38,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/networkx/algorith
38
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Symtab.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
39
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Plex/Scanners.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
40
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Plex/DFA.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
 
 
38
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Symtab.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
39
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Plex/Scanners.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
40
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Plex/DFA.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
41
+ tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/FusedNode.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/TestCyCache.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import difflib
2
+ import glob
3
+ import gzip
4
+ import os
5
+ import sys
6
+ import tempfile
7
+ import unittest
8
+
9
+ import Cython.Build.Dependencies
10
+ import Cython.Utils
11
+ from Cython.TestUtils import CythonTest
12
+
13
+
14
+ class TestCyCache(CythonTest):
15
+
16
+ def setUp(self):
17
+ CythonTest.setUp(self)
18
+ self.temp_dir = tempfile.mkdtemp(
19
+ prefix='cycache-test',
20
+ dir='TEST_TMP' if os.path.isdir('TEST_TMP') else None)
21
+ self.src_dir = tempfile.mkdtemp(prefix='src', dir=self.temp_dir)
22
+ self.cache_dir = tempfile.mkdtemp(prefix='cache', dir=self.temp_dir)
23
+
24
+ def cache_files(self, file_glob):
25
+ return glob.glob(os.path.join(self.cache_dir, file_glob))
26
+
27
+ def fresh_cythonize(self, *args, **kwargs):
28
+ Cython.Utils.clear_function_caches()
29
+ Cython.Build.Dependencies._dep_tree = None # discard method caches
30
+ Cython.Build.Dependencies.cythonize(*args, **kwargs)
31
+
32
+ def test_cycache_switch(self):
33
+ content1 = 'value = 1\n'
34
+ content2 = 'value = 2\n'
35
+ a_pyx = os.path.join(self.src_dir, 'a.pyx')
36
+ a_c = a_pyx[:-4] + '.c'
37
+
38
+ with open(a_pyx, 'w') as f:
39
+ f.write(content1)
40
+ self.fresh_cythonize(a_pyx, cache=self.cache_dir)
41
+ self.fresh_cythonize(a_pyx, cache=self.cache_dir)
42
+ self.assertEqual(1, len(self.cache_files('a.c*')))
43
+ with open(a_c) as f:
44
+ a_contents1 = f.read()
45
+ os.unlink(a_c)
46
+
47
+ with open(a_pyx, 'w') as f:
48
+ f.write(content2)
49
+ self.fresh_cythonize(a_pyx, cache=self.cache_dir)
50
+ with open(a_c) as f:
51
+ a_contents2 = f.read()
52
+ os.unlink(a_c)
53
+
54
+ self.assertNotEqual(a_contents1, a_contents2, 'C file not changed!')
55
+ self.assertEqual(2, len(self.cache_files('a.c*')))
56
+
57
+ with open(a_pyx, 'w') as f:
58
+ f.write(content1)
59
+ self.fresh_cythonize(a_pyx, cache=self.cache_dir)
60
+ self.assertEqual(2, len(self.cache_files('a.c*')))
61
+ with open(a_c) as f:
62
+ a_contents = f.read()
63
+ self.assertEqual(
64
+ a_contents, a_contents1,
65
+ msg='\n'.join(list(difflib.unified_diff(
66
+ a_contents.split('\n'), a_contents1.split('\n')))[:10]))
67
+
68
+ def test_cycache_uses_cache(self):
69
+ a_pyx = os.path.join(self.src_dir, 'a.pyx')
70
+ a_c = a_pyx[:-4] + '.c'
71
+ with open(a_pyx, 'w') as f:
72
+ f.write('pass')
73
+ self.fresh_cythonize(a_pyx, cache=self.cache_dir)
74
+ a_cache = os.path.join(self.cache_dir, os.listdir(self.cache_dir)[0])
75
+ with gzip.GzipFile(a_cache, 'wb') as gzipfile:
76
+ gzipfile.write('fake stuff'.encode('ascii'))
77
+ os.unlink(a_c)
78
+ self.fresh_cythonize(a_pyx, cache=self.cache_dir)
79
+ with open(a_c) as f:
80
+ a_contents = f.read()
81
+ self.assertEqual(a_contents, 'fake stuff',
82
+ 'Unexpected contents: %s...' % a_contents[:100])
83
+
84
+ def test_multi_file_output(self):
85
+ a_pyx = os.path.join(self.src_dir, 'a.pyx')
86
+ a_c = a_pyx[:-4] + '.c'
87
+ a_h = a_pyx[:-4] + '.h'
88
+ a_api_h = a_pyx[:-4] + '_api.h'
89
+ with open(a_pyx, 'w') as f:
90
+ f.write('cdef public api int foo(int x): return x\n')
91
+ self.fresh_cythonize(a_pyx, cache=self.cache_dir)
92
+ expected = [a_c, a_h, a_api_h]
93
+ for output in expected:
94
+ self.assertTrue(os.path.exists(output), output)
95
+ os.unlink(output)
96
+ self.fresh_cythonize(a_pyx, cache=self.cache_dir)
97
+ for output in expected:
98
+ self.assertTrue(os.path.exists(output), output)
99
+
100
+ def test_options_invalidation(self):
101
+ hash_pyx = os.path.join(self.src_dir, 'options.pyx')
102
+ hash_c = hash_pyx[:-len('.pyx')] + '.c'
103
+
104
+ with open(hash_pyx, 'w') as f:
105
+ f.write('pass')
106
+ self.fresh_cythonize(hash_pyx, cache=self.cache_dir, cplus=False)
107
+ self.assertEqual(1, len(self.cache_files('options.c*')))
108
+
109
+ os.unlink(hash_c)
110
+ self.fresh_cythonize(hash_pyx, cache=self.cache_dir, cplus=True)
111
+ self.assertEqual(2, len(self.cache_files('options.c*')))
112
+
113
+ os.unlink(hash_c)
114
+ self.fresh_cythonize(hash_pyx, cache=self.cache_dir, cplus=False, show_version=False)
115
+ self.assertEqual(2, len(self.cache_files('options.c*')))
116
+
117
+ os.unlink(hash_c)
118
+ self.fresh_cythonize(hash_pyx, cache=self.cache_dir, cplus=False, show_version=True)
119
+ self.assertEqual(2, len(self.cache_files('options.c*')))
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/TestDependencies.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import os.path
3
+ import sys
4
+ import tempfile
5
+ import unittest
6
+ from io import open
7
+ from os.path import join as pjoin
8
+
9
+ from ..Dependencies import extended_iglob
10
+
11
+
12
+ @contextlib.contextmanager
13
+ def writable_file(dir_path, filename):
14
+ with open(pjoin(dir_path, filename), "w", encoding="utf8") as f:
15
+ yield f
16
+
17
+
18
+ class TestGlobbing(unittest.TestCase):
19
+ @classmethod
20
+ def setUpClass(cls):
21
+ cls._orig_dir = os.getcwd()
22
+ if sys.version_info[0] < 3:
23
+ temp_path = cls._tmpdir = tempfile.mkdtemp()
24
+ else:
25
+ cls._tmpdir = tempfile.TemporaryDirectory()
26
+ temp_path = cls._tmpdir.name
27
+ os.chdir(temp_path)
28
+
29
+ for dir1 in "abcd":
30
+ for dir1x in [dir1, dir1 + 'x']:
31
+ for dir2 in "xyz":
32
+ dir_path = pjoin(dir1x, dir2)
33
+ os.makedirs(dir_path)
34
+ with writable_file(dir_path, "file2_pyx.pyx") as f:
35
+ f.write(u'""" PYX """')
36
+ with writable_file(dir_path, "file2_py.py") as f:
37
+ f.write(u'""" PY """')
38
+
39
+ with writable_file(dir1x, "file1_pyx.pyx") as f:
40
+ f.write(u'""" PYX """')
41
+ with writable_file(dir1x, "file1_py.py") as f:
42
+ f.write(u'""" PY """')
43
+
44
+ @classmethod
45
+ def tearDownClass(cls):
46
+ os.chdir(cls._orig_dir)
47
+ if sys.version_info[0] < 3:
48
+ import shutil
49
+ shutil.rmtree(cls._tmpdir)
50
+ else:
51
+ cls._tmpdir.cleanup()
52
+
53
+ def files_equal(self, pattern, expected_files):
54
+ expected_files = sorted(expected_files)
55
+ # It's the users's choice whether '/' will appear on Windows.
56
+ matched_files = sorted(path.replace('/', os.sep) for path in extended_iglob(pattern))
57
+ self.assertListEqual(matched_files, expected_files) # /
58
+
59
+ # Special case for Windows: also support '\' in patterns.
60
+ if os.sep == '\\' and '/' in pattern:
61
+ matched_files = sorted(extended_iglob(pattern.replace('/', '\\')))
62
+ self.assertListEqual(matched_files, expected_files) # \
63
+
64
+ def test_extended_iglob_simple(self):
65
+ ax_files = [pjoin("a", "x", "file2_pyx.pyx"), pjoin("a", "x", "file2_py.py")]
66
+ self.files_equal("a/x/*", ax_files)
67
+ self.files_equal("a/x/*.c12", [])
68
+ self.files_equal("a/x/*.{py,pyx,c12}", ax_files)
69
+ self.files_equal("a/x/*.{py,pyx}", ax_files)
70
+ self.files_equal("a/x/*.{pyx}", ax_files[:1])
71
+ self.files_equal("a/x/*.pyx", ax_files[:1])
72
+ self.files_equal("a/x/*.{py}", ax_files[1:])
73
+ self.files_equal("a/x/*.py", ax_files[1:])
74
+
75
+ def test_extended_iglob_simple_star(self):
76
+ for basedir in "ad":
77
+ files = [
78
+ pjoin(basedir, dirname, filename)
79
+ for dirname in "xyz"
80
+ for filename in ["file2_pyx.pyx", "file2_py.py"]
81
+ ]
82
+ self.files_equal(basedir + "/*/*", files)
83
+ self.files_equal(basedir + "/*/*.c12", [])
84
+ self.files_equal(basedir + "/*/*.{py,pyx,c12}", files)
85
+ self.files_equal(basedir + "/*/*.{py,pyx}", files)
86
+ self.files_equal(basedir + "/*/*.{pyx}", files[::2])
87
+ self.files_equal(basedir + "/*/*.pyx", files[::2])
88
+ self.files_equal(basedir + "/*/*.{py}", files[1::2])
89
+ self.files_equal(basedir + "/*/*.py", files[1::2])
90
+
91
+ for subdir in "xy*":
92
+ files = [
93
+ pjoin(basedir, dirname, filename)
94
+ for dirname in "xyz"
95
+ if subdir in ('*', dirname)
96
+ for filename in ["file2_pyx.pyx", "file2_py.py"]
97
+ ]
98
+ path = basedir + '/' + subdir + '/'
99
+ self.files_equal(path + "*", files)
100
+ self.files_equal(path + "*.{py,pyx}", files)
101
+ self.files_equal(path + "*.{pyx}", files[::2])
102
+ self.files_equal(path + "*.pyx", files[::2])
103
+ self.files_equal(path + "*.{py}", files[1::2])
104
+ self.files_equal(path + "*.py", files[1::2])
105
+
106
+ def test_extended_iglob_double_star(self):
107
+ basedirs = os.listdir(".")
108
+ files = [
109
+ pjoin(basedir, dirname, filename)
110
+ for basedir in basedirs
111
+ for dirname in "xyz"
112
+ for filename in ["file2_pyx.pyx", "file2_py.py"]
113
+ ]
114
+ all_files = [
115
+ pjoin(basedir, filename)
116
+ for basedir in basedirs
117
+ for filename in ["file1_pyx.pyx", "file1_py.py"]
118
+ ] + files
119
+ self.files_equal("*/*/*", files)
120
+ self.files_equal("*/*/**/*", files)
121
+ self.files_equal("*/**/*.*", all_files)
122
+ self.files_equal("**/*.*", all_files)
123
+ self.files_equal("*/**/*.c12", [])
124
+ self.files_equal("**/*.c12", [])
125
+ self.files_equal("*/*/*.{py,pyx,c12}", files)
126
+ self.files_equal("*/*/**/*.{py,pyx,c12}", files)
127
+ self.files_equal("*/**/*/*.{py,pyx,c12}", files)
128
+ self.files_equal("**/*/*/*.{py,pyx,c12}", files)
129
+ self.files_equal("**/*.{py,pyx,c12}", all_files)
130
+ self.files_equal("*/*/*.{py,pyx}", files)
131
+ self.files_equal("**/*/*/*.{py,pyx}", files)
132
+ self.files_equal("*/**/*/*.{py,pyx}", files)
133
+ self.files_equal("**/*.{py,pyx}", all_files)
134
+ self.files_equal("*/*/*.{pyx}", files[::2])
135
+ self.files_equal("**/*.{pyx}", all_files[::2])
136
+ self.files_equal("*/**/*/*.pyx", files[::2])
137
+ self.files_equal("*/*/*.pyx", files[::2])
138
+ self.files_equal("**/*.pyx", all_files[::2])
139
+ self.files_equal("*/*/*.{py}", files[1::2])
140
+ self.files_equal("**/*.{py}", all_files[1::2])
141
+ self.files_equal("*/*/*.py", files[1::2])
142
+ self.files_equal("**/*.py", all_files[1::2])
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/TestIpythonMagic.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # tag: ipython
3
+
4
+ """Tests for the Cython magics extension."""
5
+
6
+ from __future__ import absolute_import
7
+
8
+ import os
9
+ import io
10
+ import sys
11
+ from contextlib import contextmanager
12
+ from unittest import skipIf
13
+
14
+ from Cython.Build import IpythonMagic
15
+ from Cython.TestUtils import CythonTest
16
+ from Cython.Compiler.Annotate import AnnotationCCodeWriter
17
+
18
+ try:
19
+ import IPython.testing.globalipapp
20
+ except ImportError:
21
+ # Disable tests and fake helpers for initialisation below.
22
+ def skip_if_not_installed(_):
23
+ return None
24
+ else:
25
+ def skip_if_not_installed(c):
26
+ return c
27
+
28
+ # not using IPython's decorators here because they depend on "nose"
29
+ skip_win32 = skipIf(sys.platform == 'win32', "Skip on Windows")
30
+ skip_py27 = skipIf(sys.version_info[:2] == (2,7), "Disabled in Py2.7")
31
+
32
+ try:
33
+ # disable IPython history thread before it gets started to avoid having to clean it up
34
+ from IPython.core.history import HistoryManager
35
+ HistoryManager.enabled = False
36
+ except ImportError:
37
+ pass
38
+
39
+
40
+ @contextmanager
41
+ def capture_output():
42
+ backup = sys.stdout, sys.stderr
43
+ try:
44
+ replacement = [
45
+ io.TextIOWrapper(io.BytesIO(), encoding=sys.stdout.encoding),
46
+ io.TextIOWrapper(io.BytesIO(), encoding=sys.stderr.encoding),
47
+ ]
48
+ sys.stdout, sys.stderr = replacement
49
+ output = []
50
+ yield output
51
+ finally:
52
+ sys.stdout, sys.stderr = backup
53
+ for wrapper in replacement:
54
+ wrapper.seek(0) # rewind
55
+ output.append(wrapper.read())
56
+ wrapper.close()
57
+
58
+
59
+ code = u"""\
60
+ def f(x):
61
+ return 2*x
62
+ """
63
+
64
+ cython3_code = u"""\
65
+ def f(int x):
66
+ return 2 / x
67
+
68
+ def call(x):
69
+ return f(*(x,))
70
+ """
71
+
72
+ pgo_cython3_code = cython3_code + u"""\
73
+ def main():
74
+ for _ in range(100): call(5)
75
+ main()
76
+ """
77
+
78
+ compile_error_code = u'''\
79
+ cdef extern from *:
80
+ """
81
+ xxx a=1;
82
+ """
83
+ int a;
84
+ def doit():
85
+ return a
86
+ '''
87
+
88
+ compile_warning_code = u'''\
89
+ cdef extern from *:
90
+ """
91
+ #pragma message ( "CWarning" )
92
+ int a = 42;
93
+ """
94
+ int a;
95
+ def doit():
96
+ return a
97
+ '''
98
+
99
+
100
+ @skip_if_not_installed
101
+ class TestIPythonMagic(CythonTest):
102
+
103
+ @classmethod
104
+ def setUpClass(cls):
105
+ CythonTest.setUpClass()
106
+ cls._ip = IPython.testing.globalipapp.get_ipython()
107
+
108
+ def setUp(self):
109
+ CythonTest.setUp(self)
110
+ self._ip.extension_manager.load_extension('cython')
111
+
112
+ def test_cython_inline(self):
113
+ ip = self._ip
114
+ ip.ex('a=10; b=20')
115
+ result = ip.run_cell_magic('cython_inline', '', 'return a+b')
116
+ self.assertEqual(result, 30)
117
+
118
+ @skip_win32
119
+ def test_cython_pyximport(self):
120
+ ip = self._ip
121
+ module_name = '_test_cython_pyximport'
122
+ ip.run_cell_magic('cython_pyximport', module_name, code)
123
+ ip.ex('g = f(10)')
124
+ self.assertEqual(ip.user_ns['g'], 20.0)
125
+ ip.run_cell_magic('cython_pyximport', module_name, code)
126
+ ip.ex('h = f(-10)')
127
+ self.assertEqual(ip.user_ns['h'], -20.0)
128
+ try:
129
+ os.remove(module_name + '.pyx')
130
+ except OSError:
131
+ pass
132
+
133
+ def test_cython(self):
134
+ ip = self._ip
135
+ ip.run_cell_magic('cython', '', code)
136
+ ip.ex('g = f(10)')
137
+ self.assertEqual(ip.user_ns['g'], 20.0)
138
+
139
+ def test_cython_name(self):
140
+ # The Cython module named 'mymodule' defines the function f.
141
+ ip = self._ip
142
+ ip.run_cell_magic('cython', '--name=mymodule', code)
143
+ # This module can now be imported in the interactive namespace.
144
+ ip.ex('import mymodule; g = mymodule.f(10)')
145
+ self.assertEqual(ip.user_ns['g'], 20.0)
146
+
147
+ def test_cython_language_level(self):
148
+ # The Cython cell defines the functions f() and call().
149
+ ip = self._ip
150
+ ip.run_cell_magic('cython', '', cython3_code)
151
+ ip.ex('g = f(10); h = call(10)')
152
+ if sys.version_info[0] < 3:
153
+ self.assertEqual(ip.user_ns['g'], 2 // 10)
154
+ self.assertEqual(ip.user_ns['h'], 2 // 10)
155
+ else:
156
+ self.assertEqual(ip.user_ns['g'], 2.0 / 10.0)
157
+ self.assertEqual(ip.user_ns['h'], 2.0 / 10.0)
158
+
159
+ def test_cython3(self):
160
+ # The Cython cell defines the functions f() and call().
161
+ ip = self._ip
162
+ ip.run_cell_magic('cython', '-3', cython3_code)
163
+ ip.ex('g = f(10); h = call(10)')
164
+ self.assertEqual(ip.user_ns['g'], 2.0 / 10.0)
165
+ self.assertEqual(ip.user_ns['h'], 2.0 / 10.0)
166
+
167
+ def test_cython2(self):
168
+ # The Cython cell defines the functions f() and call().
169
+ ip = self._ip
170
+ ip.run_cell_magic('cython', '-2', cython3_code)
171
+ ip.ex('g = f(10); h = call(10)')
172
+ self.assertEqual(ip.user_ns['g'], 2 // 10)
173
+ self.assertEqual(ip.user_ns['h'], 2 // 10)
174
+
175
+ def test_cython_compile_error_shown(self):
176
+ ip = self._ip
177
+ with capture_output() as out:
178
+ ip.run_cell_magic('cython', '-3', compile_error_code)
179
+ captured_out, captured_err = out
180
+
181
+ # it could be that c-level output is captured by distutil-extension
182
+ # (and not by us) and is printed to stdout:
183
+ captured_all = captured_out + "\n" + captured_err
184
+ self.assertTrue("error" in captured_all, msg="error in " + captured_all)
185
+
186
+ def test_cython_link_error_shown(self):
187
+ ip = self._ip
188
+ with capture_output() as out:
189
+ ip.run_cell_magic('cython', '-3 -l=xxxxxxxx', code)
190
+ captured_out, captured_err = out
191
+
192
+ # it could be that c-level output is captured by distutil-extension
193
+ # (and not by us) and is printed to stdout:
194
+ captured_all = captured_out + "\n!" + captured_err
195
+ self.assertTrue("error" in captured_all, msg="error in " + captured_all)
196
+
197
+ def test_cython_warning_shown(self):
198
+ ip = self._ip
199
+ with capture_output() as out:
200
+ # force rebuild, otherwise no warning as after the first success
201
+ # no build step is performed
202
+ ip.run_cell_magic('cython', '-3 -f', compile_warning_code)
203
+ captured_out, captured_err = out
204
+
205
+ # check that warning was printed to stdout even if build hasn't failed
206
+ self.assertTrue("CWarning" in captured_out)
207
+
208
+ @skip_py27 # Not strictly broken in Py2.7 but currently fails in CI due to C compiler issues.
209
+ @skip_win32
210
+ def test_cython3_pgo(self):
211
+ # The Cython cell defines the functions f() and call().
212
+ ip = self._ip
213
+ ip.run_cell_magic('cython', '-3 --pgo', pgo_cython3_code)
214
+ ip.ex('g = f(10); h = call(10); main()')
215
+ self.assertEqual(ip.user_ns['g'], 2.0 / 10.0)
216
+ self.assertEqual(ip.user_ns['h'], 2.0 / 10.0)
217
+
218
+ @skip_win32
219
+ def test_extlibs(self):
220
+ ip = self._ip
221
+ code = u"""
222
+ from libc.math cimport sin
223
+ x = sin(0.0)
224
+ """
225
+ ip.user_ns['x'] = 1
226
+ ip.run_cell_magic('cython', '-l m', code)
227
+ self.assertEqual(ip.user_ns['x'], 0)
228
+
229
+
230
+ def test_cython_verbose(self):
231
+ ip = self._ip
232
+ ip.run_cell_magic('cython', '--verbose', code)
233
+ ip.ex('g = f(10)')
234
+ self.assertEqual(ip.user_ns['g'], 20.0)
235
+
236
+ def test_cython_verbose_thresholds(self):
237
+ @contextmanager
238
+ def mock_distutils():
239
+ class MockLog:
240
+ DEBUG = 1
241
+ INFO = 2
242
+ thresholds = [INFO]
243
+
244
+ def set_threshold(self, val):
245
+ self.thresholds.append(val)
246
+ return self.thresholds[-2]
247
+
248
+
249
+ new_log = MockLog()
250
+ old_log = IpythonMagic.distutils.log
251
+ try:
252
+ IpythonMagic.distutils.log = new_log
253
+ yield new_log
254
+ finally:
255
+ IpythonMagic.distutils.log = old_log
256
+
257
+ ip = self._ip
258
+ with mock_distutils() as verbose_log:
259
+ ip.run_cell_magic('cython', '--verbose', code)
260
+ ip.ex('g = f(10)')
261
+ self.assertEqual(ip.user_ns['g'], 20.0)
262
+ self.assertEqual([verbose_log.INFO, verbose_log.DEBUG, verbose_log.INFO],
263
+ verbose_log.thresholds)
264
+
265
+ with mock_distutils() as normal_log:
266
+ ip.run_cell_magic('cython', '', code)
267
+ ip.ex('g = f(10)')
268
+ self.assertEqual(ip.user_ns['g'], 20.0)
269
+ self.assertEqual([normal_log.INFO], normal_log.thresholds)
270
+
271
+ def test_cython_no_annotate(self):
272
+ ip = self._ip
273
+ html = ip.run_cell_magic('cython', '', code)
274
+ self.assertTrue(html is None)
275
+
276
+ def test_cython_annotate(self):
277
+ ip = self._ip
278
+ html = ip.run_cell_magic('cython', '--annotate', code)
279
+ # somewhat brittle way to differentiate between annotated htmls
280
+ # with/without complete source code:
281
+ self.assertTrue(AnnotationCCodeWriter.COMPLETE_CODE_TITLE not in html.data)
282
+
283
+ def test_cython_annotate_default(self):
284
+ ip = self._ip
285
+ html = ip.run_cell_magic('cython', '-a', code)
286
+ # somewhat brittle way to differentiate between annotated htmls
287
+ # with/without complete source code:
288
+ self.assertTrue(AnnotationCCodeWriter.COMPLETE_CODE_TITLE not in html.data)
289
+
290
+ def test_cython_annotate_complete_c_code(self):
291
+ ip = self._ip
292
+ html = ip.run_cell_magic('cython', '--annotate-fullc', code)
293
+ # somewhat brittle way to differentiate between annotated htmls
294
+ # with/without complete source code:
295
+ self.assertTrue(AnnotationCCodeWriter.COMPLETE_CODE_TITLE in html.data)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/TestStripLiterals.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from Cython.Build.Dependencies import strip_string_literals
2
+
3
+ from Cython.TestUtils import CythonTest
4
+
5
+ class TestStripLiterals(CythonTest):
6
+
7
+ def t(self, before, expected):
8
+ actual, literals = strip_string_literals(before, prefix="_L")
9
+ self.assertEqual(expected, actual)
10
+ for key, value in literals.items():
11
+ actual = actual.replace(key, value)
12
+ self.assertEqual(before, actual)
13
+
14
+ def test_empty(self):
15
+ self.t("", "")
16
+
17
+ def test_single_quote(self):
18
+ self.t("'x'", "'_L1_'")
19
+
20
+ def test_double_quote(self):
21
+ self.t('"x"', '"_L1_"')
22
+
23
+ def test_nested_quotes(self):
24
+ self.t(""" '"' "'" """, """ '_L1_' "_L2_" """)
25
+
26
+ def test_triple_quote(self):
27
+ self.t(" '''a\n''' ", " '''_L1_''' ")
28
+
29
+ def test_backslash(self):
30
+ self.t(r"'a\'b'", "'_L1_'")
31
+ self.t(r"'a\\'", "'_L1_'")
32
+ self.t(r"'a\\\'b'", "'_L1_'")
33
+
34
+ def test_unicode(self):
35
+ self.t("u'abc'", "u'_L1_'")
36
+
37
+ def test_raw(self):
38
+ self.t(r"r'abc\\'", "r'_L1_'")
39
+
40
+ def test_raw_unicode(self):
41
+ self.t(r"ru'abc\\'", "ru'_L1_'")
42
+
43
+ def test_comment(self):
44
+ self.t("abc # foo", "abc #_L1_")
45
+
46
+ def test_comment_and_quote(self):
47
+ self.t("abc # 'x'", "abc #_L1_")
48
+ self.t("'abc#'", "'_L1_'")
49
+
50
+ def test_include(self):
51
+ self.t("include 'a.pxi' # something here",
52
+ "include '_L1_' #_L2_")
53
+
54
+ def test_extern(self):
55
+ self.t("cdef extern from 'a.h': # comment",
56
+ "cdef extern from '_L1_': #_L2_")
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestCythonizeArgsParser.cpython-311.pyc ADDED
Binary file (41.1 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestDependencies.cpython-311.pyc ADDED
Binary file (11.8 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestInline.cpython-311.pyc ADDED
Binary file (7.92 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestIpythonMagic.cpython-311.pyc ADDED
Binary file (17 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestRecythonize.cpython-311.pyc ADDED
Binary file (13.5 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestStripLiterals.cpython-311.pyc ADDED
Binary file (4.57 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/Buffer.py ADDED
@@ -0,0 +1,749 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+
3
+ from .Visitor import CythonTransform
4
+ from .ModuleNode import ModuleNode
5
+ from .Errors import CompileError
6
+ from .UtilityCode import CythonUtilityCode
7
+ from .Code import UtilityCode, TempitaUtilityCode
8
+
9
+ from . import Options
10
+ from . import Interpreter
11
+ from . import PyrexTypes
12
+ from . import Naming
13
+ from . import Symtab
14
+
15
+ def dedent(text, reindent=0):
16
+ from textwrap import dedent
17
+ text = dedent(text)
18
+ if reindent > 0:
19
+ indent = " " * reindent
20
+ text = '\n'.join([indent + x for x in text.split('\n')])
21
+ return text
22
+
23
+ class IntroduceBufferAuxiliaryVars(CythonTransform):
24
+
25
+ #
26
+ # Entry point
27
+ #
28
+
29
+ buffers_exists = False
30
+ using_memoryview = False
31
+
32
+ def __call__(self, node):
33
+ assert isinstance(node, ModuleNode)
34
+ self.max_ndim = 0
35
+ result = super(IntroduceBufferAuxiliaryVars, self).__call__(node)
36
+ if self.buffers_exists:
37
+ use_bufstruct_declare_code(node.scope)
38
+ use_py2_buffer_functions(node.scope)
39
+
40
+ return result
41
+
42
+
43
+ #
44
+ # Basic operations for transforms
45
+ #
46
+ def handle_scope(self, node, scope):
47
+ # For all buffers, insert extra variables in the scope.
48
+ # The variables are also accessible from the buffer_info
49
+ # on the buffer entry
50
+ scope_items = scope.entries.items()
51
+ bufvars = [entry for name, entry in scope_items if entry.type.is_buffer]
52
+ if len(bufvars) > 0:
53
+ bufvars.sort(key=lambda entry: entry.name)
54
+ self.buffers_exists = True
55
+
56
+ memviewslicevars = [entry for name, entry in scope_items if entry.type.is_memoryviewslice]
57
+ if len(memviewslicevars) > 0:
58
+ self.buffers_exists = True
59
+
60
+
61
+ for (name, entry) in scope_items:
62
+ if name == 'memoryview' and isinstance(entry.utility_code_definition, CythonUtilityCode):
63
+ self.using_memoryview = True
64
+ break
65
+ del scope_items
66
+
67
+ if isinstance(node, ModuleNode) and len(bufvars) > 0:
68
+ # for now...note that pos is wrong
69
+ raise CompileError(node.pos, "Buffer vars not allowed in module scope")
70
+ for entry in bufvars:
71
+ if entry.type.dtype.is_ptr:
72
+ raise CompileError(node.pos, "Buffers with pointer types not yet supported.")
73
+
74
+ name = entry.name
75
+ buftype = entry.type
76
+ if buftype.ndim > Options.buffer_max_dims:
77
+ raise CompileError(node.pos,
78
+ "Buffer ndims exceeds Options.buffer_max_dims = %d" % Options.buffer_max_dims)
79
+ if buftype.ndim > self.max_ndim:
80
+ self.max_ndim = buftype.ndim
81
+
82
+ # Declare auxiliary vars
83
+ def decvar(type, prefix):
84
+ cname = scope.mangle(prefix, name)
85
+ aux_var = scope.declare_var(name=None, cname=cname,
86
+ type=type, pos=node.pos)
87
+ if entry.is_arg:
88
+ aux_var.used = True # otherwise, NameNode will mark whether it is used
89
+
90
+ return aux_var
91
+
92
+ auxvars = ((PyrexTypes.c_pyx_buffer_nd_type, Naming.pybuffernd_prefix),
93
+ (PyrexTypes.c_pyx_buffer_type, Naming.pybufferstruct_prefix))
94
+ pybuffernd, rcbuffer = [decvar(type, prefix) for (type, prefix) in auxvars]
95
+
96
+ entry.buffer_aux = Symtab.BufferAux(pybuffernd, rcbuffer)
97
+
98
+ scope.buffer_entries = bufvars
99
+ self.scope = scope
100
+
101
+ def visit_ModuleNode(self, node):
102
+ self.handle_scope(node, node.scope)
103
+ self.visitchildren(node)
104
+ return node
105
+
106
+ def visit_FuncDefNode(self, node):
107
+ self.handle_scope(node, node.local_scope)
108
+ self.visitchildren(node)
109
+ return node
110
+
111
+ #
112
+ # Analysis
113
+ #
114
+ buffer_options = ("dtype", "ndim", "mode", "negative_indices", "cast") # ordered!
115
+ buffer_defaults = {"ndim": 1, "mode": "full", "negative_indices": True, "cast": False}
116
+ buffer_positional_options_count = 1 # anything beyond this needs keyword argument
117
+
118
+ ERR_BUF_OPTION_UNKNOWN = '"%s" is not a buffer option'
119
+ ERR_BUF_TOO_MANY = 'Too many buffer options'
120
+ ERR_BUF_DUP = '"%s" buffer option already supplied'
121
+ ERR_BUF_MISSING = '"%s" missing'
122
+ ERR_BUF_MODE = 'Only allowed buffer modes are: "c", "fortran", "full", "strided" (as a compile-time string)'
123
+ ERR_BUF_NDIM = 'ndim must be a non-negative integer'
124
+ ERR_BUF_DTYPE = 'dtype must be "object", numeric type or a struct'
125
+ ERR_BUF_BOOL = '"%s" must be a boolean'
126
+
127
+ def analyse_buffer_options(globalpos, env, posargs, dictargs, defaults=None, need_complete=True):
128
+ """
129
+ Must be called during type analysis, as analyse is called
130
+ on the dtype argument.
131
+
132
+ posargs and dictargs should consist of a list and a dict
133
+ of tuples (value, pos). Defaults should be a dict of values.
134
+
135
+ Returns a dict containing all the options a buffer can have and
136
+ its value (with the positions stripped).
137
+ """
138
+ if defaults is None:
139
+ defaults = buffer_defaults
140
+
141
+ posargs, dictargs = Interpreter.interpret_compiletime_options(
142
+ posargs, dictargs, type_env=env, type_args=(0, 'dtype'))
143
+
144
+ if len(posargs) > buffer_positional_options_count:
145
+ raise CompileError(posargs[-1][1], ERR_BUF_TOO_MANY)
146
+
147
+ options = {}
148
+ for name, (value, pos) in dictargs.items():
149
+ if name not in buffer_options:
150
+ raise CompileError(pos, ERR_BUF_OPTION_UNKNOWN % name)
151
+ options[name] = value
152
+
153
+ for name, (value, pos) in zip(buffer_options, posargs):
154
+ if name not in buffer_options:
155
+ raise CompileError(pos, ERR_BUF_OPTION_UNKNOWN % name)
156
+ if name in options:
157
+ raise CompileError(pos, ERR_BUF_DUP % name)
158
+ options[name] = value
159
+
160
+ # Check that they are all there and copy defaults
161
+ for name in buffer_options:
162
+ if name not in options:
163
+ try:
164
+ options[name] = defaults[name]
165
+ except KeyError:
166
+ if need_complete:
167
+ raise CompileError(globalpos, ERR_BUF_MISSING % name)
168
+
169
+ dtype = options.get("dtype")
170
+ if dtype and dtype.is_extension_type:
171
+ raise CompileError(globalpos, ERR_BUF_DTYPE)
172
+
173
+ ndim = options.get("ndim")
174
+ if ndim and (not isinstance(ndim, int) or ndim < 0):
175
+ raise CompileError(globalpos, ERR_BUF_NDIM)
176
+
177
+ mode = options.get("mode")
178
+ if mode and not (mode in ('full', 'strided', 'c', 'fortran')):
179
+ raise CompileError(globalpos, ERR_BUF_MODE)
180
+
181
+ def assert_bool(name):
182
+ x = options.get(name)
183
+ if not isinstance(x, bool):
184
+ raise CompileError(globalpos, ERR_BUF_BOOL % name)
185
+
186
+ assert_bool('negative_indices')
187
+ assert_bool('cast')
188
+
189
+ return options
190
+
191
+
192
+ #
193
+ # Code generation
194
+ #
195
+
196
+ class BufferEntry(object):
197
+ def __init__(self, entry):
198
+ self.entry = entry
199
+ self.type = entry.type
200
+ self.cname = entry.buffer_aux.buflocal_nd_var.cname
201
+ self.buf_ptr = "%s.rcbuffer->pybuffer.buf" % self.cname
202
+ self.buf_ptr_type = entry.type.buffer_ptr_type
203
+ self.init_attributes()
204
+
205
+ def init_attributes(self):
206
+ self.shape = self.get_buf_shapevars()
207
+ self.strides = self.get_buf_stridevars()
208
+ self.suboffsets = self.get_buf_suboffsetvars()
209
+
210
+ def get_buf_suboffsetvars(self):
211
+ return self._for_all_ndim("%s.diminfo[%d].suboffsets")
212
+
213
+ def get_buf_stridevars(self):
214
+ return self._for_all_ndim("%s.diminfo[%d].strides")
215
+
216
+ def get_buf_shapevars(self):
217
+ return self._for_all_ndim("%s.diminfo[%d].shape")
218
+
219
+ def _for_all_ndim(self, s):
220
+ return [s % (self.cname, i) for i in range(self.type.ndim)]
221
+
222
+ def generate_buffer_lookup_code(self, code, index_cnames):
223
+ # Create buffer lookup and return it
224
+ # This is done via utility macros/inline functions, which vary
225
+ # according to the access mode used.
226
+ params = []
227
+ nd = self.type.ndim
228
+ mode = self.type.mode
229
+ if mode == 'full':
230
+ for i, s, o in zip(index_cnames,
231
+ self.get_buf_stridevars(),
232
+ self.get_buf_suboffsetvars()):
233
+ params.append(i)
234
+ params.append(s)
235
+ params.append(o)
236
+ funcname = "__Pyx_BufPtrFull%dd" % nd
237
+ funcgen = buf_lookup_full_code
238
+ else:
239
+ if mode == 'strided':
240
+ funcname = "__Pyx_BufPtrStrided%dd" % nd
241
+ funcgen = buf_lookup_strided_code
242
+ elif mode == 'c':
243
+ funcname = "__Pyx_BufPtrCContig%dd" % nd
244
+ funcgen = buf_lookup_c_code
245
+ elif mode == 'fortran':
246
+ funcname = "__Pyx_BufPtrFortranContig%dd" % nd
247
+ funcgen = buf_lookup_fortran_code
248
+ else:
249
+ assert False
250
+ for i, s in zip(index_cnames, self.get_buf_stridevars()):
251
+ params.append(i)
252
+ params.append(s)
253
+
254
+ # Make sure the utility code is available
255
+ if funcname not in code.globalstate.utility_codes:
256
+ code.globalstate.utility_codes.add(funcname)
257
+ protocode = code.globalstate['utility_code_proto']
258
+ defcode = code.globalstate['utility_code_def']
259
+ funcgen(protocode, defcode, name=funcname, nd=nd)
260
+
261
+ buf_ptr_type_code = self.buf_ptr_type.empty_declaration_code()
262
+ ptrcode = "%s(%s, %s, %s)" % (funcname, buf_ptr_type_code, self.buf_ptr,
263
+ ", ".join(params))
264
+ return ptrcode
265
+
266
+
267
+ def get_flags(buffer_aux, buffer_type):
268
+ flags = 'PyBUF_FORMAT'
269
+ mode = buffer_type.mode
270
+ if mode == 'full':
271
+ flags += '| PyBUF_INDIRECT'
272
+ elif mode == 'strided':
273
+ flags += '| PyBUF_STRIDES'
274
+ elif mode == 'c':
275
+ flags += '| PyBUF_C_CONTIGUOUS'
276
+ elif mode == 'fortran':
277
+ flags += '| PyBUF_F_CONTIGUOUS'
278
+ else:
279
+ assert False
280
+ if buffer_aux.writable_needed: flags += "| PyBUF_WRITABLE"
281
+ return flags
282
+
283
+ def used_buffer_aux_vars(entry):
284
+ buffer_aux = entry.buffer_aux
285
+ buffer_aux.buflocal_nd_var.used = True
286
+ buffer_aux.rcbuf_var.used = True
287
+
288
+ def put_unpack_buffer_aux_into_scope(buf_entry, code):
289
+ # Generate code to copy the needed struct info into local
290
+ # variables.
291
+ buffer_aux, mode = buf_entry.buffer_aux, buf_entry.type.mode
292
+ pybuffernd_struct = buffer_aux.buflocal_nd_var.cname
293
+
294
+ fldnames = ['strides', 'shape']
295
+ if mode == 'full':
296
+ fldnames.append('suboffsets')
297
+
298
+ ln = []
299
+ for i in range(buf_entry.type.ndim):
300
+ for fldname in fldnames:
301
+ ln.append("%s.diminfo[%d].%s = %s.rcbuffer->pybuffer.%s[%d];" % (
302
+ pybuffernd_struct, i, fldname,
303
+ pybuffernd_struct, fldname, i,
304
+ ))
305
+ code.putln(' '.join(ln))
306
+
307
+ def put_init_vars(entry, code):
308
+ bufaux = entry.buffer_aux
309
+ pybuffernd_struct = bufaux.buflocal_nd_var.cname
310
+ pybuffer_struct = bufaux.rcbuf_var.cname
311
+ # init pybuffer_struct
312
+ code.putln("%s.pybuffer.buf = NULL;" % pybuffer_struct)
313
+ code.putln("%s.refcount = 0;" % pybuffer_struct)
314
+ # init the buffer object
315
+ # code.put_init_var_to_py_none(entry)
316
+ # init the pybuffernd_struct
317
+ code.putln("%s.data = NULL;" % pybuffernd_struct)
318
+ code.putln("%s.rcbuffer = &%s;" % (pybuffernd_struct, pybuffer_struct))
319
+
320
+
321
+ def put_acquire_arg_buffer(entry, code, pos):
322
+ buffer_aux = entry.buffer_aux
323
+ getbuffer = get_getbuffer_call(code, entry.cname, buffer_aux, entry.type)
324
+
325
+ # Acquire any new buffer
326
+ code.putln("{")
327
+ code.putln("__Pyx_BufFmt_StackElem __pyx_stack[%d];" % entry.type.dtype.struct_nesting_depth())
328
+ code.putln(code.error_goto_if("%s == -1" % getbuffer, pos))
329
+ code.putln("}")
330
+ # An exception raised in arg parsing cannot be caught, so no
331
+ # need to care about the buffer then.
332
+ put_unpack_buffer_aux_into_scope(entry, code)
333
+
334
+
335
+ def put_release_buffer_code(code, entry):
336
+ code.globalstate.use_utility_code(acquire_utility_code)
337
+ code.putln("__Pyx_SafeReleaseBuffer(&%s.rcbuffer->pybuffer);" % entry.buffer_aux.buflocal_nd_var.cname)
338
+
339
+
340
+ def get_getbuffer_call(code, obj_cname, buffer_aux, buffer_type):
341
+ ndim = buffer_type.ndim
342
+ cast = int(buffer_type.cast)
343
+ flags = get_flags(buffer_aux, buffer_type)
344
+ pybuffernd_struct = buffer_aux.buflocal_nd_var.cname
345
+
346
+ dtype_typeinfo = get_type_information_cname(code, buffer_type.dtype)
347
+
348
+ code.globalstate.use_utility_code(acquire_utility_code)
349
+ return ("__Pyx_GetBufferAndValidate(&%(pybuffernd_struct)s.rcbuffer->pybuffer, "
350
+ "(PyObject*)%(obj_cname)s, &%(dtype_typeinfo)s, %(flags)s, %(ndim)d, "
351
+ "%(cast)d, __pyx_stack)" % locals())
352
+
353
+
354
+ def put_assign_to_buffer(lhs_cname, rhs_cname, buf_entry,
355
+ is_initialized, pos, code):
356
+ """
357
+ Generate code for reassigning a buffer variables. This only deals with getting
358
+ the buffer auxiliary structure and variables set up correctly, the assignment
359
+ itself and refcounting is the responsibility of the caller.
360
+
361
+ However, the assignment operation may throw an exception so that the reassignment
362
+ never happens.
363
+
364
+ Depending on the circumstances there are two possible outcomes:
365
+ - Old buffer released, new acquired, rhs assigned to lhs
366
+ - Old buffer released, new acquired which fails, reaqcuire old lhs buffer
367
+ (which may or may not succeed).
368
+ """
369
+
370
+ buffer_aux, buffer_type = buf_entry.buffer_aux, buf_entry.type
371
+ pybuffernd_struct = buffer_aux.buflocal_nd_var.cname
372
+ flags = get_flags(buffer_aux, buffer_type)
373
+
374
+ code.putln("{") # Set up necessary stack for getbuffer
375
+ code.putln("__Pyx_BufFmt_StackElem __pyx_stack[%d];" % buffer_type.dtype.struct_nesting_depth())
376
+
377
+ getbuffer = get_getbuffer_call(code, "%s", buffer_aux, buffer_type) # fill in object below
378
+
379
+ if is_initialized:
380
+ # Release any existing buffer
381
+ code.putln('__Pyx_SafeReleaseBuffer(&%s.rcbuffer->pybuffer);' % pybuffernd_struct)
382
+ # Acquire
383
+ retcode_cname = code.funcstate.allocate_temp(PyrexTypes.c_int_type, manage_ref=False)
384
+ code.putln("%s = %s;" % (retcode_cname, getbuffer % rhs_cname))
385
+ code.putln('if (%s) {' % (code.unlikely("%s < 0" % retcode_cname)))
386
+ # If acquisition failed, attempt to reacquire the old buffer
387
+ # before raising the exception. A failure of reacquisition
388
+ # will cause the reacquisition exception to be reported, one
389
+ # can consider working around this later.
390
+ exc_temps = tuple(code.funcstate.allocate_temp(PyrexTypes.py_object_type, manage_ref=False)
391
+ for _ in range(3))
392
+ code.putln('PyErr_Fetch(&%s, &%s, &%s);' % exc_temps)
393
+ code.putln('if (%s) {' % code.unlikely("%s == -1" % (getbuffer % lhs_cname)))
394
+ code.putln('Py_XDECREF(%s); Py_XDECREF(%s); Py_XDECREF(%s);' % exc_temps) # Do not refnanny these!
395
+ code.globalstate.use_utility_code(raise_buffer_fallback_code)
396
+ code.putln('__Pyx_RaiseBufferFallbackError();')
397
+ code.putln('} else {')
398
+ code.putln('PyErr_Restore(%s, %s, %s);' % exc_temps)
399
+ code.putln('}')
400
+ code.putln('%s = %s = %s = 0;' % exc_temps)
401
+ for t in exc_temps:
402
+ code.funcstate.release_temp(t)
403
+ code.putln('}')
404
+ # Unpack indices
405
+ put_unpack_buffer_aux_into_scope(buf_entry, code)
406
+ code.putln(code.error_goto_if_neg(retcode_cname, pos))
407
+ code.funcstate.release_temp(retcode_cname)
408
+ else:
409
+ # Our entry had no previous value, so set to None when acquisition fails.
410
+ # In this case, auxiliary vars should be set up right in initialization to a zero-buffer,
411
+ # so it suffices to set the buf field to NULL.
412
+ code.putln('if (%s) {' % code.unlikely("%s == -1" % (getbuffer % rhs_cname)))
413
+ code.putln('%s = %s; __Pyx_INCREF(Py_None); %s.rcbuffer->pybuffer.buf = NULL;' %
414
+ (lhs_cname,
415
+ PyrexTypes.typecast(buffer_type, PyrexTypes.py_object_type, "Py_None"),
416
+ pybuffernd_struct))
417
+ code.putln(code.error_goto(pos))
418
+ code.put('} else {')
419
+ # Unpack indices
420
+ put_unpack_buffer_aux_into_scope(buf_entry, code)
421
+ code.putln('}')
422
+
423
+ code.putln("}") # Release stack
424
+
425
+
426
+ def put_buffer_lookup_code(entry, index_signeds, index_cnames, directives,
427
+ pos, code, negative_indices, in_nogil_context):
428
+ """
429
+ Generates code to process indices and calculate an offset into
430
+ a buffer. Returns a C string which gives a pointer which can be
431
+ read from or written to at will (it is an expression so caller should
432
+ store it in a temporary if it is used more than once).
433
+
434
+ As the bounds checking can have any number of combinations of unsigned
435
+ arguments, smart optimizations etc. we insert it directly in the function
436
+ body. The lookup however is delegated to a inline function that is instantiated
437
+ once per ndim (lookup with suboffsets tend to get quite complicated).
438
+
439
+ entry is a BufferEntry
440
+ """
441
+ negative_indices = directives['wraparound'] and negative_indices
442
+
443
+ if directives['boundscheck']:
444
+ # Check bounds and fix negative indices.
445
+ # We allocate a temporary which is initialized to -1, meaning OK (!).
446
+ # If an error occurs, the temp is set to the index dimension the
447
+ # error is occurring at.
448
+ failed_dim_temp = code.funcstate.allocate_temp(PyrexTypes.c_int_type, manage_ref=False)
449
+ code.putln("%s = -1;" % failed_dim_temp)
450
+ for dim, (signed, cname, shape) in enumerate(zip(index_signeds, index_cnames, entry.get_buf_shapevars())):
451
+ if signed != 0:
452
+ # not unsigned, deal with negative index
453
+ code.putln("if (%s < 0) {" % cname)
454
+ if negative_indices:
455
+ code.putln("%s += %s;" % (cname, shape))
456
+ code.putln("if (%s) %s = %d;" % (
457
+ code.unlikely("%s < 0" % cname),
458
+ failed_dim_temp, dim))
459
+ else:
460
+ code.putln("%s = %d;" % (failed_dim_temp, dim))
461
+ code.put("} else ")
462
+ # check bounds in positive direction
463
+ if signed != 0:
464
+ cast = ""
465
+ else:
466
+ cast = "(size_t)"
467
+ code.putln("if (%s) %s = %d;" % (
468
+ code.unlikely("%s >= %s%s" % (cname, cast, shape)),
469
+ failed_dim_temp, dim))
470
+
471
+ if in_nogil_context:
472
+ code.globalstate.use_utility_code(raise_indexerror_nogil)
473
+ func = '__Pyx_RaiseBufferIndexErrorNogil'
474
+ else:
475
+ code.globalstate.use_utility_code(raise_indexerror_code)
476
+ func = '__Pyx_RaiseBufferIndexError'
477
+
478
+ code.putln("if (%s) {" % code.unlikely("%s != -1" % failed_dim_temp))
479
+ code.putln('%s(%s);' % (func, failed_dim_temp))
480
+ code.putln(code.error_goto(pos))
481
+ code.putln('}')
482
+ code.funcstate.release_temp(failed_dim_temp)
483
+ elif negative_indices:
484
+ # Only fix negative indices.
485
+ for signed, cname, shape in zip(index_signeds, index_cnames, entry.get_buf_shapevars()):
486
+ if signed != 0:
487
+ code.putln("if (%s < 0) %s += %s;" % (cname, cname, shape))
488
+
489
+ return entry.generate_buffer_lookup_code(code, index_cnames)
490
+
491
+
492
+ def use_bufstruct_declare_code(env):
493
+ env.use_utility_code(buffer_struct_declare_code)
494
+
495
+
496
+ def buf_lookup_full_code(proto, defin, name, nd):
497
+ """
498
+ Generates a buffer lookup function for the right number
499
+ of dimensions. The function gives back a void* at the right location.
500
+ """
501
+ # _i_ndex, _s_tride, sub_o_ffset
502
+ macroargs = ", ".join(["i%d, s%d, o%d" % (i, i, i) for i in range(nd)])
503
+ proto.putln("#define %s(type, buf, %s) (type)(%s_imp(buf, %s))" % (name, macroargs, name, macroargs))
504
+
505
+ funcargs = ", ".join(["Py_ssize_t i%d, Py_ssize_t s%d, Py_ssize_t o%d" % (i, i, i) for i in range(nd)])
506
+ proto.putln("static CYTHON_INLINE void* %s_imp(void* buf, %s);" % (name, funcargs))
507
+ defin.putln(dedent("""
508
+ static CYTHON_INLINE void* %s_imp(void* buf, %s) {
509
+ char* ptr = (char*)buf;
510
+ """) % (name, funcargs) + "".join([dedent("""\
511
+ ptr += s%d * i%d;
512
+ if (o%d >= 0) ptr = *((char**)ptr) + o%d;
513
+ """) % (i, i, i, i) for i in range(nd)]
514
+ ) + "\nreturn ptr;\n}")
515
+
516
+
517
+ def buf_lookup_strided_code(proto, defin, name, nd):
518
+ """
519
+ Generates a buffer lookup function for the right number
520
+ of dimensions. The function gives back a void* at the right location.
521
+ """
522
+ # _i_ndex, _s_tride
523
+ args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)])
524
+ offset = " + ".join(["i%d * s%d" % (i, i) for i in range(nd)])
525
+ proto.putln("#define %s(type, buf, %s) (type)((char*)buf + %s)" % (name, args, offset))
526
+
527
+
528
+ def buf_lookup_c_code(proto, defin, name, nd):
529
+ """
530
+ Similar to strided lookup, but can assume that the last dimension
531
+ doesn't need a multiplication as long as.
532
+ Still we keep the same signature for now.
533
+ """
534
+ if nd == 1:
535
+ proto.putln("#define %s(type, buf, i0, s0) ((type)buf + i0)" % name)
536
+ else:
537
+ args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)])
538
+ offset = " + ".join(["i%d * s%d" % (i, i) for i in range(nd - 1)])
539
+ proto.putln("#define %s(type, buf, %s) ((type)((char*)buf + %s) + i%d)" % (name, args, offset, nd - 1))
540
+
541
+
542
+ def buf_lookup_fortran_code(proto, defin, name, nd):
543
+ """
544
+ Like C lookup, but the first index is optimized instead.
545
+ """
546
+ if nd == 1:
547
+ proto.putln("#define %s(type, buf, i0, s0) ((type)buf + i0)" % name)
548
+ else:
549
+ args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)])
550
+ offset = " + ".join(["i%d * s%d" % (i, i) for i in range(1, nd)])
551
+ proto.putln("#define %s(type, buf, %s) ((type)((char*)buf + %s) + i%d)" % (name, args, offset, 0))
552
+
553
+
554
+ def use_py2_buffer_functions(env):
555
+ env.use_utility_code(GetAndReleaseBufferUtilityCode())
556
+
557
+
558
+ class GetAndReleaseBufferUtilityCode(object):
559
+ # Emulation of PyObject_GetBuffer and PyBuffer_Release for Python 2.
560
+ # For >= 2.6 we do double mode -- use the new buffer interface on objects
561
+ # which has the right tp_flags set, but emulation otherwise.
562
+
563
+ requires = None
564
+ is_cython_utility = False
565
+
566
+ def __init__(self):
567
+ pass
568
+
569
+ def __eq__(self, other):
570
+ return isinstance(other, GetAndReleaseBufferUtilityCode)
571
+
572
+ def __hash__(self):
573
+ return 24342342
574
+
575
+ def get_tree(self, **kwargs): pass
576
+
577
+ def put_code(self, output):
578
+ code = output['utility_code_def']
579
+ proto_code = output['utility_code_proto']
580
+ env = output.module_node.scope
581
+ cython_scope = env.context.cython_scope
582
+
583
+ # Search all types for __getbuffer__ overloads
584
+ types = []
585
+ visited_scopes = set()
586
+ def find_buffer_types(scope):
587
+ if scope in visited_scopes:
588
+ return
589
+ visited_scopes.add(scope)
590
+ for m in scope.cimported_modules:
591
+ find_buffer_types(m)
592
+ for e in scope.type_entries:
593
+ if isinstance(e.utility_code_definition, CythonUtilityCode):
594
+ continue
595
+ t = e.type
596
+ if t.is_extension_type:
597
+ if scope is cython_scope and not e.used:
598
+ continue
599
+ release = get = None
600
+ for x in t.scope.pyfunc_entries:
601
+ if x.name == u"__getbuffer__": get = x.func_cname
602
+ elif x.name == u"__releasebuffer__": release = x.func_cname
603
+ if get:
604
+ types.append((t.typeptr_cname, get, release))
605
+
606
+ find_buffer_types(env)
607
+
608
+ util_code = TempitaUtilityCode.load(
609
+ "GetAndReleaseBuffer", from_file="Buffer.c",
610
+ context=dict(types=types))
611
+
612
+ proto = util_code.format_code(util_code.proto)
613
+ impl = util_code.format_code(
614
+ util_code.inject_string_constants(util_code.impl, output)[1])
615
+
616
+ proto_code.putln(proto)
617
+ code.putln(impl)
618
+
619
+
620
+ def mangle_dtype_name(dtype):
621
+ # Use prefixes to separate user defined types from builtins
622
+ # (consider "typedef float unsigned_int")
623
+ if dtype.is_pyobject:
624
+ return "object"
625
+ elif dtype.is_ptr:
626
+ return "ptr"
627
+ else:
628
+ if dtype.is_typedef or dtype.is_struct_or_union:
629
+ prefix = "nn_"
630
+ else:
631
+ prefix = ""
632
+ return prefix + dtype.specialization_name()
633
+
634
+ def get_type_information_cname(code, dtype, maxdepth=None):
635
+ """
636
+ Output the run-time type information (__Pyx_TypeInfo) for given dtype,
637
+ and return the name of the type info struct.
638
+
639
+ Structs with two floats of the same size are encoded as complex numbers.
640
+ One can separate between complex numbers declared as struct or with native
641
+ encoding by inspecting to see if the fields field of the type is
642
+ filled in.
643
+ """
644
+ namesuffix = mangle_dtype_name(dtype)
645
+ name = "__Pyx_TypeInfo_%s" % namesuffix
646
+ structinfo_name = "__Pyx_StructFields_%s" % namesuffix
647
+
648
+ if dtype.is_error: return "<error>"
649
+
650
+ # It's critical that walking the type info doesn't use more stack
651
+ # depth than dtype.struct_nesting_depth() returns, so use an assertion for this
652
+ if maxdepth is None: maxdepth = dtype.struct_nesting_depth()
653
+ if maxdepth <= 0:
654
+ assert False
655
+
656
+ if name not in code.globalstate.utility_codes:
657
+ code.globalstate.utility_codes.add(name)
658
+ typecode = code.globalstate['typeinfo']
659
+
660
+ arraysizes = []
661
+ if dtype.is_array:
662
+ while dtype.is_array:
663
+ arraysizes.append(dtype.size)
664
+ dtype = dtype.base_type
665
+
666
+ complex_possible = dtype.is_struct_or_union and dtype.can_be_complex()
667
+
668
+ declcode = dtype.empty_declaration_code()
669
+ if dtype.is_simple_buffer_dtype():
670
+ structinfo_name = "NULL"
671
+ elif dtype.is_struct:
672
+ struct_scope = dtype.scope
673
+ if dtype.is_cv_qualified:
674
+ struct_scope = struct_scope.base_type_scope
675
+ # Must pre-call all used types in order not to recurse during utility code writing.
676
+ fields = struct_scope.var_entries
677
+ assert len(fields) > 0
678
+ types = [get_type_information_cname(code, f.type, maxdepth - 1)
679
+ for f in fields]
680
+ typecode.putln("static __Pyx_StructField %s[] = {" % structinfo_name, safe=True)
681
+
682
+ if dtype.is_cv_qualified:
683
+ # roughly speaking, remove "const" from struct_type
684
+ struct_type = dtype.cv_base_type.empty_declaration_code()
685
+ else:
686
+ struct_type = dtype.empty_declaration_code()
687
+
688
+ for f, typeinfo in zip(fields, types):
689
+ typecode.putln(' {&%s, "%s", offsetof(%s, %s)},' %
690
+ (typeinfo, f.name, struct_type, f.cname), safe=True)
691
+
692
+ typecode.putln(' {NULL, NULL, 0}', safe=True)
693
+ typecode.putln("};", safe=True)
694
+ else:
695
+ assert False
696
+
697
+ rep = str(dtype)
698
+
699
+ flags = "0"
700
+ is_unsigned = "0"
701
+ if dtype is PyrexTypes.c_char_type:
702
+ is_unsigned = "__PYX_IS_UNSIGNED(%s)" % declcode
703
+ typegroup = "'H'"
704
+ elif dtype.is_int:
705
+ is_unsigned = "__PYX_IS_UNSIGNED(%s)" % declcode
706
+ typegroup = "%s ? 'U' : 'I'" % is_unsigned
707
+ elif complex_possible or dtype.is_complex:
708
+ typegroup = "'C'"
709
+ elif dtype.is_float:
710
+ typegroup = "'R'"
711
+ elif dtype.is_struct:
712
+ typegroup = "'S'"
713
+ if dtype.packed:
714
+ flags = "__PYX_BUF_FLAGS_PACKED_STRUCT"
715
+ elif dtype.is_pyobject:
716
+ typegroup = "'O'"
717
+ else:
718
+ assert False, dtype
719
+
720
+ typeinfo = ('static __Pyx_TypeInfo %s = '
721
+ '{ "%s", %s, sizeof(%s), { %s }, %s, %s, %s, %s };')
722
+ tup = (name, rep, structinfo_name, declcode,
723
+ ', '.join([str(x) for x in arraysizes]) or '0', len(arraysizes),
724
+ typegroup, is_unsigned, flags)
725
+ typecode.putln(typeinfo % tup, safe=True)
726
+
727
+ return name
728
+
729
+ def load_buffer_utility(util_code_name, context=None, **kwargs):
730
+ if context is None:
731
+ return UtilityCode.load(util_code_name, "Buffer.c", **kwargs)
732
+ else:
733
+ return TempitaUtilityCode.load(util_code_name, "Buffer.c", context=context, **kwargs)
734
+
735
+ context = dict(max_dims=Options.buffer_max_dims)
736
+ buffer_struct_declare_code = load_buffer_utility("BufferStructDeclare", context=context)
737
+ buffer_formats_declare_code = load_buffer_utility("BufferFormatStructs")
738
+
739
+ # Utility function to set the right exception
740
+ # The caller should immediately goto_error
741
+ raise_indexerror_code = load_buffer_utility("BufferIndexError")
742
+ raise_indexerror_nogil = load_buffer_utility("BufferIndexErrorNogil")
743
+ raise_buffer_fallback_code = load_buffer_utility("BufferFallbackError")
744
+
745
+ acquire_utility_code = load_buffer_utility("BufferGetAndValidate", context=context)
746
+ buffer_format_check_code = load_buffer_utility("BufferFormatCheck", context=context)
747
+
748
+ # See utility code BufferFormatFromTypeInfo
749
+ _typeinfo_to_format_code = load_buffer_utility("TypeInfoToFormat")
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/Code.py ADDED
The diff for this file is too large to render. See raw diff
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/FusedNode.cpython-311-x86_64-linux-gnu.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:02498cb7e330a1a7ccae2b142938c3c3c01d80751d06f9ade63bac39f2ab681a
3
+ size 517064
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/Parsing.py ADDED
The diff for this file is too large to render. See raw diff
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # empty file
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Lexicon.cpython-311.pyc ADDED
Binary file (14.5 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Naming.cpython-311.pyc ADDED
Binary file (7.78 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Options.cpython-311.pyc ADDED
Binary file (24.3 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Pythran.cpython-311.pyc ADDED
Binary file (12.8 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Scanning.cpython-311.pyc ADDED
Binary file (29.7 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/UtilityCode.cpython-311.pyc ADDED
Binary file (13.3 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/networkx/generators/atlas.dat.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:73fc416df0164923607751cb759f4ae81deb5f6550bf25be59c86de3b747e41d
3
+ size 8887
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pyproject_hooks/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Wrappers to call pyproject.toml-based build backend hooks.
2
+ """
3
+
4
+ from ._impl import (
5
+ BackendInvalid,
6
+ BackendUnavailable,
7
+ BuildBackendHookCaller,
8
+ HookMissing,
9
+ UnsupportedOperation,
10
+ default_subprocess_runner,
11
+ quiet_subprocess_runner,
12
+ )
13
+
14
+ __version__ = '1.0.0'
15
+ __all__ = [
16
+ 'BackendUnavailable',
17
+ 'BackendInvalid',
18
+ 'HookMissing',
19
+ 'UnsupportedOperation',
20
+ 'default_subprocess_runner',
21
+ 'quiet_subprocess_runner',
22
+ 'BuildBackendHookCaller',
23
+ ]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pyproject_hooks/__pycache__/_compat.cpython-311.pyc ADDED
Binary file (425 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pyproject_hooks/_in_process/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This is a subpackage because the directory is on sys.path for _in_process.py
2
+
3
+ The subpackage should stay as empty as possible to avoid shadowing modules that
4
+ the backend might import.
5
+ """
6
+
7
+ import importlib.resources as resources
8
+
9
+ try:
10
+ resources.files
11
+ except AttributeError:
12
+ # Python 3.8 compatibility
13
+ def _in_proc_script_path():
14
+ return resources.path(__package__, '_in_process.py')
15
+ else:
16
+ def _in_proc_script_path():
17
+ return resources.as_file(
18
+ resources.files(__package__).joinpath('_in_process.py'))
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pyproject_hooks/_in_process/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.19 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/code_template.cpython-311.pyc ADDED
Binary file (5.1 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/context.cpython-311.pyc ADDED
Binary file (7.73 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_aoti_c_shim.cpython-311.pyc ADDED
Binary file (18.2 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_backend_stubs.cpython-311.pyc ADDED
Binary file (27 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_executorch.cpython-311.pyc ADDED
Binary file (46.6 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_functionalization_type.cpython-311.pyc ADDED
Binary file (41.1 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_lazy_tensor.cpython-311.pyc ADDED
Binary file (22.2 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_vmap_plumbing.cpython-311.pyc ADDED
Binary file (15.6 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/local.cpython-311.pyc ADDED
Binary file (2.18 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/native_function_generation.cpython-311.pyc ADDED
Binary file (25.3 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/yaml_utils.cpython-311.pyc ADDED
Binary file (1.65 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .lazy_ir import (
2
+ generate_non_native_lazy_ir_nodes as generate_non_native_lazy_ir_nodes,
3
+ GenLazyIR as GenLazyIR,
4
+ GenLazyNativeFuncDefinition as GenLazyNativeFuncDefinition,
5
+ GenLazyShapeInferenceDefinition as GenLazyShapeInferenceDefinition,
6
+ )
7
+ from .native_functions import (
8
+ compute_native_function_declaration as compute_native_function_declaration,
9
+ )
10
+ from .register_dispatch_key import (
11
+ gen_registration_headers as gen_registration_headers,
12
+ gen_registration_helpers as gen_registration_helpers,
13
+ RegisterDispatchKey as RegisterDispatchKey,
14
+ )
15
+ from .ufunc import (
16
+ compute_ufunc_cpu as compute_ufunc_cpu,
17
+ compute_ufunc_cpu_kernel as compute_ufunc_cpu_kernel,
18
+ compute_ufunc_cuda as compute_ufunc_cuda,
19
+ )
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/__pycache__/lazy_ts_lowering.cpython-311.pyc ADDED
Binary file (3.54 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/__pycache__/ufunc.cpython-311.pyc ADDED
Binary file (28.1 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/lazy_ir.py ADDED
@@ -0,0 +1,707 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ from abc import ABC
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, List, Optional, Tuple, Union
5
+
6
+ import torchgen.api.dispatcher as dispatcher
7
+ from torchgen.api.lazy import (
8
+ getValueT,
9
+ isValueType,
10
+ LazyArgument,
11
+ LazyIrProperties,
12
+ LazyIrSchema,
13
+ tensorListValueT,
14
+ )
15
+ from torchgen.api.translate import translate
16
+ from torchgen.api.types import (
17
+ BaseCType,
18
+ Binding,
19
+ deviceT,
20
+ DispatcherSignature,
21
+ kernel_signature,
22
+ NativeSignature,
23
+ OptionalCType,
24
+ VectorCType,
25
+ )
26
+ from torchgen.context import method_with_native_function
27
+ from torchgen.dest.lazy_ts_lowering import ts_lowering_body
28
+ from torchgen.model import (
29
+ Argument,
30
+ BackendIndex,
31
+ BackendMetadata,
32
+ BaseTy,
33
+ BaseType,
34
+ FunctionSchema,
35
+ ListType,
36
+ NativeFunction,
37
+ NativeFunctionsGroup,
38
+ )
39
+
40
+
41
+ def node_ctor_arg_rvalue_string(arg: LazyArgument) -> str:
42
+ """
43
+ Given a LazyArgument,
44
+ generate a c++ string for materializing an rvalue of that arg for passing into
45
+ a lazy Node constructor.
46
+ """
47
+
48
+ # TODO: Matching on CType seems wrong; should be matching on Type
49
+ if isValueType(arg.lazy_type):
50
+ if isinstance(arg.lazy_type, BaseCType):
51
+ if arg.is_wrapped_scalar:
52
+ return f"node_{arg.name}"
53
+ elif arg.lazy_type.type is tensorListValueT:
54
+ return f"lazy_{arg.name}_tensorlist"
55
+ elif arg.is_symint_or_list:
56
+ return f"GetSymIntValue({arg.name})"
57
+ return f"lazy_{arg.name}->GetIrValue()"
58
+ elif isinstance(arg.lazy_type, OptionalCType):
59
+ if arg.is_symint_or_list:
60
+ # TODO: I don't understand when you should put lazy_ in the name
61
+ # or not
62
+ return f"{arg.name} ? c10::make_optional(GetSymIntValue(*{arg.name})) : c10::nullopt"
63
+ elif arg.is_wrapped_scalar:
64
+ return f"node_{arg.name}"
65
+ return (
66
+ f"lazy_{arg.name} ? "
67
+ f"c10::make_optional(lazy_{arg.name}->GetIrValue()) : "
68
+ "c10::nullopt"
69
+ )
70
+ else:
71
+ raise AssertionError(
72
+ f"TODO not sure if there are other valid types to handle here ({arg.lazy_type})"
73
+ )
74
+ else:
75
+ # NB: this is here because right now we aren't treating SymInt[] as a
76
+ # value type; when we do this needs to move above
77
+ # NB: we cannot test arg.lazy_type as we've already specified it is an
78
+ # int64_t and so we cannot distinguish between SymInt and int64_t
79
+ if isinstance(arg.orig_type, ListType) and arg.orig_type.elem == BaseType(
80
+ BaseTy.SymInt
81
+ ):
82
+ if arg.symint:
83
+ return f"GetSymIntArrayRefValue({arg.name})"
84
+ else:
85
+ return f"std::vector<int64_t>({arg.name}.begin(), {arg.name}.end())"
86
+ elif isinstance(arg.lazy_type, VectorCType) and isinstance(
87
+ arg.lazy_type.elem, BaseCType
88
+ ):
89
+ return f"std::vector<{arg.lazy_type.elem.type}>({arg.name}.begin(), {arg.name}.end())"
90
+ elif (
91
+ isinstance(arg.lazy_type, OptionalCType)
92
+ and isinstance(arg.lazy_type.elem, VectorCType)
93
+ and isinstance(arg.lazy_type.elem.elem, BaseCType)
94
+ ):
95
+ return f"torch::lazy::ToOptionalVector<{arg.lazy_type.elem.elem.type}>({arg.name})"
96
+ else:
97
+ return f"{arg.name}"
98
+
99
+
100
+ def node_ctor_inputs(schema: LazyIrSchema) -> str:
101
+ """
102
+ Produce a formatted string with the arguments as passed into the constructor of a node class.
103
+ """
104
+ node_ctor_values = [
105
+ node_ctor_arg_rvalue_string(arg) for arg in schema.filtered_args()
106
+ ]
107
+ return ", ".join(node_ctor_values)
108
+
109
+
110
+ def gen_fallback_code(
111
+ schema: LazyIrSchema,
112
+ sig: Union[DispatcherSignature, NativeSignature],
113
+ overload_name: str,
114
+ ) -> str:
115
+ """
116
+ Generate code that falls back to eager conditioned on a predicate
117
+ """
118
+ dispatcher_sig = DispatcherSignature.from_schema(schema.func)
119
+ exprs = translate(sig.arguments(), dispatcher_sig.arguments())
120
+ fallback_args = ",\n ".join([a.expr for a in exprs])
121
+ if len(overload_name):
122
+ aten_op_str = f"ATEN_OP2({schema.aten_name}, {overload_name})"
123
+ else:
124
+ aten_op_str = f"ATEN_OP({schema.aten_name})"
125
+ return f"""
126
+ if (force_eager_fallback({aten_symbol(schema)})) {{
127
+ return at::native::call_fallback_fn_symint<&ltc_eager_fallback, {aten_op_str}>::call(
128
+ {fallback_args}
129
+ );
130
+ }}
131
+ """
132
+
133
+
134
+ def aten_symbol(schema: LazyIrSchema) -> str:
135
+ missing_interned_strings = {
136
+ "sigmoid_backward",
137
+ }
138
+ if schema.aten_name in missing_interned_strings:
139
+ return f'c10::Symbol::fromQualString("aten::{schema.aten_name}")'
140
+
141
+ if not schema.aten_name.startswith("at::"):
142
+ return f"at::aten::{schema.aten_name}"
143
+ else:
144
+ return schema.aten_name
145
+
146
+
147
+ # converts all tensor-like arguments to meta tensors. Returns:
148
+ # (1) a string containing all of the logic that does the conversions.
149
+ # (2) a context, to be used by translate(), with all of the relevant bindings.
150
+ def convert_to_meta_tensors(sig: DispatcherSignature) -> Tuple[str, List[Binding]]:
151
+ context: List[Binding] = []
152
+ unwrapped_tensor_args: List[str] = []
153
+ for arg in sig.arguments():
154
+ if isinstance(arg.argument, Argument) and arg.argument.type.is_tensor_like():
155
+ unwrapped_name = f"{arg.name}_meta"
156
+ unwrapped_tensor_args.append(
157
+ f"auto {unwrapped_name} = to_meta({arg.name});"
158
+ )
159
+ context.append(arg.with_name(unwrapped_name))
160
+ else:
161
+ context.append(arg)
162
+ unwrap_tensor_args_str = "\n ".join(unwrapped_tensor_args)
163
+ return unwrap_tensor_args_str, context
164
+
165
+
166
+ @dataclass(frozen=True)
167
+ class GenLazyIR(ABC):
168
+ backend_index: BackendIndex
169
+ backend_name: str
170
+ node_base: str
171
+ use_lazy_shape: bool
172
+
173
+ @method_with_native_function
174
+ def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]:
175
+ func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
176
+ metadata = self.backend_index.get_kernel(
177
+ f.functional if isinstance(f, NativeFunctionsGroup) else f
178
+ )
179
+ schema = LazyIrSchema(
180
+ func, symint=metadata is not None and metadata.supports_symint()
181
+ )
182
+ return self.gen(schema)
183
+
184
+ # there is no lowering functionality generated unless this IR base class is subclassed and
185
+ # implemented as a backend-specific node
186
+ def lowering_function(self, schema: LazyIrSchema) -> str:
187
+ return ""
188
+
189
+ def create_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
190
+ return ""
191
+
192
+ def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
193
+ return f"""bool CanBeReused({node_ctor_args}) const {{
194
+ return false;
195
+ }}"""
196
+
197
+ def node_base_ctor_call(self, schema: LazyIrSchema) -> str:
198
+ value_args = schema.filtered_args(values=True, scalars=False)
199
+ # backends can customize the way the node base class constructor is called,
200
+ # as long as all of its arguments can be generated from information available from the schema
201
+ base_ctor_value_args_list = []
202
+ for arg in value_args:
203
+ if isinstance(arg.lazy_type, (BaseCType, VectorCType)):
204
+ base_ctor_value_args_list.append(f"{arg.name}")
205
+ elif isinstance(arg.lazy_type, OptionalCType):
206
+ base_ctor_value_args_list.append(f"{arg.name}.value_or(kNullValue)")
207
+ else:
208
+ raise AssertionError(
209
+ f"Unsupported type ({arg.lazy_type}) - add support if necessary"
210
+ )
211
+ base_ctor_value_args = ", ".join(base_ctor_value_args_list)
212
+
213
+ scalar_args = schema.filtered_args(values=False, scalars=True)
214
+
215
+ # Shape construction.
216
+ # Conditionally build shape depending on specified shape property
217
+ if schema.properties.ShapePrecompute:
218
+ shape_ctor_arg = "std::move(shapes),"
219
+ elif schema.properties.ShapeCompute:
220
+ shape_args = [a.name for a in value_args]
221
+ shape_args.extend(a.name for a in scalar_args)
222
+ shape_ctor_arg = f"compute_shape_{schema.name}({', '.join(shape_args)}),"
223
+ elif schema.properties.ShapeCache:
224
+ shape_args = [f"operand({i})" for i in range(len(value_args))]
225
+ shape_args.extend(a.name for a in scalar_args)
226
+ shape_ctor_arg = f"[&](){{ return compute_shape_{schema.name}({', '.join(shape_args)})[0]; }},"
227
+ else:
228
+ shape_ctor_arg = ""
229
+
230
+ scalar_hashes = ", ".join(f"{a.name}" for a in scalar_args)
231
+
232
+ return f"""{self.node_base}(
233
+ {schema.node_name}::ClassOpKind(),
234
+ OpList{{{base_ctor_value_args}}},
235
+ {shape_ctor_arg}
236
+ /* num_outputs */ {len(schema.returns)},
237
+ torch::lazy::MHash({scalar_hashes}))"""
238
+
239
+ def gen(self, schema: LazyIrSchema) -> List[str]:
240
+ opkind = schema.opkind or aten_symbol(schema)
241
+
242
+ # for now, we just want one IR class decl and soon after also the method defs
243
+ # and we use the functional version not out/inplace.
244
+ all_args = schema.filtered_args()
245
+ value_args = schema.filtered_args(values=True, scalars=False)
246
+ scalar_args = schema.filtered_args(values=False, scalars=True)
247
+
248
+ ctor_args = [f"const {i.lazy_type.cpp_type()}& {i.name}" for i in all_args]
249
+ reuse_ctor_args = ", ".join(ctor_args)
250
+ if self.use_lazy_shape and schema.properties.ShapePrecompute:
251
+ ctor_args.append("std::vector<torch::lazy::Shape>&& shapes")
252
+ node_ctor_args = ", ".join(ctor_args)
253
+
254
+ scalar_initializers = ",\n ".join(
255
+ [
256
+ # This code is just special casing the mapping from string_view -> strings
257
+ f"{a.name}({a.name}.has_value() ? c10::make_optional(std::string(*{a.name})) : c10::nullopt)"
258
+ if a.lazy_type.cpp_type() == "c10::optional<c10::string_view>"
259
+ else f"{a.name}({a.name})"
260
+ for a in scalar_args
261
+ ]
262
+ )
263
+ if len(scalar_initializers):
264
+ scalar_initializers = f",\n {scalar_initializers}"
265
+ scalar_decls = "\n ".join(
266
+ [
267
+ f"std::string {a.name};"
268
+ if a.lazy_type.cpp_type() == "c10::string_view"
269
+ else f"c10::optional<std::string> {a.name};"
270
+ if a.lazy_type.cpp_type() == "c10::optional<c10::string_view>"
271
+ else f"{a.lazy_type.cpp_type()} {a.name};"
272
+ for a in scalar_args
273
+ ]
274
+ )
275
+ optional_values = [
276
+ arg.name
277
+ for arg in schema.filtered_args(values=True, scalars=False)
278
+ if isinstance(arg.lazy_type, OptionalCType)
279
+ ]
280
+ has_optional_decls = "\n ".join(
281
+ [f"bool has_{value}: 1;" for value in optional_values]
282
+ )
283
+ has_optional_defs = "\n ".join(
284
+ [f"has_{value} = !!{value};" for value in optional_values]
285
+ )
286
+ members_to_string = []
287
+ for arg in scalar_args:
288
+ if isinstance(arg.lazy_type, OptionalCType):
289
+ value = f"{arg.name}.value()"
290
+ if arg.is_generator:
291
+ value = '"torch.Generator()"'
292
+ members_to_string.append(
293
+ f"""if ({arg.name}.has_value()) {{
294
+ ss << ", {arg.name}=" << {value};
295
+ }} else {{
296
+ ss << ", {arg.name}=null";
297
+ }}"""
298
+ )
299
+ else:
300
+ members_to_string.append(f'ss << ", {arg.name}=" << {arg.name};')
301
+ members_to_string_str = "\n ".join(members_to_string)
302
+
303
+ return [
304
+ f"""\
305
+ class {schema.node_name} : public {self.node_base} {{
306
+ public:
307
+ static torch::lazy::OpKind ClassOpKind() {{
308
+ return torch::lazy::OpKind({opkind});
309
+ }}
310
+
311
+ {schema.node_name}({node_ctor_args})
312
+ : {self.node_base_ctor_call(schema)}{scalar_initializers}
313
+ {{
314
+ {has_optional_defs}
315
+ }}
316
+
317
+ std::string ToString() const override {{
318
+ std::stringstream ss;
319
+ ss << {self.node_base}::ToString();
320
+ {members_to_string_str}
321
+ return ss.str();
322
+ }}
323
+
324
+ {self.create_function(schema, reuse_ctor_args)}
325
+
326
+ {self.can_be_reused_function(schema, reuse_ctor_args)}
327
+
328
+ {self.lowering_function(schema)}
329
+
330
+ {scalar_decls}
331
+ {has_optional_decls}
332
+
333
+ }};
334
+
335
+ """,
336
+ ]
337
+
338
+
339
+ @dataclass(frozen=True)
340
+ class GenTSLazyIR(GenLazyIR):
341
+ def lowering_function(self, schema: LazyIrSchema) -> str:
342
+ signature = """
343
+ torch::lazy::TSOpVector Lower(
344
+ std::shared_ptr<torch::jit::GraphFunction> function,
345
+ torch::lazy::TSLoweringContext* loctx) const override"""
346
+
347
+ if schema.properties.LowerDeclOnly:
348
+ return f"{signature};"
349
+ elif schema.properties.Lower:
350
+ return f"""{signature} {{
351
+ {ts_lowering_body(schema)}
352
+ }}
353
+ """
354
+ else:
355
+ return ""
356
+
357
+ def create_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
358
+ signature = f"static NodePtr Create({node_ctor_args})"
359
+ if schema.properties.CreateFnDeclOnly:
360
+ return f"{signature};"
361
+ elif not schema.properties.CreateFn:
362
+ return ""
363
+ return f"""{signature} {{
364
+ return ReuseOrMakeNode<{schema.node_name}>(data);
365
+ }}"""
366
+
367
+ def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
368
+ signature = f"bool CanBeReused({node_ctor_args}) const"
369
+ if schema.properties.CanBeReusedDeclOnly:
370
+ return f"{signature};"
371
+ elif not schema.properties.CanBeReused:
372
+ return ""
373
+ value_comparison = []
374
+ for arg in itertools.chain(schema.positional_values, schema.keyword_values):
375
+ if isinstance(arg.lazy_type, OptionalCType):
376
+ value_comparison.append(
377
+ f"nullable_operand(i++) == {arg.name}.value_or(kNullValue)"
378
+ )
379
+ else:
380
+ value_comparison.append(f"operand(i++) == {arg.name}")
381
+ for arg in itertools.chain(schema.positional_scalars, schema.keyword_scalars):
382
+ if isinstance(arg.lazy_type, OptionalCType):
383
+ value_comparison.append(
384
+ f"((!this->{arg.name}&&!{arg.name}) || (this->{arg.name}&&{arg.name} && *(this->{arg.name}) == *{arg.name}))"
385
+ )
386
+ else:
387
+ value_comparison.append(f"this->{arg.name} == {arg.name}")
388
+ value_comparison_str = " &&\n ".join(value_comparison)
389
+
390
+ return f"""{signature} {{
391
+ size_t i = 0;
392
+ return ({value_comparison_str});
393
+ }}"""
394
+
395
+
396
+ @dataclass(frozen=True)
397
+ class GenLazyNativeFuncDefinition:
398
+ class_method_name: str
399
+ backend_index: BackendIndex
400
+ tensor_class: str
401
+ gen_forced_fallback_code: bool
402
+ backend_namespace: str
403
+ get_tensorlist: str
404
+ get_tensor_or_wrap_number: str
405
+ try_get_tensor: str
406
+ metrics_counter: str
407
+ create_tensor: str
408
+ create_from_first_tensor: bool
409
+ create_aten_from_ltc_tensor: str
410
+ tuple_aten_from_ltc_tensors: str
411
+ lazy_tensor_ptr: str
412
+ get_device_fn: str
413
+
414
+ def lazy_tensor_decls(self, func: NativeFunction, schema: LazyIrSchema) -> str:
415
+ value_args = schema.filtered_args(values=True, scalars=False)
416
+ # Generates lazy_{name} variables for LazyTensors wrapping input tensors
417
+ lazy_tensor_decls: List[str] = []
418
+ for arg in value_args:
419
+ if arg.is_wrapped_scalar:
420
+ if isinstance(arg.lazy_type, OptionalCType):
421
+ lazy_tensor_decls.append(
422
+ f"""auto node_{arg.name} = {arg.name} ?
423
+ c10::make_optional(torch::lazy::LazyGraphExecutor::Get()->
424
+ GetIrValueForScalarFromCodegen(*{arg.name}, *common_device)):
425
+ c10::nullopt;"""
426
+ )
427
+ else:
428
+ lazy_tensor_decls.append(
429
+ f"""auto node_{arg.name} = torch::lazy::LazyGraphExecutor::Get()->
430
+ GetIrValueForScalarFromCodegen({arg.name}, *common_device);"""
431
+ )
432
+ elif arg.is_symint_or_list:
433
+ continue # values are extracted in isValueType
434
+ elif isinstance(arg.lazy_type, BaseCType):
435
+ if arg.lazy_type.type is tensorListValueT:
436
+ lazy_tensor_decls.append(
437
+ f"auto lazy_{arg.name}_tensorlist = "
438
+ f"{self.backend_namespace}::{self.get_tensorlist}({arg.name});"
439
+ )
440
+ else:
441
+ lazy_tensor_decls.append(
442
+ f"{self.lazy_tensor_ptr} lazy_{arg.name} = "
443
+ f"{self.backend_namespace}::{self.get_tensor_or_wrap_number}({arg.name}, *common_device);"
444
+ )
445
+ elif isinstance(arg.lazy_type, OptionalCType):
446
+ assert arg.lazy_type.elem == BaseCType(getValueT()), arg.lazy_type.elem
447
+ # TODO(alanwaketan): Maybe we want to apply GetLtcTensorOrCreateForWrappedNumber here, but hold it
448
+ # until we encounter a real world example.
449
+ lazy_tensor_decls.append(
450
+ f"{self.lazy_tensor_ptr} lazy_{arg.name} = "
451
+ f"{self.backend_namespace}::{self.try_get_tensor}({arg.name}.value_or(at::Tensor()));"
452
+ )
453
+ else:
454
+ raise AssertionError(
455
+ f"TODO not sure if there are other valid types to handle here ({arg.lazy_type})"
456
+ )
457
+ return ("\n ").join(lazy_tensor_decls)
458
+
459
+ def force_eager_fallback(
460
+ self,
461
+ func: NativeFunction,
462
+ schema: LazyIrSchema,
463
+ metadata: BackendMetadata,
464
+ sig: Union[DispatcherSignature, NativeSignature],
465
+ ) -> str:
466
+ if self.gen_forced_fallback_code:
467
+ return gen_fallback_code(
468
+ schema, sig, overload_name=func.func.name.overload_name
469
+ )
470
+ return ""
471
+
472
+ def metrics(self, func: NativeFunction, schema: LazyIrSchema) -> str:
473
+ return f"{self.metrics_counter};"
474
+
475
+ def get_device(self, func: NativeFunction, schema: LazyIrSchema) -> str:
476
+ value_args = schema.filtered_args(values=True, scalars=False)
477
+ scalar_args = schema.filtered_args(values=False, scalars=True)
478
+ value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar]
479
+ optional_device = OptionalCType(BaseCType(deviceT))
480
+ optional_devices = [
481
+ a.name for a in scalar_args if a.lazy_type == optional_device
482
+ ]
483
+ assert (
484
+ len(value_types_names) > 0 or len(optional_devices) > 0
485
+ ), "Expected at least one Value or Device type"
486
+ get_device_str = (
487
+ f"{self.get_device_fn}({', '.join(value_types_names + optional_devices)})"
488
+ )
489
+ return f"""auto common_device = {get_device_str};
490
+ TORCH_INTERNAL_ASSERT(common_device);
491
+ """
492
+
493
+ def shape_inference(self, func: NativeFunction, schema: LazyIrSchema) -> str:
494
+ metadata = self.backend_index.get_kernel(func)
495
+ assert metadata is not None
496
+ all_args = schema.filtered_args()
497
+ returns_length = len(schema.returns)
498
+ # call the meta kernel if it exists, to compute output shape/dtype for our IR
499
+ # Note [Generated LTC Shape Functions]
500
+ # LTC uses meta tensors from core to do shape inference when possible, and otherwise
501
+ # we generate a shape function declaration that needs to be manually implemented.
502
+ # How do we detect which ops are eligible to use meta tensors?
503
+ # In general we should be able to use meta tensors not just on structured operators,
504
+ # but also on composite operators that are implemented in terms of structured kernels.
505
+ # We don't currently have a way of knowing at codegen time which ops are implemented that way.
506
+ # This is the case for all view and view_copy operators however, so we're going to
507
+ # use them specifically for all of the view_copy ops (instead of manually writing shape rules for all of them).
508
+ is_view_copy_op = "view_copy" in func.tags
509
+ is_structured = func.structured or func.structured_delegate is not None
510
+ if is_structured or is_view_copy_op:
511
+ meta_out = """
512
+ std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};"""
513
+ if returns_length > 1:
514
+
515
+ def this_shape(i: int) -> str:
516
+ return f"torch::lazy::Shape(std::get<{i}>(out_meta).scalar_type(), std::get<{i}>(out_meta).sizes().vec())"
517
+
518
+ shapes_str = ",".join([this_shape(i) for i in range(returns_length)])
519
+ meta_out = "std::vector<torch::lazy::Shape> shapes{" + shapes_str + "};"
520
+
521
+ # Convert tensor args to the meta device and call it.
522
+ # (We can't pass in the input tensors directly, because they are "functional wrappers".
523
+ # If any of the meta kernels call a tensor op and redispatch, we don't want to hit the functionalize kernels.)
524
+ # Even at::meta:: functions might redispatch, e.g. if they call into view ops.
525
+ dispatcher_sig = DispatcherSignature.from_schema(func.func)
526
+ meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig)
527
+ meta_call_args = [
528
+ e.expr
529
+ for e in translate(
530
+ meta_call_ctx, dispatcher_sig.arguments(), method=False
531
+ )
532
+ ]
533
+ if is_view_copy_op:
534
+ # view_copy ops always have a CompositeExplicitAutogradNonFunctional kernel
535
+ assert func.has_composite_explicit_autograd_non_functional_kernel
536
+ dispatch_ns = "compositeexplicitautogradnonfunctional"
537
+ else:
538
+ dispatch_ns = "meta"
539
+ aten_name = schema.aten_name
540
+ # TODO: this is trolling
541
+ if func.func.has_symint() and metadata.supports_symint():
542
+ aten_name += "_symint"
543
+ shape_str = f"""\
544
+ {meta_conversion_str}
545
+ auto out_meta = at::{dispatch_ns}::{aten_name}({', '.join(meta_call_args)});
546
+ {meta_out}"""
547
+ else:
548
+ shape_sig = ComputeShapeSignature(
549
+ metadata.kernel, func, symint=metadata.supports_symint()
550
+ )
551
+ shape_str = f"""
552
+ auto shapes = {shape_sig.shape_call};"""
553
+
554
+ shape_str += f"""
555
+ TORCH_INTERNAL_ASSERT(shapes.size() == {returns_length});"""
556
+
557
+ # Calculating which dimensions are symbolic
558
+ func_schema_str = "aten::" + str(func.func)
559
+ shape_str += f"""
560
+ if(torch::lazy::symbolicShapeEnabled()){{
561
+ std::vector<torch::jit::IValue> inputs = {{ {', '.join(str(a.name) for a in all_args)} }};
562
+ const char* schema_str = "{func_schema_str}";
563
+ applySymbolicShapesOnLT(schema_str, inputs, shapes);
564
+ }}
565
+ """
566
+ return shape_str
567
+
568
+ def build_ir_node(self, func: NativeFunction, schema: LazyIrSchema) -> str:
569
+ node_ctor_input_str = node_ctor_inputs(schema)
570
+ return f"""torch::lazy::NodePtr node = torch::lazy::ReuseNode<{schema.node_name}>({node_ctor_input_str});
571
+ if (!node) {{
572
+ {self.shape_inference(func, schema)}
573
+ node = torch::lazy::MakeNode<{schema.node_name}>({node_ctor_input_str}, std::move(shapes));
574
+ CacheNode(node);
575
+ }}
576
+ """
577
+
578
+ def create_lazy_tensor(self, first_tensor_name: Optional[str] = None) -> str:
579
+ # xla uses an instance method for tensor creation, for the time being
580
+ if self.create_from_first_tensor:
581
+ # TODO(whc) remove this if XLA switches to using static method for creation
582
+ assert (
583
+ first_tensor_name is not None
584
+ ), "Requires first tensor to create lazy tensor"
585
+ return f"{first_tensor_name}.{self.create_tensor}"
586
+ return f"{self.backend_namespace}::{self.create_tensor}"
587
+
588
+ def return_aten_tensor(self, func: NativeFunction, schema: LazyIrSchema) -> str:
589
+ returns_length = len(schema.returns)
590
+ value_args = schema.filtered_args(values=True, scalars=False)
591
+ value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar]
592
+ first_tensor_name = value_types_names[0] if len(value_types_names) > 0 else None
593
+ bridge_str = f"""auto result = {self.create_aten_from_ltc_tensor}(
594
+ {self.create_lazy_tensor(first_tensor_name)}(std::move(node), *common_device));"""
595
+
596
+ if returns_length > 1:
597
+ assert (
598
+ len(value_types_names) > 0
599
+ ), "Code below assumes there is at least one tensor arg"
600
+ bridge_str = f"""std::vector<{self.lazy_tensor_ptr}> lazy_tensors;
601
+ for (int i = 0; i < {returns_length}; i++) {{
602
+ lazy_tensors.push_back({self.create_lazy_tensor(first_tensor_name)}({getValueT()}(node, i), *common_device));
603
+ }}
604
+ auto result = {self.tuple_aten_from_ltc_tensors}<{returns_length}>(lazy_tensors);"""
605
+
606
+ if schema.name.name.inplace or func.func.is_out_fn():
607
+ assert returns_length == 1, (
608
+ "We assumed there was no such case where an op is an in-place variant "
609
+ f"and has tuple outputs, but got tuple of len {returns_length}."
610
+ )
611
+ bridge_str = f"""lazy_{first_tensor_name}->SetInPlaceIrValue(node);
612
+ auto& result = {first_tensor_name};"""
613
+
614
+ bridge_str += """
615
+ return result;"""
616
+ return bridge_str
617
+
618
+ @method_with_native_function
619
+ def __call__(self, func: NativeFunction) -> List[str]:
620
+ sig = kernel_signature(func, self.backend_index)
621
+ metadata = self.backend_index.get_kernel(func)
622
+ assert metadata is not None
623
+ schema = LazyIrSchema(func.func, symint=metadata.supports_symint())
624
+ return [
625
+ f"""\
626
+ {sig.decl(name=f"{self.class_method_name}::{metadata.kernel}")} {{
627
+ {self.force_eager_fallback(func, schema, metadata, sig)}
628
+ {self.metrics(func, schema)}
629
+ {self.get_device(func, schema)}
630
+ {self.lazy_tensor_decls(func, schema)}
631
+ {self.build_ir_node(func, schema)}
632
+ {self.return_aten_tensor(func, schema)}
633
+ }}\n
634
+ """
635
+ ]
636
+
637
+
638
+ class ComputeShapeSignature:
639
+ """
640
+ Here we use the base name as the suffix of the signature to avoid generating for in-place variants.
641
+ """
642
+
643
+ def __init__(self, kernel_name: str, f: NativeFunction, *, symint: bool):
644
+ self.__schema = LazyIrSchema(f.func, symint=symint)
645
+ self.__dispatch_args = ", ".join(
646
+ [a.decl() for a in dispatcher.arguments(f.func, symint=symint)]
647
+ )
648
+ self.__call_args = ", ".join(
649
+ [f"{arg.name}" for arg in self.__schema.filtered_args(generator=True)]
650
+ )
651
+ self.__kernel_name = kernel_name
652
+
653
+ def __decl_suffix(self) -> str:
654
+ return f"{self.__kernel_name}({self.__dispatch_args})"
655
+
656
+ def __call_suffix(self) -> str:
657
+ return f"{self.__kernel_name}({self.__call_args})"
658
+
659
+ @property
660
+ def shape_decl(self) -> str:
661
+ return f"TORCH_API std::vector<torch::lazy::Shape> compute_shape_{self.__decl_suffix()}"
662
+
663
+ @property
664
+ def shape_call(self) -> str:
665
+ return f"torch::lazy::compute_shape_{self.__call_suffix()}"
666
+
667
+
668
+ @dataclass(frozen=True)
669
+ class GenLazyShapeInferenceDefinition:
670
+ backend_index: BackendIndex
671
+ tensor_class: str
672
+
673
+ @method_with_native_function
674
+ def __call__(self, f: NativeFunction) -> List[str]:
675
+ sig = kernel_signature(f, self.backend_index)
676
+ metadata = self.backend_index.get_kernel(f)
677
+ assert metadata is not None
678
+
679
+ # See Note [Generated LTC Shape Functions]
680
+ is_view_copy_op = "view_copy" in f.tags
681
+ is_structured = f.structured or f.structured_delegate is not None
682
+ if is_structured or is_view_copy_op:
683
+ return []
684
+ else:
685
+ shape_sig = ComputeShapeSignature(
686
+ metadata.kernel, f, symint=metadata.supports_symint()
687
+ )
688
+ return ["\n".join([f"{shape_sig.shape_decl};"])]
689
+
690
+
691
+ def generate_non_native_lazy_ir_nodes(
692
+ non_native: List[Dict[str, Any]], gen_lazy_ir: GenLazyIR
693
+ ) -> List[str]:
694
+ """Generate the non-native lazy IR node classes"""
695
+ nodes = []
696
+ for op in non_native:
697
+ # Set default properties for Non-Native IRs
698
+ properties = LazyIrProperties("ShapeCache", "CanBeReused", "LowerDeclOnly")
699
+ for p in op.get("properties", []):
700
+ setattr(properties, p, True)
701
+
702
+ # non-native is assumed to want symint bindings if you wrote symint
703
+ schema = LazyIrSchema(FunctionSchema.parse(op["func"]), properties, symint=True)
704
+ schema.opkind = op.get("opkind")
705
+ nodes.append(gen_lazy_ir.gen(schema)[0])
706
+
707
+ return nodes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/lazy_ts_lowering.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchgen.api.lazy import LazyArgument, LazyIrSchema
2
+ from torchgen.api.types import OptionalCType
3
+
4
+
5
+ def ts_lowering_body(schema: LazyIrSchema) -> str:
6
+ # for now, we just want one IR class decl and soon after also the method defs
7
+ # and we use the functional version not out/inplace.
8
+ emplace_arguments = []
9
+
10
+ def get_value(arg: LazyArgument) -> str:
11
+ if isinstance(arg.lazy_type, OptionalCType):
12
+ return f"has_{arg.name} ? loctx->GetOutputOp(operand(i++)) : nullptr"
13
+ return "loctx->GetOutputOp(operand(i++))"
14
+
15
+ for arg in schema.positional_args:
16
+ if arg.is_lazy_value:
17
+ emplace_arguments.append(get_value(arg))
18
+ continue
19
+ emplace_arguments.append(f'"{arg.name}", {arg.name}')
20
+
21
+ emplace_arguments_str = "\n ".join(
22
+ [f"arguments.emplace_back({a});" for a in emplace_arguments]
23
+ )
24
+ emplace_kwarg_values = [
25
+ f'"{arg.name}", {get_value(arg)}' for arg in schema.keyword_values
26
+ ]
27
+ emplace_kwarg_scalars = [
28
+ f'"{arg.name}", {arg.name}' for arg in schema.keyword_scalars
29
+ ]
30
+ emplace_kwarguments = "\n ".join(
31
+ [
32
+ f"kwarguments.emplace_back({a});"
33
+ for a in emplace_kwarg_values + emplace_kwarg_scalars
34
+ ]
35
+ )
36
+ return f"""\
37
+ std::vector<torch::jit::NamedValue> arguments;
38
+ std::vector<torch::jit::NamedValue> kwarguments;
39
+ arguments.reserve({len(emplace_arguments)});
40
+ kwarguments.reserve({len(emplace_kwarg_values + emplace_kwarg_scalars)});
41
+ size_t i = 0;
42
+ {emplace_arguments_str}
43
+ {emplace_kwarguments}
44
+ torch::lazy::TSOpVector {schema.aten_name}_out = torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments);
45
+ TORCH_CHECK_EQ({schema.aten_name}_out.size(), {len(schema.returns)});
46
+
47
+ return {schema.aten_name}_out;
48
+ """
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/native_functions.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union
2
+
3
+ import torchgen.api.meta as meta
4
+ import torchgen.api.structured as structured
5
+ from torchgen.api.types import kernel_signature
6
+
7
+ from torchgen.context import with_native_function_and_index
8
+ from torchgen.model import BackendIndex, NativeFunction, NativeFunctionsGroup
9
+ from torchgen.utils import mapMaybe
10
+
11
+
12
+ @with_native_function_and_index
13
+ def gen_unstructured(f: NativeFunction, backend_index: BackendIndex) -> Optional[str]:
14
+ sig = kernel_signature(f, backend_index)
15
+ metadata = backend_index.get_kernel(f)
16
+ if metadata is None:
17
+ return None
18
+ if "legacy::" in metadata.kernel:
19
+ return None
20
+ else:
21
+ prefix = "static" if backend_index.external else "TORCH_API"
22
+ return f"{prefix} {sig.decl(name=metadata.kernel)};"
23
+
24
+
25
+ @with_native_function_and_index
26
+ def gen_structured(g: NativeFunctionsGroup, backend_index: BackendIndex) -> List[str]:
27
+ meta_name = meta.name(g)
28
+ out_args = structured.impl_arguments(g)
29
+ metadata = backend_index.get_kernel(g)
30
+ if metadata is None:
31
+ return []
32
+ prefix = "" if backend_index.external else "TORCH_API "
33
+ return [
34
+ f"""\
35
+ struct {prefix}structured_{metadata.kernel} : public at::meta::structured_{meta_name} {{
36
+ void impl({', '.join(a.decl() for a in out_args)});
37
+ }};
38
+ """
39
+ ]
40
+
41
+
42
+ # Generates NativeFunctions.h, a list of forward declarations of all
43
+ # actual kernel definitions we keep in aten/src/ATen/native/
44
+ @with_native_function_and_index
45
+ def compute_native_function_declaration(
46
+ g: Union[NativeFunctionsGroup, NativeFunction], backend_index: BackendIndex
47
+ ) -> List[str]:
48
+ metadata = backend_index.get_kernel(g)
49
+ if isinstance(g, NativeFunctionsGroup):
50
+ if metadata is not None and metadata.structured:
51
+ if backend_index.external:
52
+ # Structured hasn't been tested with external backends yet.
53
+ raise AssertionError(
54
+ "Structured external backend functions are not implemented yet."
55
+ )
56
+ else:
57
+ return gen_structured(g, backend_index)
58
+ else:
59
+ return list(
60
+ mapMaybe(lambda f: gen_unstructured(f, backend_index), g.functions())
61
+ )
62
+ else:
63
+ x = gen_unstructured(g, backend_index)
64
+ return [] if x is None else [x]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/register_dispatch_key.py ADDED
@@ -0,0 +1,989 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import textwrap
3
+ from dataclasses import dataclass
4
+ from typing import List, Literal, Optional, Tuple, Union
5
+
6
+ import torchgen.api.cpp as cpp
7
+ import torchgen.api.meta as meta
8
+ import torchgen.api.structured as structured
9
+ from torchgen.api.translate import translate
10
+ from torchgen.api.types import (
11
+ BaseCType,
12
+ Binding,
13
+ ConstRefCType,
14
+ CppSignature,
15
+ CppSignatureGroup,
16
+ DispatcherSignature,
17
+ Expr,
18
+ kernel_signature,
19
+ MutRefCType,
20
+ NamedCType,
21
+ NativeSignature,
22
+ tensorT,
23
+ )
24
+
25
+ from torchgen.context import method_with_native_function, native_function_manager
26
+ from torchgen.model import (
27
+ Argument,
28
+ BackendIndex,
29
+ DeviceCheckType,
30
+ DispatchKey,
31
+ gets_generated_out_inplace_wrapper,
32
+ is_cuda_dispatch_key,
33
+ NativeFunction,
34
+ NativeFunctionsGroup,
35
+ SchemaKind,
36
+ TensorOptionsArguments,
37
+ )
38
+ from torchgen.selective_build.selector import SelectiveBuilder
39
+ from torchgen.utils import assert_never, mapMaybe, Target
40
+
41
+
42
+ def gen_registration_headers(
43
+ backend_index: BackendIndex,
44
+ per_operator_headers: bool,
45
+ rocm: bool,
46
+ ) -> List[str]:
47
+ if per_operator_headers:
48
+ headers = ["#include <ATen/ops/as_strided_native.h>"]
49
+ else:
50
+ headers = ["#include <ATen/NativeFunctions.h>"]
51
+
52
+ if backend_index.dispatch_key in (DispatchKey.CPU, DispatchKey.Meta):
53
+ headers.append("#include <ATen/EmptyTensor.h>")
54
+ elif backend_index.dispatch_key == DispatchKey.CUDA:
55
+ if rocm:
56
+ headers.append("#include <ATen/hip/EmptyTensor.h>")
57
+ else:
58
+ headers.append("#include <ATen/cuda/EmptyTensor.h>")
59
+ elif backend_index.dispatch_key == DispatchKey.MPS:
60
+ headers.append("#include <ATen/mps/EmptyTensor.h>")
61
+ elif per_operator_headers:
62
+ headers += [
63
+ "#include <ATen/ops/empty.h>",
64
+ "#include <ATen/ops/empty_strided.h>",
65
+ "#include <ATen/ops/_copy_from_and_resize.h>",
66
+ "#include <ATen/ops/_copy_from.h>",
67
+ ]
68
+ else:
69
+ headers.append("#include <ATen/Functions.h>")
70
+
71
+ return headers
72
+
73
+
74
+ def gen_empty_impl_names(
75
+ backend_index: BackendIndex,
76
+ ) -> Tuple[Optional[str], Optional[str]]:
77
+ empty_impl = None
78
+ empty_strided_impl = None
79
+
80
+ if backend_index.dispatch_key in (
81
+ DispatchKey.Meta,
82
+ DispatchKey.CPU,
83
+ DispatchKey.CUDA,
84
+ DispatchKey.MPS,
85
+ ):
86
+ dispatch = str(backend_index.dispatch_key).lower()
87
+ empty_impl = f"at::detail::empty_{dispatch}"
88
+ empty_strided_impl = f"at::detail::empty_strided_{dispatch}"
89
+ elif backend_index.dispatch_key in (
90
+ DispatchKey.CompositeExplicitAutogradNonFunctional,
91
+ DispatchKey.QuantizedCPU,
92
+ DispatchKey.QuantizedCUDA,
93
+ ):
94
+ empty_impl = "at::empty"
95
+ empty_strided_impl = "at::empty_strided"
96
+
97
+ return empty_impl, empty_strided_impl
98
+
99
+
100
+ def gen_create_out_helper(backend_index: BackendIndex) -> List[str]:
101
+ if backend_index.dispatch_key == DispatchKey.Meta:
102
+ empty_options = "options.device(at::kMeta)"
103
+ else:
104
+ empty_options = "options"
105
+
106
+ empty_impl, empty_strided_impl = gen_empty_impl_names(backend_index)
107
+ if empty_impl is None:
108
+ return []
109
+
110
+ return [
111
+ f"""
112
+ Tensor create_out(IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {{
113
+ if (strides.empty()) {{
114
+ return {empty_impl}(sizes, {empty_options});
115
+ }} else {{
116
+ return {empty_strided_impl}(sizes, strides, {empty_options});
117
+ }}
118
+ }}
119
+ """
120
+ ]
121
+
122
+
123
+ def gen_maybe_create_proxy_helper(backend_index: BackendIndex) -> List[str]:
124
+ _, empty_strided_impl = gen_empty_impl_names(backend_index)
125
+ return (
126
+ []
127
+ if empty_strided_impl is None
128
+ else [
129
+ f"""
130
+ c10::optional<Tensor> maybe_create_proxy(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {{
131
+ if (out.strides() != strides) {{
132
+ return {empty_strided_impl}(sizes, strides, options);
133
+ }}
134
+ return c10::nullopt;
135
+ }}
136
+ """
137
+ ]
138
+ )
139
+
140
+
141
+ def gen_resize_out_helper(backend_index: BackendIndex) -> List[str]:
142
+ if backend_index.dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional:
143
+ # The function isn't used by this key (since only functional ops have a kernel for this key),
144
+ # so we need to not include it to avoid a defined-but-not-used error.
145
+ return []
146
+ return [
147
+ """
148
+ void resize_out(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {
149
+ TORCH_CHECK(options.dtype() == out.dtype(),
150
+ "Expected out tensor to have dtype ", options.dtype(), ", but got ", out.dtype(), " instead");
151
+ TORCH_CHECK(options.device() == out.device(),
152
+ "Expected out tensor to have device ", options.device(), ", but got ", out.device(), " instead");
153
+ const bool resized = at::native::resize_output(out, sizes);
154
+ // Only restride if a resize occurred; otherwise we ignore the (advisory)
155
+ // strides from the meta function and directly use the output tensor's
156
+ // preexisting strides
157
+ if (resized) {
158
+ if (!strides.empty()) {
159
+ TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value());
160
+ // TODO: avoid the redispatch here
161
+ out.as_strided_(sizes, strides);
162
+ } else if (options.memory_format_opt().has_value()) {
163
+ out.unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt());
164
+ }
165
+ }
166
+ }
167
+ """
168
+ ]
169
+
170
+
171
+ def gen_check_inplace_helper(backend_index: BackendIndex) -> List[str]:
172
+ return [
173
+ """
174
+ void check_inplace(const Tensor &self, IntArrayRef sizes, const TensorOptions &options) {
175
+ // These checks are needed on those operators that:
176
+ // 1) don't use 'TensorIterator' (e.g. 'addmm' and 'baddbmm')
177
+ // 2) have particular typing rules (e.g. 'cumsum' and 'cumprod')
178
+ // For other operators (e.g. 'add'), 'TensorIterator' already checks
179
+ // these things separately.
180
+ TORCH_CHECK(options.dtype() == self.dtype(),
181
+ "Bad in-place call: ",
182
+ "input tensor dtype ", self.dtype(), " and output tensor dtype ", options.dtype(), " should match");
183
+ TORCH_CHECK(options.device() == self.device(),
184
+ "Bad in-place call: ",
185
+ "input tensor device ", self.device(), " and output tensor device ", options.device(), " should match");
186
+ TORCH_CHECK(sizes == self.sizes(),
187
+ "Bad in-place call: ",
188
+ "input tensor size ", self.sizes(), " and output tensor size ", sizes, " should match");
189
+ }
190
+ """
191
+ ]
192
+
193
+
194
+ def gen_registration_helpers(backend_index: BackendIndex) -> List[str]:
195
+ return [
196
+ *gen_create_out_helper(backend_index),
197
+ *gen_resize_out_helper(backend_index),
198
+ *gen_check_inplace_helper(backend_index),
199
+ *gen_maybe_create_proxy_helper(backend_index),
200
+ ]
201
+
202
+
203
+ # Generates Register{dispatch}.cpp (e.g., RegisterCPU.cpp).
204
+ #
205
+ # - The primary function of this file is to register all of the
206
+ # implementations for the given dispatch key to the dispatcher,
207
+ # so they are available for use in PyTorch. If dispatch is
208
+ # None, we generate schema (def) registrations and catchall
209
+ # registrations.
210
+ # - The secondary function of this file is to generate a wrapper
211
+ # around functions. In CPUType these wrappers do nothing
212
+ # (and should be removed), but in other cases they handle
213
+ # DeviceGuard. A small extra benefit of wrappers is they
214
+ # are not overloaded, so they can be used in the registration
215
+ # API without having to disambiguate which overload you want
216
+ # (as would be the case if you directly registered native::
217
+ # functions).
218
+ # - The tertiary function of this file is to generate *static*
219
+ # cpp API bindings which can be used to bypass dispatcher
220
+ # directly to kernels, but with user-friendly cpp-style API
221
+ @dataclass(frozen=True)
222
+ class RegisterDispatchKey:
223
+ backend_index: BackendIndex
224
+
225
+ target: Literal[
226
+ Target.ANONYMOUS_DEFINITION,
227
+ Target.NAMESPACED_DEFINITION,
228
+ Target.NAMESPACED_DECLARATION,
229
+ Target.REGISTRATION,
230
+ ]
231
+
232
+ # Selector object to determine which operators to generate
233
+ # registration code for.
234
+ selector: SelectiveBuilder
235
+
236
+ # Whether or not we are actually code-genning for ROCm
237
+ rocm: bool
238
+
239
+ # Whether or not to generate symint registrations or not. External users
240
+ # of codegen who don't care about symints can set this to false to get
241
+ # non-SymInt codegen
242
+ symint: bool
243
+
244
+ # The class that all unstructured native functions live under. This is used to improve
245
+ # compiler error messages when a kernel writer adds a native function with the wrong signature.
246
+ # This is only used in unstructured kernels, since structured kernels already live in a class.
247
+ # Finally, this field is currently Optional because it is only used by external backends.
248
+ # It would be nice if we can add the same logic to in-tree kernels too, but that requires updating
249
+ # all of the existing kernel signatures scattered across aten/src/ATen/native.
250
+ class_method_name: Optional[str]
251
+
252
+ # Only set to true in lightweight dispatch. If lightweight dispatch is enabled we are registering
253
+ # operators into JIT op registry, thus we need to avoid generating code to register into the dispatcher.
254
+ skip_dispatcher_op_registration: bool
255
+
256
+ @staticmethod
257
+ def gen_device_check(
258
+ type: DeviceCheckType, args: List[Argument], method_name: str
259
+ ) -> str:
260
+ if type == DeviceCheckType.NoCheck:
261
+ return " // No device check\n"
262
+
263
+ device_check = "c10::optional<Device> common_device = nullopt;\n"
264
+ device_check += "(void)common_device; // Suppress unused variable warning\n"
265
+ for arg in args:
266
+ # Only tensor like arguments are eligible
267
+ if arg.type.is_tensor_like():
268
+ device_check += f"""
269
+ c10::impl::check_and_update_common_device(common_device, {arg.name}, "{method_name}", "{arg.name}");"""
270
+ return device_check
271
+
272
+ @method_with_native_function
273
+ def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]:
274
+ if isinstance(f, NativeFunctionsGroup):
275
+ g: NativeFunctionsGroup = f
276
+ # Note: We call gen_structured() if the operator is marked structured, regardless of the backend.
277
+ # gen_structured() has special logic to handle auto-generated kernels.
278
+ if g.structured:
279
+ return self.gen_structured(g)
280
+ else:
281
+ return list(
282
+ mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions())
283
+ )
284
+ elif isinstance(f, NativeFunction):
285
+ r = self.gen_unstructured(f)
286
+ return [] if r is None else [r]
287
+ else:
288
+ assert_never(f)
289
+
290
+ def wrapper_kernel_sig(
291
+ self, f: NativeFunction
292
+ ) -> Union[NativeSignature, DispatcherSignature]:
293
+ # The prefix is just to ensure uniqueness. The Dispatcher API doesn't guarantee unique kernel names.
294
+ return DispatcherSignature.from_schema(
295
+ f.func,
296
+ prefix=f"wrapper_{self.backend_index.dispatch_key}_{f.func.name.overload_name}_",
297
+ symint=self.symint,
298
+ )
299
+
300
+ def gen_out_inplace_wrapper(
301
+ self, f: NativeFunction, g: Optional[NativeFunctionsGroup]
302
+ ) -> Optional[str]:
303
+ if g is None:
304
+ return None
305
+ k = f.func.kind()
306
+ if k is SchemaKind.inplace:
307
+ copy_op = "at::_copy_from"
308
+ elif k is SchemaKind.out:
309
+ copy_op = "at::_copy_from_and_resize"
310
+ else:
311
+ raise AssertionError("gen_out_inplace_wrapper called on a functional op")
312
+
313
+ sig = self.wrapper_kernel_sig(f)
314
+ name = sig.name()
315
+
316
+ func_res = f"{name}_tmp"
317
+ return_names = cpp.return_names(f)
318
+ if len(return_names) > 1:
319
+ updates = "\n ".join(
320
+ f"{copy_op}(std::get<{i}>({func_res}), {ret_name});"
321
+ for i, ret_name in enumerate(return_names)
322
+ )
323
+ returns = f'{sig.returns_type().cpp_type()}({", ".join(return_names)})'
324
+ elif len(return_names) == 1:
325
+ ret_name = return_names[0]
326
+ updates = f"{copy_op}({func_res}, {ret_name});"
327
+ returns = ret_name
328
+ else:
329
+ assert len(f.func.arguments.out) == 1
330
+ returns = ""
331
+ out_arg = f.func.arguments.out[0]
332
+ if out_arg.type.is_list_like():
333
+ updates = f"""\
334
+ for (int64_t i = 0; i < {func_res}.size(); ++i) {{
335
+ {copy_op}({func_res}[i], {out_arg.name}[i]);
336
+ }}"""
337
+ else:
338
+ updates = f"{copy_op}({func_res}, {out_arg.name});"
339
+
340
+ functional_sig = self.wrapper_kernel_sig(g.functional)
341
+ wrapper_name = sig.name()
342
+
343
+ return f"""\
344
+ {sig.defn(name=wrapper_name)} {{
345
+ auto {func_res} = {functional_sig.name()}({", ".join(e.expr for e in translate(sig.arguments(), functional_sig.arguments()))});
346
+ {updates}
347
+ return {returns};
348
+ }}
349
+ """
350
+
351
+ def gen_structured(self, g: NativeFunctionsGroup) -> List[str]:
352
+ metadata = self.backend_index.get_kernel(g)
353
+ if self.backend_index.dispatch_key == DispatchKey.Meta:
354
+ assert not self.backend_index.has_kernel(g.out), (
355
+ "Do not explicitly specify Meta dispatch key on structured "
356
+ "functions, they will be automatically generated for you"
357
+ )
358
+ elif (
359
+ self.backend_index.dispatch_key
360
+ == DispatchKey.CompositeExplicitAutogradNonFunctional
361
+ ):
362
+ assert not self.backend_index.has_kernel(g.out), (
363
+ "Do not explicitly specify CompositeExplicitAutograd dispatch key on structured "
364
+ "functions, they will be automatically generated for you"
365
+ )
366
+ elif metadata is None or not metadata.structured:
367
+ return list(mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions()))
368
+ structured_gen = StructuredRegisterDispatchKey(
369
+ self.backend_index,
370
+ self.target,
371
+ self.selector,
372
+ self.rocm,
373
+ self.symint,
374
+ self.class_method_name,
375
+ self.skip_dispatcher_op_registration,
376
+ g,
377
+ )
378
+ return list(mapMaybe(structured_gen.gen_one, g.functions()))
379
+
380
+ def gen_unstructured(
381
+ self, f: NativeFunction, g: Optional[NativeFunctionsGroup] = None
382
+ ) -> Optional[str]:
383
+ with native_function_manager(f):
384
+ inplace_meta = False
385
+ gets_out_inplace_wrapper = False
386
+ if not self.backend_index.has_kernel(f):
387
+ if (
388
+ self.backend_index.dispatch_key == DispatchKey.Meta
389
+ and f.func.kind() is SchemaKind.inplace
390
+ and
391
+ # Defer to composites for meta implementation
392
+ not f.has_composite_kernel
393
+ and
394
+ # Inplace list operations are not supported
395
+ len(f.func.returns) == 1
396
+ ):
397
+ inplace_meta = True
398
+ elif (
399
+ not self.backend_index.use_out_as_primary
400
+ and g is not None
401
+ and gets_generated_out_inplace_wrapper(f, g, self.backend_index)
402
+ ):
403
+ # We want to generate inplace/out wrappers, that don't have a kernel for the backend.
404
+ gets_out_inplace_wrapper = True
405
+ else:
406
+ return None
407
+ if f.manual_kernel_registration:
408
+ return None
409
+
410
+ if (
411
+ self.target is Target.REGISTRATION
412
+ and not self.selector.is_native_function_selected(f)
413
+ ):
414
+ return None
415
+
416
+ sig = self.wrapper_kernel_sig(f)
417
+
418
+ name = sig.name()
419
+ returns_type = sig.returns_type().cpp_type()
420
+ args = sig.arguments()
421
+ args_str = ", ".join(a.defn() for a in args)
422
+
423
+ # See Note [Direct dispatch bindings]
424
+ cpp_sig_group = CppSignatureGroup.from_native_function(
425
+ f, method=False, fallback_binding=False
426
+ )
427
+
428
+ # TODO: dedupe this with the structured codegen
429
+ if self.target is Target.NAMESPACED_DECLARATION:
430
+ result = ""
431
+ for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
432
+ result += f"TORCH_API {cpp_sig.decl()};\n"
433
+ return result
434
+ elif self.target is Target.NAMESPACED_DEFINITION:
435
+
436
+ def generate_defn(cpp_sig: CppSignature) -> str:
437
+ return f"""
438
+ {cpp_sig.defn()} {{
439
+ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
440
+ }}
441
+ """
442
+
443
+ result = ""
444
+ for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
445
+ result += generate_defn(cpp_sig)
446
+ return result
447
+
448
+ elif self.target is Target.ANONYMOUS_DEFINITION:
449
+ # short circuit for inplace_meta
450
+ if inplace_meta:
451
+ assert f.func.arguments.self_arg is not None
452
+ self_arg_name = f.func.arguments.self_arg.argument.name
453
+ # TODO: handle in place on tensor list
454
+ return f"""
455
+ {returns_type} {name}({args_str}) {{
456
+ TORCH_CHECK_NOT_IMPLEMENTED({self_arg_name}.is_meta(),
457
+ "Cannot inplace into non-meta tensor with meta tensor argument");
458
+ return {self_arg_name};
459
+ }}
460
+ """
461
+
462
+ # short circuit for generated inplace/out wrappers
463
+ if gets_out_inplace_wrapper:
464
+ return self.gen_out_inplace_wrapper(f, g)
465
+
466
+ metadata = self.backend_index.get_kernel(f)
467
+ if metadata is None:
468
+ return None
469
+ if self.class_method_name is None:
470
+ impl_name = f"{metadata.cpp_namespace}::{metadata.kernel}"
471
+ else:
472
+ impl_name = f"{metadata.cpp_namespace}::{self.class_method_name}::{metadata.kernel}"
473
+
474
+ kernel_sig = kernel_signature(f, self.backend_index)
475
+
476
+ args_exprs_str = ", ".join(
477
+ e.expr
478
+ for e in translate(
479
+ sig.arguments(), kernel_sig.arguments(), method=False
480
+ )
481
+ )
482
+
483
+ device_check = " // No device check\n"
484
+ # Backends that require device guards presumably also require device checks.
485
+ if self.backend_index.device_guard:
486
+ device_check_args = itertools.chain(
487
+ f.func.arguments.out, f.func.arguments.flat_positional
488
+ )
489
+ device_check = RegisterDispatchKey.gen_device_check(
490
+ f.device_check, list(device_check_args), name
491
+ )
492
+
493
+ device_guard = "// DeviceGuard omitted" # default
494
+ if f.device_guard and self.backend_index.device_guard:
495
+ has_tensor_options = any(
496
+ isinstance(a, TensorOptionsArguments)
497
+ for a in f.func.arguments.non_out
498
+ )
499
+ if has_tensor_options:
500
+ # kernel is creating a tensor
501
+ device_guard = """
502
+ const DeviceGuard device_guard(device_or_default(device));"""
503
+
504
+ # CUDA requires special handling
505
+ if is_cuda_dispatch_key(self.backend_index.dispatch_key):
506
+ device_guard = (
507
+ f"globalContext().lazyInitCUDA();\n{device_guard}"
508
+ )
509
+ else:
510
+ # kernel is operating on existing tensors
511
+
512
+ # There is precedence for which argument we use to do
513
+ # device guard. This describes the precedence order.
514
+ self_arg = (
515
+ [f.func.arguments.self_arg.argument]
516
+ if f.func.arguments.self_arg is not None
517
+ else []
518
+ )
519
+ candidate_args = itertools.chain(
520
+ self_arg,
521
+ f.func.arguments.out,
522
+ f.func.arguments.flat_positional,
523
+ )
524
+
525
+ # Only tensor like arguments are eligible
526
+ device_of = next(
527
+ (
528
+ f"{a.name}"
529
+ for a in candidate_args
530
+ if a.type.is_tensor_like()
531
+ ),
532
+ None,
533
+ )
534
+ if device_of is not None:
535
+ device_guard = f"const OptionalDeviceGuard device_guard(device_of({device_of}));"
536
+
537
+ return f"""\
538
+ namespace {{
539
+
540
+ {returns_type} {name}({args_str}) {{
541
+ {device_check}
542
+
543
+ {device_guard}
544
+ return {impl_name}({args_exprs_str});
545
+ }}
546
+
547
+ }} // anonymous namespace
548
+ """
549
+
550
+ elif self.target is Target.REGISTRATION:
551
+ if f.manual_kernel_registration or self.skip_dispatcher_op_registration:
552
+ return None
553
+ else:
554
+ payload = f"TORCH_FN({name})"
555
+ return f'm.impl("{f.func.name}",\n{payload});\n'
556
+ else:
557
+ assert_never(self.target)
558
+
559
+
560
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
561
+ #
562
+ # STRUCTURED
563
+ #
564
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
565
+
566
+
567
+ @dataclass(frozen=True)
568
+ class StructuredRegisterDispatchKey(RegisterDispatchKey):
569
+ g: NativeFunctionsGroup
570
+
571
+ def gen_class_set_output_functions(
572
+ self, k: SchemaKind, parent_class: str, generate_super: bool
573
+ ) -> str:
574
+ if generate_super:
575
+ set_output_super = f"{parent_class}::set_output_raw_strided(output_idx, sizes, strides, options, names);"
576
+ else:
577
+ set_output_super = ""
578
+
579
+ def gen_set_output_function(name: str, maybe_create_proxy: bool) -> str:
580
+ return f"""
581
+ void set_output_{name}(
582
+ int64_t output_idx, IntArrayRef sizes, IntArrayRef strides,
583
+ TensorOptions options, DimnameList names
584
+ ) override {{
585
+ {textwrap.indent(self.gen_class_set_output_body(k, maybe_create_proxy), " ")}
586
+ if (!names.empty()) {{
587
+ namedinference::propagate_names(outputs_[output_idx], names);
588
+ }}
589
+ // super must happen after, so that downstream can use maybe_get_output
590
+ // to retrieve the output
591
+ {textwrap.indent(set_output_super, " ")}
592
+ }}
593
+ """
594
+
595
+ return f"""
596
+ {gen_set_output_function("strided", maybe_create_proxy=True)}
597
+ {gen_set_output_function("raw_strided", maybe_create_proxy=False)}
598
+ """
599
+
600
+ def gen_class_set_output_body(self, k: SchemaKind, maybe_create_proxy: bool) -> str:
601
+ if self.backend_index.dispatch_key in [
602
+ DispatchKey.CUDA,
603
+ DispatchKey.MPS,
604
+ DispatchKey.CompositeExplicitAutogradNonFunctional,
605
+ ]:
606
+ maybe_set_guard = """
607
+ auto current_device = guard_.current_device();
608
+ if (C10_UNLIKELY(current_device.has_value())) {
609
+ TORCH_INTERNAL_ASSERT(*current_device == options.device(),
610
+ "structured kernels don't support multi-device outputs");
611
+ } else {
612
+ guard_.reset_device(options.device());
613
+ }
614
+ """
615
+ maybe_set_guard_line = maybe_set_guard + "\n"
616
+ else:
617
+ maybe_set_guard_line = maybe_set_guard = ""
618
+
619
+ if maybe_create_proxy:
620
+ create_proxy = """
621
+ auto maybe_proxy = maybe_create_proxy(out, sizes, strides, options);
622
+ if (C10_UNLIKELY(maybe_proxy.has_value())) {
623
+ proxy_outputs_[output_idx] = std::move(maybe_proxy).value();
624
+ }
625
+ """
626
+ else:
627
+ create_proxy = ""
628
+
629
+ if k is SchemaKind.functional:
630
+ assert self.backend_index.dispatch_key in (
631
+ DispatchKey.Meta,
632
+ DispatchKey.CPU,
633
+ DispatchKey.CUDA,
634
+ DispatchKey.MPS,
635
+ DispatchKey.CompositeExplicitAutogradNonFunctional,
636
+ )
637
+ return f"""{maybe_set_guard_line}
638
+ outputs_[output_idx] = create_out(sizes, strides, options);"""
639
+ elif k is SchemaKind.inplace:
640
+ return f"""{maybe_set_guard_line}
641
+ const auto& out = outputs_[output_idx].get();
642
+ check_inplace(out, sizes, options);
643
+ {create_proxy}"""
644
+ elif k is SchemaKind.out:
645
+ return f"""{maybe_set_guard_line}
646
+ const auto& out = outputs_[output_idx].get();
647
+ resize_out(out, sizes, strides, options);
648
+ {create_proxy}"""
649
+ elif k is SchemaKind.mutable or k is SchemaKind.scratch:
650
+ raise AssertionError(
651
+ f"{k} structured operators are currently not supported"
652
+ )
653
+ else:
654
+ assert_never(k)
655
+
656
+ # returns the definition of a ctor, as well as how to construct
657
+ # this class to a variable named op
658
+ def gen_class_ctor(self, k: SchemaKind, class_name: str, returns: int) -> str:
659
+ if k is SchemaKind.functional:
660
+ return ""
661
+ elif k is SchemaKind.inplace:
662
+ # TODO: Make sure out argument is guaranteed to be self
663
+ return f"{class_name}(Tensor& self) : outputs_{{std::ref(self)}} {{}}"
664
+ elif k is SchemaKind.out:
665
+ out_args = ", ".join(f"Tensor& out{i}" for i in range(returns))
666
+ out_refs = ", ".join(f"std::ref(out{i})" for i in range(returns))
667
+ return f"{class_name}({out_args}) : outputs_{{ {out_refs} }} {{}}"
668
+ elif k is SchemaKind.mutable or k is SchemaKind.scratch:
669
+ raise AssertionError(
670
+ f"{k} structured operators are currently not supported"
671
+ )
672
+ else:
673
+ assert_never(k)
674
+
675
+ def gen_class(
676
+ self,
677
+ f: NativeFunction,
678
+ k: SchemaKind,
679
+ *,
680
+ class_name: str,
681
+ parent_class: str,
682
+ generate_super: bool,
683
+ ) -> str:
684
+ if k is SchemaKind.functional:
685
+ output_type = "Tensor"
686
+ output_value = "outputs_[output_idx]"
687
+ proxy_field = ""
688
+ elif k is SchemaKind.inplace:
689
+ output_type = "std::reference_wrapper<Tensor>"
690
+ output_value = "proxy_outputs_[output_idx].has_value() ? *proxy_outputs_[output_idx] : outputs_[output_idx].get()"
691
+ proxy_field = f"std::array<c10::optional<Tensor>, {len(f.func.returns)}> proxy_outputs_;"
692
+ elif k is SchemaKind.out:
693
+ output_type = "std::reference_wrapper<Tensor>"
694
+ output_value = "proxy_outputs_[output_idx].has_value() ? *proxy_outputs_[output_idx] : outputs_[output_idx].get()"
695
+ proxy_field = f"std::array<c10::optional<Tensor>, {len(f.func.returns)}> proxy_outputs_;"
696
+
697
+ if self.backend_index.dispatch_key == DispatchKey.CUDA:
698
+ if self.rocm:
699
+ guard_field = "c10::hip::OptionalHIPGuardMasqueradingAsCUDA guard_;"
700
+ else:
701
+ guard_field = "c10::cuda::OptionalCUDAGuard guard_;"
702
+ elif (
703
+ self.backend_index.dispatch_key
704
+ == DispatchKey.CompositeExplicitAutogradNonFunctional
705
+ ):
706
+ guard_field = "c10::OptionalDeviceGuard guard_;"
707
+ elif self.backend_index.dispatch_key == DispatchKey.MPS:
708
+ # TODO: Move to OptionalMPSGuard.
709
+ guard_field = "c10::OptionalDeviceGuard guard_;"
710
+ else:
711
+ guard_field = ""
712
+
713
+ indent = " " * 4
714
+ class_ctor_str = self.gen_class_ctor(k, class_name, len(f.func.returns))
715
+ lines = (
716
+ f"struct {class_name} final : public {parent_class} {{",
717
+ f"{textwrap.indent(class_ctor_str, indent)}",
718
+ f"{textwrap.indent(self.gen_class_set_output_functions(k, parent_class, generate_super), indent)}",
719
+ " const Tensor& maybe_get_output(int64_t output_idx) override {",
720
+ f" return {output_value};\n", # type: ignore[possibly-undefined] # TODO: audit
721
+ " }",
722
+ f" std::array<{output_type}, {len(f.func.returns)}> outputs_;", # type: ignore[possibly-undefined] # TODO: audit
723
+ f"{textwrap.indent(proxy_field, indent)}", # type: ignore[possibly-undefined] # TODO: audit
724
+ f"{textwrap.indent(guard_field, indent)}",
725
+ "};",
726
+ )
727
+ return "\n".join(line for line in lines if line)
728
+
729
+ @method_with_native_function
730
+ def gen_one(self, f: NativeFunction) -> Optional[str]:
731
+ assert not f.manual_kernel_registration
732
+
733
+ if (
734
+ self.target is Target.REGISTRATION
735
+ and not self.selector.is_native_function_selected(f)
736
+ ):
737
+ return None
738
+
739
+ # TODO: Now, there is something interesting going on here. In the code below,
740
+ # we generate CompositeExplicitAutogradNonFunctional implementations of functional and inplace
741
+ # based on the out implementation. But in fact, out is definable by
742
+ # functional too (just not very efficiently), and this is honestly the
743
+ # MORE likely situation for a backend implementor. How do we pick?
744
+ # Well, taking a page from Haskell type classes and default methods,
745
+ # we could conceivably register a circular definition (out in terms
746
+ # of functional, and functional in terms of out) and just require
747
+ # someone to implement one or the other. We'd have to do a little bit
748
+ # of work to not register one of these "weak" definitions unless there
749
+ # is a strong definition somewhere in the DAG! So it's not implemented yet.
750
+ if (
751
+ self.backend_index.dispatch_key
752
+ == DispatchKey.CompositeExplicitAutogradNonFunctional
753
+ and f.func.kind() is SchemaKind.out
754
+ ):
755
+ # Never generate a default implementation for out, that's what you
756
+ # have to define as a backend implementor
757
+ return None
758
+
759
+ # Note [Direct dispatch bindings]
760
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
761
+ # Signature of the non-dispatched function we'll expose in a header
762
+ # (e.g., at::cpu::add). We don't generate methods (TODO: do this
763
+ # when CPUTensor class is a thing); nor do we generate fallback
764
+ # bindings for manual_cpp_binding functions.
765
+ cpp_sig_group = CppSignatureGroup.from_native_function(
766
+ f, method=False, fallback_binding=False
767
+ )
768
+
769
+ # Signature of the wrapper function we'll register to the dispatcher
770
+ kern = self.backend_index.get_kernel(f)
771
+ sig = NativeSignature(
772
+ f.func,
773
+ prefix=f"wrapper_{self.backend_index.dispatch_key}_",
774
+ symint=kern is not None and kern.supports_symint(),
775
+ )
776
+
777
+ if self.target is Target.NAMESPACED_DECLARATION:
778
+ result = ""
779
+ for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
780
+ result += f"TORCH_API {cpp_sig.decl()};\n"
781
+ return result
782
+
783
+ elif self.target is Target.NAMESPACED_DEFINITION:
784
+
785
+ def generate_defn(cpp_sig: CppSignature) -> str:
786
+ return f"""
787
+ {cpp_sig.defn()} {{
788
+ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
789
+ }}
790
+ """
791
+
792
+ result = ""
793
+ for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
794
+ result += generate_defn(cpp_sig)
795
+ return result
796
+
797
+ elif self.target is Target.ANONYMOUS_DEFINITION:
798
+ k = f.func.kind()
799
+
800
+ # Construct the body of the wrapper function with signature sig
801
+ sig_body = []
802
+ # We'll use context to keep track of any variables we've brought
803
+ # into scope while generating code
804
+ context: List[Union[Binding, Expr]] = list(sig.arguments())
805
+
806
+ # Initialize the class corresponding to this structured
807
+ # operator; feeding it the output argument(s) if it is known
808
+ if self.backend_index.dispatch_key is DispatchKey.Meta:
809
+ class_name = f"structured_{meta.name(self.g)}_meta_{k.name}"
810
+ parent_class = f"at::meta::structured_{meta.name(self.g)}"
811
+ elif (
812
+ self.backend_index.dispatch_key
813
+ is DispatchKey.CompositeExplicitAutogradNonFunctional
814
+ ):
815
+ # TODO: dedup this branch
816
+ class_name = f"structured_{meta.name(self.g)}_default_backend_{k.name}"
817
+ parent_class = f"at::meta::structured_{meta.name(self.g)}"
818
+ else:
819
+ metadata = self.backend_index.get_kernel(self.g)
820
+ assert metadata is not None
821
+ class_name = f"structured_{metadata.kernel}_{k.name}"
822
+ parent_class = f"{metadata.cpp_namespace}::structured_{metadata.kernel}"
823
+
824
+ if self.backend_index.device_guard:
825
+ device_check_args = itertools.chain(
826
+ f.func.arguments.out, f.func.arguments.flat_positional
827
+ )
828
+ sig_body.append(
829
+ RegisterDispatchKey.gen_device_check(
830
+ f.device_check, list(device_check_args), sig.name()
831
+ )
832
+ )
833
+
834
+ if k is SchemaKind.functional:
835
+ sig_body.append(f"{class_name} op;")
836
+ elif k is SchemaKind.inplace:
837
+ sig_body.append(f"{class_name} op(self);")
838
+ elif k is SchemaKind.out:
839
+ out_args_str = ", ".join(a.name for a in f.func.arguments.out)
840
+ sig_body.append(f"{class_name} op({out_args_str});")
841
+
842
+ # Translate the input native arguments into structured
843
+ # arguments for the meta call
844
+ meta_exprs = ", ".join(
845
+ e.expr
846
+ for e in translate(
847
+ context, structured.meta_arguments(self.g), method=False
848
+ )
849
+ )
850
+
851
+ if self.g.out.precomputed:
852
+ # If this function group has precomputed elements, the meta function
853
+ # returns a struct containing them which must be saved so that it
854
+ # can be unpacked when generating code to call the impl.
855
+ sig_body.append(f"auto precompute = op.meta({meta_exprs});")
856
+
857
+ # Put all of the contents of the precompute struct into the context
858
+ # so that translate will be able to return the correct args for the
859
+ # call to the impl.
860
+ precomputed_values = [
861
+ *self.g.out.precomputed.replace.values(),
862
+ self.g.out.precomputed.add,
863
+ ]
864
+ for precomputed_elems in precomputed_values:
865
+ for arg in precomputed_elems:
866
+ context.append(
867
+ Expr(
868
+ expr=f"precompute.{arg.name}",
869
+ type=structured.argument_type(arg, binds=arg.name),
870
+ )
871
+ )
872
+
873
+ # Add a use of the precompute struct so FB internal compilers don't
874
+ # complain that there is an unused variable.
875
+ sig_body.append("(void)precompute;")
876
+ else:
877
+ sig_body.append(f"op.meta({meta_exprs});")
878
+
879
+ # After running meta, op.outputs_ is guaranteed to be valid;
880
+ # add it to the context
881
+ out_args = structured.out_arguments(self.g)
882
+ for i, out_arg in enumerate(out_args):
883
+ assert ConstRefCType(BaseCType(tensorT)) == out_arg.nctype.type
884
+
885
+ if k is SchemaKind.out:
886
+ expr = f"op.maybe_get_output({i})"
887
+ else:
888
+ expr = f"op.outputs_[{i}]"
889
+
890
+ context.append(
891
+ Expr(
892
+ expr=expr,
893
+ # TODO: Stop hardcoding that the output type is a Tensor. Note
894
+ # that for the codegen here this is fine because outputs_ is
895
+ # hardcoded to be tensor already
896
+ type=NamedCType(
897
+ out_arg.nctype.name, MutRefCType(BaseCType(tensorT))
898
+ ),
899
+ )
900
+ )
901
+
902
+ # With the expanded context, do the impl call (if not a meta
903
+ # function)
904
+ if (
905
+ self.backend_index.dispatch_key
906
+ == DispatchKey.CompositeExplicitAutogradNonFunctional
907
+ ):
908
+ # TODO: https://github.com/pytorch/pytorch/issues/53023
909
+ out_sig_group = CppSignatureGroup.from_native_function(
910
+ self.g.out, method=False, fallback_binding=f.manual_cpp_binding
911
+ )
912
+ out_sig = out_sig_group.most_faithful_signature()
913
+ api_name = out_sig.name()
914
+ out_exprs = ", ".join(
915
+ e.expr
916
+ for e in translate(context, out_sig.arguments(), method=False)
917
+ )
918
+ # TODO: I think this means structured won't work with method
919
+ # only functions (but maybe you're saved by faithful? iunno.)
920
+ # NB: Originally I wrote this as an at::redispatch call, but
921
+ # I got in trouble because that meant I needed a DispatchKeySet
922
+ # in the wrapper function, which meant I needed a DispatchKeySet
923
+ # in the DispatchKeyFunctions declarations, but the defined API
924
+ # there does NOT permit a dispatch key set. I think you can
925
+ # probably unwind this by calling some function to do the TLS
926
+ # fetch and get the DispatchKeySet when you don't have it, but
927
+ # I didn't do it for this version
928
+ sig_body.append(f"at::{api_name}({out_exprs});")
929
+ elif self.backend_index.dispatch_key != DispatchKey.Meta:
930
+ impl_exprs = ", ".join(
931
+ e.expr
932
+ for e in translate(
933
+ context, structured.impl_arguments(self.g), method=False
934
+ )
935
+ )
936
+ sig_body.append(f"op.impl({impl_exprs});")
937
+
938
+ # Go over each output, and check if there is a proxy created for it.
939
+ # If so, copy it over to the original output.
940
+ if k is SchemaKind.out or k is SchemaKind.inplace:
941
+ for i in range(len(f.func.returns)):
942
+ sig_body.append(
943
+ f"if (op.proxy_outputs_[{i}].has_value()) op.outputs_[{i}].get().copy_(*op.proxy_outputs_[{i}]);"
944
+ )
945
+
946
+ # Destructively return the final tensors
947
+ # TODO: Do this in translate instead
948
+ if k is SchemaKind.functional:
949
+ if len(f.func.returns) == 1:
950
+ ret_expr = "std::move(op.outputs_[0])" # small optimization
951
+ else:
952
+ moved = ", ".join(
953
+ f"std::move(op.outputs_[{i}])"
954
+ for i in range(len(f.func.returns))
955
+ )
956
+ ret_expr = f"std::make_tuple({moved})"
957
+ elif k is SchemaKind.inplace:
958
+ ret_expr = "self"
959
+ elif k is SchemaKind.out:
960
+ if len(f.func.returns) == 1:
961
+ ret_expr = f.func.arguments.out[0].name
962
+ else:
963
+ refs = ", ".join(a.name for a in f.func.arguments.out)
964
+ ret_expr = f"std::forward_as_tuple({refs})"
965
+ sig_body.append(f"return {ret_expr};") # type: ignore[possibly-undefined] # TODO: audit
966
+
967
+ sig_body_str = "\n".join(sig_body)
968
+
969
+ # For an overview of what this template code looks like, see
970
+ # https://github.com/pytorch/rfcs/pull/9
971
+ return f"""\
972
+ {self.gen_class(
973
+ f, k,
974
+ class_name=class_name,
975
+ parent_class=parent_class,
976
+ generate_super=self.g.out.structured_inherits is not None
977
+ )}
978
+
979
+ {sig.defn()} {{
980
+ {sig_body_str}
981
+ }}
982
+ """
983
+
984
+ elif self.target is Target.REGISTRATION:
985
+ return f'm.impl("{f.func.name}", TORCH_FN({sig.name()}));'
986
+ else:
987
+ assert_never(self.target)
988
+ # Silence mypy's "Missing return statement" error
989
+ return None
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/ufunc.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict, List, Optional, Sequence, Tuple, Union
3
+
4
+ import torchgen.api.ufunc as ufunc
5
+ from torchgen.api.translate import translate
6
+ from torchgen.api.types import (
7
+ BaseCType,
8
+ Binding,
9
+ CType,
10
+ Expr,
11
+ NamedCType,
12
+ opmath_t,
13
+ scalar_t,
14
+ StructuredImplSignature,
15
+ VectorizedCType,
16
+ )
17
+ from torchgen.api.ufunc import UfunctorBindings
18
+ from torchgen.context import with_native_function
19
+ from torchgen.model import (
20
+ Argument,
21
+ BaseTy,
22
+ BaseType,
23
+ DispatchKey,
24
+ NativeFunctionsGroup,
25
+ ScalarType,
26
+ UfuncKey,
27
+ )
28
+ from torchgen.utils import OrderedSet
29
+
30
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
31
+ #
32
+ # CUDA STUFF
33
+ #
34
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
35
+
36
+ # NB: not bothering to generate dispatch stub forward declaration in header,
37
+ # we can just paste it whereever necessary
38
+
39
+ # TODO: use BackendIndex
40
+ # dispatch_key: DispatchKey # only CPU/CUDA right now
41
+
42
+
43
+ # Represents functors for implementing CUDA ufuncs.
44
+ # Functors are templated by scalar_t because when USERS instantiate functors
45
+ # they are templated. A functor looks something like this:
46
+ #
47
+ # template <typename scalar_t>
48
+ # struct CUDAFunctorOnSelf_add {
49
+ # using opmath_t = at::opmath_type<scalar_t>;
50
+ # opmath_t other_;
51
+ # opmath_t alpha_;
52
+ # CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha)
53
+ # : other_(other), alpha_(alpha) {}
54
+ # __device__ scalar_t operator()(scalar_t self) {
55
+ # return ufunc::add(static_cast<opmath_t>(self), other_, alpha_);
56
+ # }
57
+ # };
58
+ #
59
+ @dataclass(frozen=True)
60
+ class UfunctorSignature:
61
+ g: NativeFunctionsGroup
62
+ scalar_tensor_idx: Optional[int]
63
+ name: str
64
+
65
+ def arguments(self) -> UfunctorBindings:
66
+ return ufunc.ufunctor_arguments(
67
+ self.g, scalar_tensor_idx=self.scalar_tensor_idx, scalar_t=scalar_t
68
+ )
69
+
70
+ def fields(self) -> List[Binding]:
71
+ # fields are renamed to have a trailing underscore, as is conventional
72
+ return [b.rename(f"{b.name}_") for b in self.arguments().ctor]
73
+
74
+ def returns_type(self) -> CType:
75
+ # TODO: don't hardcode; return type will be inferred based on tags on
76
+ # the native function
77
+ return BaseCType(scalar_t)
78
+
79
+ def decl_fields(self) -> str:
80
+ return "\n".join(f"{f.type} {f.name};" for f in self.fields())
81
+
82
+ def inline_defn_ctor(self) -> str:
83
+ args_str = ", ".join(a.decl() for a in self.arguments().ctor)
84
+ # NB: hypothetically could do this with translate but the
85
+ # transition here is very regular
86
+ init_str = ", ".join(f"{a.name}_({a.name})" for a in self.arguments().ctor)
87
+ return f"{self.name}({args_str}) : {init_str} {{}}"
88
+
89
+ def decl_apply(self) -> str:
90
+ args_str = ", ".join(a.decl() for a in self.arguments().apply)
91
+ return f"{self.returns_type().cpp_type()} operator()({args_str}) const"
92
+
93
+
94
+ @dataclass(frozen=True)
95
+ class UfuncSignature:
96
+ g: NativeFunctionsGroup
97
+ name: str
98
+ compute_t: CType
99
+
100
+ def arguments(self) -> List[Binding]:
101
+ return ufunc.ufunc_arguments(self.g, compute_t=self.compute_t)
102
+
103
+ def call(self, ctx: Sequence[Union[Binding, Expr]]) -> str:
104
+ return f"{self.name}({', '.join(a.expr for a in translate(ctx, self.arguments()))})"
105
+
106
+
107
+ # steps:
108
+ # 1. take the functional signature
109
+ # 2. use api.ufunc to convert it to template signature. this establishes
110
+ # the type of the template function
111
+ # 3. use api.ufunc (II) to generate a split struct / operator() signature.
112
+ # this establish context in which we call the template signature
113
+ #
114
+ # StructuredImplSignature context
115
+ # ~> functor constructor sig
116
+ #
117
+ # Functor constructor context
118
+ # ~> functor fields sig
119
+ #
120
+ # Functor apply context (functor fields + functor apply sig)
121
+ # ~> template sig
122
+ #
123
+
124
+
125
+ def eligible_for_binary_scalar_specialization(g: NativeFunctionsGroup) -> bool:
126
+ num_tensors = sum(
127
+ 1 for a in g.functional.func.arguments.flat_non_out if a.type.is_tensor_like()
128
+ )
129
+ return num_tensors == 2
130
+
131
+
132
+ def compute_ufunc_cuda_functors(
133
+ g: NativeFunctionsGroup,
134
+ ) -> Tuple[Dict[ScalarType, Dict[UfuncKey, UfunctorSignature]], str]:
135
+ # First, build the functors.
136
+ ufunctor_sigs: Dict[ScalarType, Dict[UfuncKey, UfunctorSignature]] = {}
137
+ ufunctors: List[str] = []
138
+ loops = g.out.ufunc_inner_loop
139
+ scalar_tensor_idx_lookup = {
140
+ UfuncKey.CUDAFunctorOnSelf: 1,
141
+ UfuncKey.CUDAFunctorOnOther: 0,
142
+ UfuncKey.CUDAFunctor: None,
143
+ }
144
+ if eligible_for_binary_scalar_specialization(g):
145
+ keys = [
146
+ UfuncKey.CUDAFunctorOnSelf,
147
+ UfuncKey.CUDAFunctorOnOther,
148
+ UfuncKey.CUDAFunctor,
149
+ ]
150
+ else:
151
+ keys = [UfuncKey.CUDAFunctor]
152
+ for k in [UfuncKey.CUDAFunctorOnSelf, UfuncKey.CUDAFunctorOnOther]:
153
+ assert k not in loops, f"cannot use {k} on non-binary function"
154
+ for k in keys:
155
+ # If the key was directly defined, skip functor codegen; we assume the
156
+ # user already done it for us
157
+ if k in loops:
158
+ ufunctor_sig = UfunctorSignature(
159
+ g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=loops[k].name
160
+ )
161
+ for dtype in loops[k].supported_dtypes:
162
+ ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig
163
+ continue
164
+
165
+ # Note [ScalarOnly and Generic must match names for CUDA]
166
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
167
+ # Otherwise, look in ANY of the generic entries. For simplicity of
168
+ # codegen, both ScalarOnly and Generic are defined, the ufunc name
169
+ # must match (if they didn't match, we'd have to generate distinct
170
+ # functors per dtype, which is awful, so we're not going to do it unless
171
+ # someone really forces us to)
172
+ ufunc_name = None
173
+ supported_dtypes: OrderedSet[ScalarType] = OrderedSet()
174
+ for lk in [UfuncKey.ScalarOnly, UfuncKey.Generic]:
175
+ if lk not in loops:
176
+ continue
177
+ if ufunc_name is None:
178
+ ufunc_name = loops[lk].name
179
+ else:
180
+ # See Note [ScalarOnly and Generic must match names for CUDA]
181
+ assert (
182
+ ufunc_name == loops[lk].name
183
+ ), "ScalarOnly and Generic must have same ufunc name"
184
+ supported_dtypes |= loops[lk].supported_dtypes
185
+ assert ufunc_name is not None
186
+
187
+ name = f"{k}_{ufunc_name}"
188
+ ufunctor_sig = UfunctorSignature(
189
+ g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=name
190
+ )
191
+ for dtype in supported_dtypes:
192
+ ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig
193
+
194
+ ufunc_sig = UfuncSignature(
195
+ g, name=f"ufunc::{ufunc_name}", compute_t=BaseCType(opmath_t)
196
+ )
197
+ apply_ctx = ufunctor_sig.fields() + ufunctor_sig.arguments().apply
198
+ ufunctors.append(
199
+ f"""
200
+ template <typename scalar_t>
201
+ struct {ufunctor_sig.name} {{
202
+ using opmath_t = at::opmath_type<scalar_t>;
203
+ {ufunctor_sig.decl_fields()}
204
+ {ufunctor_sig.inline_defn_ctor()}
205
+ __device__ {ufunctor_sig.decl_apply()} {{
206
+ return {ufunc_sig.call(apply_ctx)};
207
+ }}
208
+ }};
209
+ """
210
+ )
211
+
212
+ return ufunctor_sigs, "\n".join(ufunctors)
213
+
214
+
215
+ @dataclass(frozen=True)
216
+ class BinaryScalarSpecializationConfig:
217
+ scalar_idx: int
218
+ ctor_tensor: str
219
+ ufunc_key: UfuncKey
220
+
221
+
222
+ BinaryScalarSpecializationConfigs = [
223
+ BinaryScalarSpecializationConfig(
224
+ scalar_idx=0,
225
+ ctor_tensor="self",
226
+ ufunc_key=UfuncKey.CUDAFunctorOnOther,
227
+ ),
228
+ BinaryScalarSpecializationConfig(
229
+ scalar_idx=1,
230
+ ctor_tensor="other",
231
+ ufunc_key=UfuncKey.CUDAFunctorOnSelf,
232
+ ),
233
+ ]
234
+
235
+
236
+ def compute_ufunc_cuda_dtype_body(
237
+ g: NativeFunctionsGroup,
238
+ dtype: ScalarType,
239
+ inner_loops: Dict[UfuncKey, UfunctorSignature],
240
+ parent_ctx: Sequence[Binding],
241
+ ) -> str:
242
+ body = "using opmath_t = at::opmath_type<scalar_t>;"
243
+ body += "if (false) {}\n" # for ease of codegen
244
+ for config in BinaryScalarSpecializationConfigs:
245
+ if config.ufunc_key not in inner_loops:
246
+ continue
247
+ ufunctor_sig = inner_loops[config.ufunc_key]
248
+ scalar_idx = config.scalar_idx + 1
249
+ # Make a copy and at the same time widen the type (not permissible
250
+ # without copy; we don't want to mutate the input argument anyway)
251
+ ctx: List[Union[Expr, Binding]] = list(parent_ctx)
252
+ ctx.append(
253
+ Expr(
254
+ expr=f"iter.scalar_value<opmath_t>({scalar_idx})",
255
+ type=NamedCType(config.ctor_tensor, BaseCType(opmath_t)),
256
+ )
257
+ )
258
+ ufunctor_ctor_exprs_str = ", ".join(
259
+ a.expr for a in translate(ctx, ufunctor_sig.arguments().ctor)
260
+ )
261
+
262
+ # NB: ufunctor must be allocated before iter.remove_operand is called,
263
+ # as it relies on iter
264
+ body += f"""\
265
+ else if (iter.is_cpu_scalar({scalar_idx})) {{
266
+ {ufunctor_sig.name}<scalar_t> ufunctor({ufunctor_ctor_exprs_str});
267
+ iter.remove_operand({scalar_idx});
268
+ gpu_kernel(iter, ufunctor);
269
+ }}"""
270
+
271
+ ufunctor_sig = inner_loops[UfuncKey.CUDAFunctor]
272
+ ufunctor_ctor_exprs_str = ", ".join(
273
+ a.expr for a in translate(parent_ctx, ufunctor_sig.arguments().ctor)
274
+ )
275
+ body += f"""
276
+ else {{
277
+ gpu_kernel(iter, {ufunctor_sig.name}<scalar_t>({ufunctor_ctor_exprs_str}));
278
+ }}
279
+ """
280
+ return body
281
+
282
+
283
+ @with_native_function
284
+ def compute_ufunc_cuda(g: NativeFunctionsGroup) -> str:
285
+ # First, build the functors, indexing them by dtype
286
+ ufunctor_sigs, ufunctors = compute_ufunc_cuda_functors(g)
287
+
288
+ # Next, build the conditionals
289
+ sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CUDA))
290
+ dtype_cases = []
291
+ for dtype, inner_ufunc_sigs in ufunctor_sigs.items():
292
+ dtype_cases.append(
293
+ f"""
294
+ AT_DISPATCH_CASE(at::ScalarType::{dtype},
295
+ [&]() {{
296
+ {compute_ufunc_cuda_dtype_body(g, dtype, inner_ufunc_sigs, sig.arguments())}
297
+ }}
298
+ )
299
+ """
300
+ )
301
+
302
+ dtype_cases_str = "\n".join(dtype_cases)
303
+
304
+ stub_sig = StubSignature(g)
305
+
306
+ return f"""
307
+ {ufunctors}
308
+
309
+ {stub_sig.type_defn()};
310
+ {stub_sig.dispatch_decl()};
311
+
312
+ {stub_sig.kernel_defn()} {{
313
+ AT_DISPATCH_SWITCH(iter.common_dtype(), "{sig.name}",
314
+ {dtype_cases_str}
315
+ );
316
+ }}
317
+ REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name});
318
+
319
+ {sig.defn()} {{
320
+ {stub_sig.direct_call(sig.arguments())};
321
+ }}
322
+ """
323
+
324
+
325
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
326
+ #
327
+ # CPU STUFF
328
+ #
329
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
330
+
331
+
332
+ @dataclass(frozen=True)
333
+ class StubSignature:
334
+ g: NativeFunctionsGroup
335
+
336
+ @property
337
+ def name(self) -> str:
338
+ return f"{str(self.g.functional.func.name.name)}_stub"
339
+
340
+ @property
341
+ def kernel_name(self) -> str:
342
+ return f"{str(self.g.functional.func.name.name)}_kernel"
343
+
344
+ @property
345
+ def type_name(self) -> str:
346
+ return f"{str(self.g.functional.func.name.name)}_fn"
347
+
348
+ def arguments(self) -> List[Binding]:
349
+ return ufunc.stub_arguments(self.g)
350
+
351
+ def type(self) -> str:
352
+ cpp_args = self.arguments()
353
+ return f"void(*)(TensorIteratorBase&, {', '.join(a.type for a in cpp_args)})"
354
+
355
+ def dispatch_decl(self) -> str:
356
+ return f"DECLARE_DISPATCH({self.type_name}, {self.name})"
357
+
358
+ def dispatch_defn(self) -> str:
359
+ return f"DEFINE_DISPATCH({self.name})"
360
+
361
+ def kernel_defn(self) -> str:
362
+ return f"void {self.kernel_name}(TensorIteratorBase& iter, {', '.join(a.defn() for a in self.arguments())})"
363
+
364
+ def type_defn(self) -> str:
365
+ return f"using {self.type_name} = {self.type()}"
366
+
367
+ # must be called from context where this is TensorIteratorBase*
368
+ def call(self, ctx: Sequence[Binding]) -> str:
369
+ return f"{self.name}(device_type(), *this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})"
370
+
371
+ # used in CUDA to skip the unnecessary dynamic dispatch
372
+ def direct_call(self, ctx: Sequence[Binding]) -> str:
373
+ return f"{self.kernel_name}(*this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})"
374
+
375
+
376
+ @with_native_function
377
+ def compute_ufunc_cpu(g: NativeFunctionsGroup) -> str:
378
+ stub_sig = StubSignature(g)
379
+ sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CPU))
380
+
381
+ return f"""
382
+ {stub_sig.type_defn()};
383
+ {stub_sig.dispatch_decl()};
384
+ {stub_sig.dispatch_defn()};
385
+
386
+ {sig.defn()} {{
387
+ {stub_sig.call(sig.arguments())};
388
+ }}
389
+ """
390
+
391
+
392
+ def compute_ufunc_cpu_dtype_body(
393
+ g: NativeFunctionsGroup,
394
+ dtype: ScalarType,
395
+ inner_loops: Dict[UfuncKey, UfuncSignature],
396
+ parent_ctx: Sequence[Binding],
397
+ ) -> str:
398
+ assert UfuncKey.CPUScalar in inner_loops, f"{dtype}, {inner_loops.keys()}"
399
+ assert inner_loops.keys() <= {UfuncKey.CPUScalar, UfuncKey.CPUVector}
400
+ scalar_loop = inner_loops[UfuncKey.CPUScalar]
401
+ vec_loop = None
402
+ if UfuncKey.CPUVector in inner_loops:
403
+ vec_loop = inner_loops[UfuncKey.CPUVector]
404
+
405
+ # NB: We DON'T use translate here, because translate is
406
+ # incapable of CSE'ing the scalar accesses in case it is also
407
+ # used by Vectorized; also, the unpacking here is very simple
408
+ # and only affects Scalar; everything else is implicitly captured
409
+ # by the lambda
410
+
411
+ # Setup scalar in scope
412
+ body = []
413
+ ctx = []
414
+ for b in parent_ctx:
415
+ if isinstance(b.argument, Argument) and b.argument.type != BaseType(
416
+ BaseTy.Scalar
417
+ ):
418
+ continue
419
+ body.append(f"auto _s_{b.name} = {b.name}.to<scalar_t>();")
420
+ ctx.append(Expr(f"_s_{b.name}", NamedCType(b.nctype.name, BaseCType(scalar_t))))
421
+ if vec_loop is not None:
422
+ for b in parent_ctx:
423
+ if isinstance(b.argument, Argument) and b.argument.type != BaseType(
424
+ BaseTy.Scalar
425
+ ):
426
+ continue
427
+ body.append(
428
+ f"auto _v_{b.name} = at::vec::Vectorized<scalar_t>(_s_{b.name});"
429
+ )
430
+ ctx.append(
431
+ Expr(
432
+ f"_v_{b.name}",
433
+ NamedCType(b.nctype.name, VectorizedCType(BaseCType(scalar_t))),
434
+ )
435
+ )
436
+
437
+ # Setup lambda signature
438
+ # NB: simplified version of ufunctor_arguments
439
+ scalar_bindings = []
440
+ vec_bindings = []
441
+ for a in g.functional.func.arguments.flat_non_out:
442
+ if not a.type.is_tensor_like():
443
+ continue
444
+ assert a.type == BaseType(BaseTy.Tensor)
445
+ scalar_bindings.append(
446
+ Binding(
447
+ name=a.name,
448
+ nctype=NamedCType(a.name, BaseCType(scalar_t)),
449
+ argument=a,
450
+ )
451
+ )
452
+ if vec_loop is not None:
453
+ vec_bindings.append(
454
+ Binding(
455
+ name=a.name,
456
+ nctype=NamedCType(a.name, VectorizedCType(BaseCType(scalar_t))),
457
+ argument=a,
458
+ )
459
+ )
460
+
461
+ def with_ctx(b: Sequence[Binding]) -> List[Union[Expr, Binding]]:
462
+ r: List[Union[Expr, Binding]] = []
463
+ r.extend(ctx)
464
+ r.extend(b)
465
+ return r
466
+
467
+ body_str = "\n".join(body)
468
+ if vec_loop is not None:
469
+ return f"""
470
+ {body_str}
471
+ cpu_kernel_vec(iter,
472
+ [=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }},
473
+ [=]({', '.join(b.decl() for b in vec_bindings)}) {{ return {vec_loop.call(with_ctx(vec_bindings))}; }}
474
+ );
475
+ """
476
+ else:
477
+ return f"""
478
+ {body_str}
479
+ cpu_kernel(iter,
480
+ [=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }}
481
+ );
482
+ """
483
+
484
+
485
+ @with_native_function
486
+ def compute_ufunc_cpu_kernel(g: NativeFunctionsGroup) -> str:
487
+ stub_sig = StubSignature(g)
488
+
489
+ # Reindex the ufunc by dtypes; processing generic/scalaronly as well
490
+ loops = g.out.ufunc_inner_loop
491
+ ufunc_sigs: Dict[ScalarType, Dict[UfuncKey, UfuncSignature]] = {}
492
+ for k in [UfuncKey.CPUScalar, UfuncKey.CPUVector]:
493
+ lks = []
494
+ # ORDER MATTERS: this specifies overriding precedence
495
+ if k in loops: # should happen rarely
496
+ lks.append(k)
497
+ if UfuncKey.ScalarOnly in loops and k is UfuncKey.CPUScalar:
498
+ lks.append(UfuncKey.ScalarOnly)
499
+ if UfuncKey.Generic in loops:
500
+ lks.append(UfuncKey.Generic)
501
+ # TODO: don't hardcode ufunc:: namespace here, should be centralized smh
502
+ for lk in lks:
503
+ for dtype in loops[lk].supported_dtypes:
504
+ compute_t: CType
505
+ if k is UfuncKey.CPUScalar:
506
+ compute_t = BaseCType(scalar_t)
507
+ elif k is UfuncKey.CPUVector:
508
+ compute_t = VectorizedCType(BaseCType(scalar_t))
509
+ else:
510
+ raise AssertionError()
511
+ inner_ufunc_sigs = ufunc_sigs.setdefault(dtype, {})
512
+ if k not in inner_ufunc_sigs:
513
+ inner_ufunc_sigs[k] = UfuncSignature(
514
+ g, name=f"ufunc::{loops[lk].name}", compute_t=compute_t
515
+ )
516
+
517
+ # Build the conditionals
518
+ dtype_cases = []
519
+ for dtype, inner_ufunc_sigs in ufunc_sigs.items():
520
+ dtype_cases.append(
521
+ f"""
522
+ AT_DISPATCH_CASE(at::ScalarType::{dtype},
523
+ [&]() {{
524
+ {compute_ufunc_cpu_dtype_body(g, dtype, inner_ufunc_sigs, stub_sig.arguments())}
525
+ }}
526
+ )
527
+ """
528
+ )
529
+
530
+ dtype_cases_str = "\n".join(dtype_cases)
531
+ return f"""
532
+ namespace {{
533
+
534
+ {stub_sig.kernel_defn()} {{
535
+ AT_DISPATCH_SWITCH(iter.common_dtype(), "{stub_sig.name}",
536
+ {dtype_cases_str}
537
+ );
538
+ }}
539
+
540
+ }} // anonymous namespace
541
+
542
+ {stub_sig.type_defn()};
543
+ {stub_sig.dispatch_decl()};
544
+ REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name});
545
+ """
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/__init__.py ADDED
File without changes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (220 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/__pycache__/parse.cpython-311.pyc ADDED
Binary file (7.11 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/api/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (224 Bytes). View file