Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/TestCyCache.py +119 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/TestDependencies.py +142 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/TestIpythonMagic.py +295 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/TestStripLiterals.py +56 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestCythonizeArgsParser.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestDependencies.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestInline.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestIpythonMagic.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestRecythonize.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestStripLiterals.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/Buffer.py +749 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/Code.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/FusedNode.cpython-311-x86_64-linux-gnu.so +3 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/Parsing.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__init__.py +1 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Lexicon.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Naming.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Options.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Pythran.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Scanning.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/UtilityCode.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/networkx/generators/atlas.dat.gz +3 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pyproject_hooks/__init__.py +23 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pyproject_hooks/__pycache__/_compat.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pyproject_hooks/_in_process/__init__.py +18 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pyproject_hooks/_in_process/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/code_template.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/context.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_aoti_c_shim.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_backend_stubs.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_executorch.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_functionalization_type.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_lazy_tensor.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_vmap_plumbing.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/local.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/native_function_generation.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/yaml_utils.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/__init__.py +19 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/__pycache__/lazy_ts_lowering.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/__pycache__/ufunc.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/lazy_ir.py +707 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/lazy_ts_lowering.py +48 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/native_functions.py +64 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/register_dispatch_key.py +989 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/ufunc.py +545 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/__init__.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/__pycache__/parse.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/api/__pycache__/__init__.cpython-311.pyc +0 -0
.gitattributes
CHANGED
|
@@ -38,3 +38,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/networkx/algorith
|
|
| 38 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Symtab.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 39 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Plex/Scanners.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
| 40 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Plex/DFA.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 38 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Symtab.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 39 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Plex/Scanners.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
| 40 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Plex/DFA.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/FusedNode.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/TestCyCache.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import difflib
|
| 2 |
+
import glob
|
| 3 |
+
import gzip
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
import tempfile
|
| 7 |
+
import unittest
|
| 8 |
+
|
| 9 |
+
import Cython.Build.Dependencies
|
| 10 |
+
import Cython.Utils
|
| 11 |
+
from Cython.TestUtils import CythonTest
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class TestCyCache(CythonTest):
|
| 15 |
+
|
| 16 |
+
def setUp(self):
|
| 17 |
+
CythonTest.setUp(self)
|
| 18 |
+
self.temp_dir = tempfile.mkdtemp(
|
| 19 |
+
prefix='cycache-test',
|
| 20 |
+
dir='TEST_TMP' if os.path.isdir('TEST_TMP') else None)
|
| 21 |
+
self.src_dir = tempfile.mkdtemp(prefix='src', dir=self.temp_dir)
|
| 22 |
+
self.cache_dir = tempfile.mkdtemp(prefix='cache', dir=self.temp_dir)
|
| 23 |
+
|
| 24 |
+
def cache_files(self, file_glob):
|
| 25 |
+
return glob.glob(os.path.join(self.cache_dir, file_glob))
|
| 26 |
+
|
| 27 |
+
def fresh_cythonize(self, *args, **kwargs):
|
| 28 |
+
Cython.Utils.clear_function_caches()
|
| 29 |
+
Cython.Build.Dependencies._dep_tree = None # discard method caches
|
| 30 |
+
Cython.Build.Dependencies.cythonize(*args, **kwargs)
|
| 31 |
+
|
| 32 |
+
def test_cycache_switch(self):
|
| 33 |
+
content1 = 'value = 1\n'
|
| 34 |
+
content2 = 'value = 2\n'
|
| 35 |
+
a_pyx = os.path.join(self.src_dir, 'a.pyx')
|
| 36 |
+
a_c = a_pyx[:-4] + '.c'
|
| 37 |
+
|
| 38 |
+
with open(a_pyx, 'w') as f:
|
| 39 |
+
f.write(content1)
|
| 40 |
+
self.fresh_cythonize(a_pyx, cache=self.cache_dir)
|
| 41 |
+
self.fresh_cythonize(a_pyx, cache=self.cache_dir)
|
| 42 |
+
self.assertEqual(1, len(self.cache_files('a.c*')))
|
| 43 |
+
with open(a_c) as f:
|
| 44 |
+
a_contents1 = f.read()
|
| 45 |
+
os.unlink(a_c)
|
| 46 |
+
|
| 47 |
+
with open(a_pyx, 'w') as f:
|
| 48 |
+
f.write(content2)
|
| 49 |
+
self.fresh_cythonize(a_pyx, cache=self.cache_dir)
|
| 50 |
+
with open(a_c) as f:
|
| 51 |
+
a_contents2 = f.read()
|
| 52 |
+
os.unlink(a_c)
|
| 53 |
+
|
| 54 |
+
self.assertNotEqual(a_contents1, a_contents2, 'C file not changed!')
|
| 55 |
+
self.assertEqual(2, len(self.cache_files('a.c*')))
|
| 56 |
+
|
| 57 |
+
with open(a_pyx, 'w') as f:
|
| 58 |
+
f.write(content1)
|
| 59 |
+
self.fresh_cythonize(a_pyx, cache=self.cache_dir)
|
| 60 |
+
self.assertEqual(2, len(self.cache_files('a.c*')))
|
| 61 |
+
with open(a_c) as f:
|
| 62 |
+
a_contents = f.read()
|
| 63 |
+
self.assertEqual(
|
| 64 |
+
a_contents, a_contents1,
|
| 65 |
+
msg='\n'.join(list(difflib.unified_diff(
|
| 66 |
+
a_contents.split('\n'), a_contents1.split('\n')))[:10]))
|
| 67 |
+
|
| 68 |
+
def test_cycache_uses_cache(self):
|
| 69 |
+
a_pyx = os.path.join(self.src_dir, 'a.pyx')
|
| 70 |
+
a_c = a_pyx[:-4] + '.c'
|
| 71 |
+
with open(a_pyx, 'w') as f:
|
| 72 |
+
f.write('pass')
|
| 73 |
+
self.fresh_cythonize(a_pyx, cache=self.cache_dir)
|
| 74 |
+
a_cache = os.path.join(self.cache_dir, os.listdir(self.cache_dir)[0])
|
| 75 |
+
with gzip.GzipFile(a_cache, 'wb') as gzipfile:
|
| 76 |
+
gzipfile.write('fake stuff'.encode('ascii'))
|
| 77 |
+
os.unlink(a_c)
|
| 78 |
+
self.fresh_cythonize(a_pyx, cache=self.cache_dir)
|
| 79 |
+
with open(a_c) as f:
|
| 80 |
+
a_contents = f.read()
|
| 81 |
+
self.assertEqual(a_contents, 'fake stuff',
|
| 82 |
+
'Unexpected contents: %s...' % a_contents[:100])
|
| 83 |
+
|
| 84 |
+
def test_multi_file_output(self):
|
| 85 |
+
a_pyx = os.path.join(self.src_dir, 'a.pyx')
|
| 86 |
+
a_c = a_pyx[:-4] + '.c'
|
| 87 |
+
a_h = a_pyx[:-4] + '.h'
|
| 88 |
+
a_api_h = a_pyx[:-4] + '_api.h'
|
| 89 |
+
with open(a_pyx, 'w') as f:
|
| 90 |
+
f.write('cdef public api int foo(int x): return x\n')
|
| 91 |
+
self.fresh_cythonize(a_pyx, cache=self.cache_dir)
|
| 92 |
+
expected = [a_c, a_h, a_api_h]
|
| 93 |
+
for output in expected:
|
| 94 |
+
self.assertTrue(os.path.exists(output), output)
|
| 95 |
+
os.unlink(output)
|
| 96 |
+
self.fresh_cythonize(a_pyx, cache=self.cache_dir)
|
| 97 |
+
for output in expected:
|
| 98 |
+
self.assertTrue(os.path.exists(output), output)
|
| 99 |
+
|
| 100 |
+
def test_options_invalidation(self):
|
| 101 |
+
hash_pyx = os.path.join(self.src_dir, 'options.pyx')
|
| 102 |
+
hash_c = hash_pyx[:-len('.pyx')] + '.c'
|
| 103 |
+
|
| 104 |
+
with open(hash_pyx, 'w') as f:
|
| 105 |
+
f.write('pass')
|
| 106 |
+
self.fresh_cythonize(hash_pyx, cache=self.cache_dir, cplus=False)
|
| 107 |
+
self.assertEqual(1, len(self.cache_files('options.c*')))
|
| 108 |
+
|
| 109 |
+
os.unlink(hash_c)
|
| 110 |
+
self.fresh_cythonize(hash_pyx, cache=self.cache_dir, cplus=True)
|
| 111 |
+
self.assertEqual(2, len(self.cache_files('options.c*')))
|
| 112 |
+
|
| 113 |
+
os.unlink(hash_c)
|
| 114 |
+
self.fresh_cythonize(hash_pyx, cache=self.cache_dir, cplus=False, show_version=False)
|
| 115 |
+
self.assertEqual(2, len(self.cache_files('options.c*')))
|
| 116 |
+
|
| 117 |
+
os.unlink(hash_c)
|
| 118 |
+
self.fresh_cythonize(hash_pyx, cache=self.cache_dir, cplus=False, show_version=True)
|
| 119 |
+
self.assertEqual(2, len(self.cache_files('options.c*')))
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/TestDependencies.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import contextlib
|
| 2 |
+
import os.path
|
| 3 |
+
import sys
|
| 4 |
+
import tempfile
|
| 5 |
+
import unittest
|
| 6 |
+
from io import open
|
| 7 |
+
from os.path import join as pjoin
|
| 8 |
+
|
| 9 |
+
from ..Dependencies import extended_iglob
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@contextlib.contextmanager
|
| 13 |
+
def writable_file(dir_path, filename):
|
| 14 |
+
with open(pjoin(dir_path, filename), "w", encoding="utf8") as f:
|
| 15 |
+
yield f
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class TestGlobbing(unittest.TestCase):
|
| 19 |
+
@classmethod
|
| 20 |
+
def setUpClass(cls):
|
| 21 |
+
cls._orig_dir = os.getcwd()
|
| 22 |
+
if sys.version_info[0] < 3:
|
| 23 |
+
temp_path = cls._tmpdir = tempfile.mkdtemp()
|
| 24 |
+
else:
|
| 25 |
+
cls._tmpdir = tempfile.TemporaryDirectory()
|
| 26 |
+
temp_path = cls._tmpdir.name
|
| 27 |
+
os.chdir(temp_path)
|
| 28 |
+
|
| 29 |
+
for dir1 in "abcd":
|
| 30 |
+
for dir1x in [dir1, dir1 + 'x']:
|
| 31 |
+
for dir2 in "xyz":
|
| 32 |
+
dir_path = pjoin(dir1x, dir2)
|
| 33 |
+
os.makedirs(dir_path)
|
| 34 |
+
with writable_file(dir_path, "file2_pyx.pyx") as f:
|
| 35 |
+
f.write(u'""" PYX """')
|
| 36 |
+
with writable_file(dir_path, "file2_py.py") as f:
|
| 37 |
+
f.write(u'""" PY """')
|
| 38 |
+
|
| 39 |
+
with writable_file(dir1x, "file1_pyx.pyx") as f:
|
| 40 |
+
f.write(u'""" PYX """')
|
| 41 |
+
with writable_file(dir1x, "file1_py.py") as f:
|
| 42 |
+
f.write(u'""" PY """')
|
| 43 |
+
|
| 44 |
+
@classmethod
|
| 45 |
+
def tearDownClass(cls):
|
| 46 |
+
os.chdir(cls._orig_dir)
|
| 47 |
+
if sys.version_info[0] < 3:
|
| 48 |
+
import shutil
|
| 49 |
+
shutil.rmtree(cls._tmpdir)
|
| 50 |
+
else:
|
| 51 |
+
cls._tmpdir.cleanup()
|
| 52 |
+
|
| 53 |
+
def files_equal(self, pattern, expected_files):
|
| 54 |
+
expected_files = sorted(expected_files)
|
| 55 |
+
# It's the users's choice whether '/' will appear on Windows.
|
| 56 |
+
matched_files = sorted(path.replace('/', os.sep) for path in extended_iglob(pattern))
|
| 57 |
+
self.assertListEqual(matched_files, expected_files) # /
|
| 58 |
+
|
| 59 |
+
# Special case for Windows: also support '\' in patterns.
|
| 60 |
+
if os.sep == '\\' and '/' in pattern:
|
| 61 |
+
matched_files = sorted(extended_iglob(pattern.replace('/', '\\')))
|
| 62 |
+
self.assertListEqual(matched_files, expected_files) # \
|
| 63 |
+
|
| 64 |
+
def test_extended_iglob_simple(self):
|
| 65 |
+
ax_files = [pjoin("a", "x", "file2_pyx.pyx"), pjoin("a", "x", "file2_py.py")]
|
| 66 |
+
self.files_equal("a/x/*", ax_files)
|
| 67 |
+
self.files_equal("a/x/*.c12", [])
|
| 68 |
+
self.files_equal("a/x/*.{py,pyx,c12}", ax_files)
|
| 69 |
+
self.files_equal("a/x/*.{py,pyx}", ax_files)
|
| 70 |
+
self.files_equal("a/x/*.{pyx}", ax_files[:1])
|
| 71 |
+
self.files_equal("a/x/*.pyx", ax_files[:1])
|
| 72 |
+
self.files_equal("a/x/*.{py}", ax_files[1:])
|
| 73 |
+
self.files_equal("a/x/*.py", ax_files[1:])
|
| 74 |
+
|
| 75 |
+
def test_extended_iglob_simple_star(self):
|
| 76 |
+
for basedir in "ad":
|
| 77 |
+
files = [
|
| 78 |
+
pjoin(basedir, dirname, filename)
|
| 79 |
+
for dirname in "xyz"
|
| 80 |
+
for filename in ["file2_pyx.pyx", "file2_py.py"]
|
| 81 |
+
]
|
| 82 |
+
self.files_equal(basedir + "/*/*", files)
|
| 83 |
+
self.files_equal(basedir + "/*/*.c12", [])
|
| 84 |
+
self.files_equal(basedir + "/*/*.{py,pyx,c12}", files)
|
| 85 |
+
self.files_equal(basedir + "/*/*.{py,pyx}", files)
|
| 86 |
+
self.files_equal(basedir + "/*/*.{pyx}", files[::2])
|
| 87 |
+
self.files_equal(basedir + "/*/*.pyx", files[::2])
|
| 88 |
+
self.files_equal(basedir + "/*/*.{py}", files[1::2])
|
| 89 |
+
self.files_equal(basedir + "/*/*.py", files[1::2])
|
| 90 |
+
|
| 91 |
+
for subdir in "xy*":
|
| 92 |
+
files = [
|
| 93 |
+
pjoin(basedir, dirname, filename)
|
| 94 |
+
for dirname in "xyz"
|
| 95 |
+
if subdir in ('*', dirname)
|
| 96 |
+
for filename in ["file2_pyx.pyx", "file2_py.py"]
|
| 97 |
+
]
|
| 98 |
+
path = basedir + '/' + subdir + '/'
|
| 99 |
+
self.files_equal(path + "*", files)
|
| 100 |
+
self.files_equal(path + "*.{py,pyx}", files)
|
| 101 |
+
self.files_equal(path + "*.{pyx}", files[::2])
|
| 102 |
+
self.files_equal(path + "*.pyx", files[::2])
|
| 103 |
+
self.files_equal(path + "*.{py}", files[1::2])
|
| 104 |
+
self.files_equal(path + "*.py", files[1::2])
|
| 105 |
+
|
| 106 |
+
def test_extended_iglob_double_star(self):
|
| 107 |
+
basedirs = os.listdir(".")
|
| 108 |
+
files = [
|
| 109 |
+
pjoin(basedir, dirname, filename)
|
| 110 |
+
for basedir in basedirs
|
| 111 |
+
for dirname in "xyz"
|
| 112 |
+
for filename in ["file2_pyx.pyx", "file2_py.py"]
|
| 113 |
+
]
|
| 114 |
+
all_files = [
|
| 115 |
+
pjoin(basedir, filename)
|
| 116 |
+
for basedir in basedirs
|
| 117 |
+
for filename in ["file1_pyx.pyx", "file1_py.py"]
|
| 118 |
+
] + files
|
| 119 |
+
self.files_equal("*/*/*", files)
|
| 120 |
+
self.files_equal("*/*/**/*", files)
|
| 121 |
+
self.files_equal("*/**/*.*", all_files)
|
| 122 |
+
self.files_equal("**/*.*", all_files)
|
| 123 |
+
self.files_equal("*/**/*.c12", [])
|
| 124 |
+
self.files_equal("**/*.c12", [])
|
| 125 |
+
self.files_equal("*/*/*.{py,pyx,c12}", files)
|
| 126 |
+
self.files_equal("*/*/**/*.{py,pyx,c12}", files)
|
| 127 |
+
self.files_equal("*/**/*/*.{py,pyx,c12}", files)
|
| 128 |
+
self.files_equal("**/*/*/*.{py,pyx,c12}", files)
|
| 129 |
+
self.files_equal("**/*.{py,pyx,c12}", all_files)
|
| 130 |
+
self.files_equal("*/*/*.{py,pyx}", files)
|
| 131 |
+
self.files_equal("**/*/*/*.{py,pyx}", files)
|
| 132 |
+
self.files_equal("*/**/*/*.{py,pyx}", files)
|
| 133 |
+
self.files_equal("**/*.{py,pyx}", all_files)
|
| 134 |
+
self.files_equal("*/*/*.{pyx}", files[::2])
|
| 135 |
+
self.files_equal("**/*.{pyx}", all_files[::2])
|
| 136 |
+
self.files_equal("*/**/*/*.pyx", files[::2])
|
| 137 |
+
self.files_equal("*/*/*.pyx", files[::2])
|
| 138 |
+
self.files_equal("**/*.pyx", all_files[::2])
|
| 139 |
+
self.files_equal("*/*/*.{py}", files[1::2])
|
| 140 |
+
self.files_equal("**/*.{py}", all_files[1::2])
|
| 141 |
+
self.files_equal("*/*/*.py", files[1::2])
|
| 142 |
+
self.files_equal("**/*.py", all_files[1::2])
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/TestIpythonMagic.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# tag: ipython
|
| 3 |
+
|
| 4 |
+
"""Tests for the Cython magics extension."""
|
| 5 |
+
|
| 6 |
+
from __future__ import absolute_import
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import io
|
| 10 |
+
import sys
|
| 11 |
+
from contextlib import contextmanager
|
| 12 |
+
from unittest import skipIf
|
| 13 |
+
|
| 14 |
+
from Cython.Build import IpythonMagic
|
| 15 |
+
from Cython.TestUtils import CythonTest
|
| 16 |
+
from Cython.Compiler.Annotate import AnnotationCCodeWriter
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
import IPython.testing.globalipapp
|
| 20 |
+
except ImportError:
|
| 21 |
+
# Disable tests and fake helpers for initialisation below.
|
| 22 |
+
def skip_if_not_installed(_):
|
| 23 |
+
return None
|
| 24 |
+
else:
|
| 25 |
+
def skip_if_not_installed(c):
|
| 26 |
+
return c
|
| 27 |
+
|
| 28 |
+
# not using IPython's decorators here because they depend on "nose"
|
| 29 |
+
skip_win32 = skipIf(sys.platform == 'win32', "Skip on Windows")
|
| 30 |
+
skip_py27 = skipIf(sys.version_info[:2] == (2,7), "Disabled in Py2.7")
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
# disable IPython history thread before it gets started to avoid having to clean it up
|
| 34 |
+
from IPython.core.history import HistoryManager
|
| 35 |
+
HistoryManager.enabled = False
|
| 36 |
+
except ImportError:
|
| 37 |
+
pass
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@contextmanager
|
| 41 |
+
def capture_output():
|
| 42 |
+
backup = sys.stdout, sys.stderr
|
| 43 |
+
try:
|
| 44 |
+
replacement = [
|
| 45 |
+
io.TextIOWrapper(io.BytesIO(), encoding=sys.stdout.encoding),
|
| 46 |
+
io.TextIOWrapper(io.BytesIO(), encoding=sys.stderr.encoding),
|
| 47 |
+
]
|
| 48 |
+
sys.stdout, sys.stderr = replacement
|
| 49 |
+
output = []
|
| 50 |
+
yield output
|
| 51 |
+
finally:
|
| 52 |
+
sys.stdout, sys.stderr = backup
|
| 53 |
+
for wrapper in replacement:
|
| 54 |
+
wrapper.seek(0) # rewind
|
| 55 |
+
output.append(wrapper.read())
|
| 56 |
+
wrapper.close()
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
code = u"""\
|
| 60 |
+
def f(x):
|
| 61 |
+
return 2*x
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
cython3_code = u"""\
|
| 65 |
+
def f(int x):
|
| 66 |
+
return 2 / x
|
| 67 |
+
|
| 68 |
+
def call(x):
|
| 69 |
+
return f(*(x,))
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
pgo_cython3_code = cython3_code + u"""\
|
| 73 |
+
def main():
|
| 74 |
+
for _ in range(100): call(5)
|
| 75 |
+
main()
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
compile_error_code = u'''\
|
| 79 |
+
cdef extern from *:
|
| 80 |
+
"""
|
| 81 |
+
xxx a=1;
|
| 82 |
+
"""
|
| 83 |
+
int a;
|
| 84 |
+
def doit():
|
| 85 |
+
return a
|
| 86 |
+
'''
|
| 87 |
+
|
| 88 |
+
compile_warning_code = u'''\
|
| 89 |
+
cdef extern from *:
|
| 90 |
+
"""
|
| 91 |
+
#pragma message ( "CWarning" )
|
| 92 |
+
int a = 42;
|
| 93 |
+
"""
|
| 94 |
+
int a;
|
| 95 |
+
def doit():
|
| 96 |
+
return a
|
| 97 |
+
'''
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
@skip_if_not_installed
|
| 101 |
+
class TestIPythonMagic(CythonTest):
|
| 102 |
+
|
| 103 |
+
@classmethod
|
| 104 |
+
def setUpClass(cls):
|
| 105 |
+
CythonTest.setUpClass()
|
| 106 |
+
cls._ip = IPython.testing.globalipapp.get_ipython()
|
| 107 |
+
|
| 108 |
+
def setUp(self):
|
| 109 |
+
CythonTest.setUp(self)
|
| 110 |
+
self._ip.extension_manager.load_extension('cython')
|
| 111 |
+
|
| 112 |
+
def test_cython_inline(self):
|
| 113 |
+
ip = self._ip
|
| 114 |
+
ip.ex('a=10; b=20')
|
| 115 |
+
result = ip.run_cell_magic('cython_inline', '', 'return a+b')
|
| 116 |
+
self.assertEqual(result, 30)
|
| 117 |
+
|
| 118 |
+
@skip_win32
|
| 119 |
+
def test_cython_pyximport(self):
|
| 120 |
+
ip = self._ip
|
| 121 |
+
module_name = '_test_cython_pyximport'
|
| 122 |
+
ip.run_cell_magic('cython_pyximport', module_name, code)
|
| 123 |
+
ip.ex('g = f(10)')
|
| 124 |
+
self.assertEqual(ip.user_ns['g'], 20.0)
|
| 125 |
+
ip.run_cell_magic('cython_pyximport', module_name, code)
|
| 126 |
+
ip.ex('h = f(-10)')
|
| 127 |
+
self.assertEqual(ip.user_ns['h'], -20.0)
|
| 128 |
+
try:
|
| 129 |
+
os.remove(module_name + '.pyx')
|
| 130 |
+
except OSError:
|
| 131 |
+
pass
|
| 132 |
+
|
| 133 |
+
def test_cython(self):
|
| 134 |
+
ip = self._ip
|
| 135 |
+
ip.run_cell_magic('cython', '', code)
|
| 136 |
+
ip.ex('g = f(10)')
|
| 137 |
+
self.assertEqual(ip.user_ns['g'], 20.0)
|
| 138 |
+
|
| 139 |
+
def test_cython_name(self):
|
| 140 |
+
# The Cython module named 'mymodule' defines the function f.
|
| 141 |
+
ip = self._ip
|
| 142 |
+
ip.run_cell_magic('cython', '--name=mymodule', code)
|
| 143 |
+
# This module can now be imported in the interactive namespace.
|
| 144 |
+
ip.ex('import mymodule; g = mymodule.f(10)')
|
| 145 |
+
self.assertEqual(ip.user_ns['g'], 20.0)
|
| 146 |
+
|
| 147 |
+
def test_cython_language_level(self):
|
| 148 |
+
# The Cython cell defines the functions f() and call().
|
| 149 |
+
ip = self._ip
|
| 150 |
+
ip.run_cell_magic('cython', '', cython3_code)
|
| 151 |
+
ip.ex('g = f(10); h = call(10)')
|
| 152 |
+
if sys.version_info[0] < 3:
|
| 153 |
+
self.assertEqual(ip.user_ns['g'], 2 // 10)
|
| 154 |
+
self.assertEqual(ip.user_ns['h'], 2 // 10)
|
| 155 |
+
else:
|
| 156 |
+
self.assertEqual(ip.user_ns['g'], 2.0 / 10.0)
|
| 157 |
+
self.assertEqual(ip.user_ns['h'], 2.0 / 10.0)
|
| 158 |
+
|
| 159 |
+
def test_cython3(self):
|
| 160 |
+
# The Cython cell defines the functions f() and call().
|
| 161 |
+
ip = self._ip
|
| 162 |
+
ip.run_cell_magic('cython', '-3', cython3_code)
|
| 163 |
+
ip.ex('g = f(10); h = call(10)')
|
| 164 |
+
self.assertEqual(ip.user_ns['g'], 2.0 / 10.0)
|
| 165 |
+
self.assertEqual(ip.user_ns['h'], 2.0 / 10.0)
|
| 166 |
+
|
| 167 |
+
def test_cython2(self):
|
| 168 |
+
# The Cython cell defines the functions f() and call().
|
| 169 |
+
ip = self._ip
|
| 170 |
+
ip.run_cell_magic('cython', '-2', cython3_code)
|
| 171 |
+
ip.ex('g = f(10); h = call(10)')
|
| 172 |
+
self.assertEqual(ip.user_ns['g'], 2 // 10)
|
| 173 |
+
self.assertEqual(ip.user_ns['h'], 2 // 10)
|
| 174 |
+
|
| 175 |
+
def test_cython_compile_error_shown(self):
|
| 176 |
+
ip = self._ip
|
| 177 |
+
with capture_output() as out:
|
| 178 |
+
ip.run_cell_magic('cython', '-3', compile_error_code)
|
| 179 |
+
captured_out, captured_err = out
|
| 180 |
+
|
| 181 |
+
# it could be that c-level output is captured by distutil-extension
|
| 182 |
+
# (and not by us) and is printed to stdout:
|
| 183 |
+
captured_all = captured_out + "\n" + captured_err
|
| 184 |
+
self.assertTrue("error" in captured_all, msg="error in " + captured_all)
|
| 185 |
+
|
| 186 |
+
def test_cython_link_error_shown(self):
|
| 187 |
+
ip = self._ip
|
| 188 |
+
with capture_output() as out:
|
| 189 |
+
ip.run_cell_magic('cython', '-3 -l=xxxxxxxx', code)
|
| 190 |
+
captured_out, captured_err = out
|
| 191 |
+
|
| 192 |
+
# it could be that c-level output is captured by distutil-extension
|
| 193 |
+
# (and not by us) and is printed to stdout:
|
| 194 |
+
captured_all = captured_out + "\n!" + captured_err
|
| 195 |
+
self.assertTrue("error" in captured_all, msg="error in " + captured_all)
|
| 196 |
+
|
| 197 |
+
def test_cython_warning_shown(self):
|
| 198 |
+
ip = self._ip
|
| 199 |
+
with capture_output() as out:
|
| 200 |
+
# force rebuild, otherwise no warning as after the first success
|
| 201 |
+
# no build step is performed
|
| 202 |
+
ip.run_cell_magic('cython', '-3 -f', compile_warning_code)
|
| 203 |
+
captured_out, captured_err = out
|
| 204 |
+
|
| 205 |
+
# check that warning was printed to stdout even if build hasn't failed
|
| 206 |
+
self.assertTrue("CWarning" in captured_out)
|
| 207 |
+
|
| 208 |
+
@skip_py27 # Not strictly broken in Py2.7 but currently fails in CI due to C compiler issues.
|
| 209 |
+
@skip_win32
|
| 210 |
+
def test_cython3_pgo(self):
|
| 211 |
+
# The Cython cell defines the functions f() and call().
|
| 212 |
+
ip = self._ip
|
| 213 |
+
ip.run_cell_magic('cython', '-3 --pgo', pgo_cython3_code)
|
| 214 |
+
ip.ex('g = f(10); h = call(10); main()')
|
| 215 |
+
self.assertEqual(ip.user_ns['g'], 2.0 / 10.0)
|
| 216 |
+
self.assertEqual(ip.user_ns['h'], 2.0 / 10.0)
|
| 217 |
+
|
| 218 |
+
@skip_win32
|
| 219 |
+
def test_extlibs(self):
|
| 220 |
+
ip = self._ip
|
| 221 |
+
code = u"""
|
| 222 |
+
from libc.math cimport sin
|
| 223 |
+
x = sin(0.0)
|
| 224 |
+
"""
|
| 225 |
+
ip.user_ns['x'] = 1
|
| 226 |
+
ip.run_cell_magic('cython', '-l m', code)
|
| 227 |
+
self.assertEqual(ip.user_ns['x'], 0)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def test_cython_verbose(self):
|
| 231 |
+
ip = self._ip
|
| 232 |
+
ip.run_cell_magic('cython', '--verbose', code)
|
| 233 |
+
ip.ex('g = f(10)')
|
| 234 |
+
self.assertEqual(ip.user_ns['g'], 20.0)
|
| 235 |
+
|
| 236 |
+
def test_cython_verbose_thresholds(self):
|
| 237 |
+
@contextmanager
|
| 238 |
+
def mock_distutils():
|
| 239 |
+
class MockLog:
|
| 240 |
+
DEBUG = 1
|
| 241 |
+
INFO = 2
|
| 242 |
+
thresholds = [INFO]
|
| 243 |
+
|
| 244 |
+
def set_threshold(self, val):
|
| 245 |
+
self.thresholds.append(val)
|
| 246 |
+
return self.thresholds[-2]
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
new_log = MockLog()
|
| 250 |
+
old_log = IpythonMagic.distutils.log
|
| 251 |
+
try:
|
| 252 |
+
IpythonMagic.distutils.log = new_log
|
| 253 |
+
yield new_log
|
| 254 |
+
finally:
|
| 255 |
+
IpythonMagic.distutils.log = old_log
|
| 256 |
+
|
| 257 |
+
ip = self._ip
|
| 258 |
+
with mock_distutils() as verbose_log:
|
| 259 |
+
ip.run_cell_magic('cython', '--verbose', code)
|
| 260 |
+
ip.ex('g = f(10)')
|
| 261 |
+
self.assertEqual(ip.user_ns['g'], 20.0)
|
| 262 |
+
self.assertEqual([verbose_log.INFO, verbose_log.DEBUG, verbose_log.INFO],
|
| 263 |
+
verbose_log.thresholds)
|
| 264 |
+
|
| 265 |
+
with mock_distutils() as normal_log:
|
| 266 |
+
ip.run_cell_magic('cython', '', code)
|
| 267 |
+
ip.ex('g = f(10)')
|
| 268 |
+
self.assertEqual(ip.user_ns['g'], 20.0)
|
| 269 |
+
self.assertEqual([normal_log.INFO], normal_log.thresholds)
|
| 270 |
+
|
| 271 |
+
def test_cython_no_annotate(self):
|
| 272 |
+
ip = self._ip
|
| 273 |
+
html = ip.run_cell_magic('cython', '', code)
|
| 274 |
+
self.assertTrue(html is None)
|
| 275 |
+
|
| 276 |
+
def test_cython_annotate(self):
|
| 277 |
+
ip = self._ip
|
| 278 |
+
html = ip.run_cell_magic('cython', '--annotate', code)
|
| 279 |
+
# somewhat brittle way to differentiate between annotated htmls
|
| 280 |
+
# with/without complete source code:
|
| 281 |
+
self.assertTrue(AnnotationCCodeWriter.COMPLETE_CODE_TITLE not in html.data)
|
| 282 |
+
|
| 283 |
+
def test_cython_annotate_default(self):
|
| 284 |
+
ip = self._ip
|
| 285 |
+
html = ip.run_cell_magic('cython', '-a', code)
|
| 286 |
+
# somewhat brittle way to differentiate between annotated htmls
|
| 287 |
+
# with/without complete source code:
|
| 288 |
+
self.assertTrue(AnnotationCCodeWriter.COMPLETE_CODE_TITLE not in html.data)
|
| 289 |
+
|
| 290 |
+
def test_cython_annotate_complete_c_code(self):
|
| 291 |
+
ip = self._ip
|
| 292 |
+
html = ip.run_cell_magic('cython', '--annotate-fullc', code)
|
| 293 |
+
# somewhat brittle way to differentiate between annotated htmls
|
| 294 |
+
# with/without complete source code:
|
| 295 |
+
self.assertTrue(AnnotationCCodeWriter.COMPLETE_CODE_TITLE in html.data)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/TestStripLiterals.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from Cython.Build.Dependencies import strip_string_literals
|
| 2 |
+
|
| 3 |
+
from Cython.TestUtils import CythonTest
|
| 4 |
+
|
| 5 |
+
class TestStripLiterals(CythonTest):
|
| 6 |
+
|
| 7 |
+
def t(self, before, expected):
|
| 8 |
+
actual, literals = strip_string_literals(before, prefix="_L")
|
| 9 |
+
self.assertEqual(expected, actual)
|
| 10 |
+
for key, value in literals.items():
|
| 11 |
+
actual = actual.replace(key, value)
|
| 12 |
+
self.assertEqual(before, actual)
|
| 13 |
+
|
| 14 |
+
def test_empty(self):
|
| 15 |
+
self.t("", "")
|
| 16 |
+
|
| 17 |
+
def test_single_quote(self):
|
| 18 |
+
self.t("'x'", "'_L1_'")
|
| 19 |
+
|
| 20 |
+
def test_double_quote(self):
|
| 21 |
+
self.t('"x"', '"_L1_"')
|
| 22 |
+
|
| 23 |
+
def test_nested_quotes(self):
|
| 24 |
+
self.t(""" '"' "'" """, """ '_L1_' "_L2_" """)
|
| 25 |
+
|
| 26 |
+
def test_triple_quote(self):
|
| 27 |
+
self.t(" '''a\n''' ", " '''_L1_''' ")
|
| 28 |
+
|
| 29 |
+
def test_backslash(self):
|
| 30 |
+
self.t(r"'a\'b'", "'_L1_'")
|
| 31 |
+
self.t(r"'a\\'", "'_L1_'")
|
| 32 |
+
self.t(r"'a\\\'b'", "'_L1_'")
|
| 33 |
+
|
| 34 |
+
def test_unicode(self):
|
| 35 |
+
self.t("u'abc'", "u'_L1_'")
|
| 36 |
+
|
| 37 |
+
def test_raw(self):
|
| 38 |
+
self.t(r"r'abc\\'", "r'_L1_'")
|
| 39 |
+
|
| 40 |
+
def test_raw_unicode(self):
|
| 41 |
+
self.t(r"ru'abc\\'", "ru'_L1_'")
|
| 42 |
+
|
| 43 |
+
def test_comment(self):
|
| 44 |
+
self.t("abc # foo", "abc #_L1_")
|
| 45 |
+
|
| 46 |
+
def test_comment_and_quote(self):
|
| 47 |
+
self.t("abc # 'x'", "abc #_L1_")
|
| 48 |
+
self.t("'abc#'", "'_L1_'")
|
| 49 |
+
|
| 50 |
+
def test_include(self):
|
| 51 |
+
self.t("include 'a.pxi' # something here",
|
| 52 |
+
"include '_L1_' #_L2_")
|
| 53 |
+
|
| 54 |
+
def test_extern(self):
|
| 55 |
+
self.t("cdef extern from 'a.h': # comment",
|
| 56 |
+
"cdef extern from '_L1_': #_L2_")
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestCythonizeArgsParser.cpython-311.pyc
ADDED
|
Binary file (41.1 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestDependencies.cpython-311.pyc
ADDED
|
Binary file (11.8 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestInline.cpython-311.pyc
ADDED
|
Binary file (7.92 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestIpythonMagic.cpython-311.pyc
ADDED
|
Binary file (17 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestRecythonize.cpython-311.pyc
ADDED
|
Binary file (13.5 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Build/Tests/__pycache__/TestStripLiterals.cpython-311.pyc
ADDED
|
Binary file (4.57 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/Buffer.py
ADDED
|
@@ -0,0 +1,749 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import absolute_import
|
| 2 |
+
|
| 3 |
+
from .Visitor import CythonTransform
|
| 4 |
+
from .ModuleNode import ModuleNode
|
| 5 |
+
from .Errors import CompileError
|
| 6 |
+
from .UtilityCode import CythonUtilityCode
|
| 7 |
+
from .Code import UtilityCode, TempitaUtilityCode
|
| 8 |
+
|
| 9 |
+
from . import Options
|
| 10 |
+
from . import Interpreter
|
| 11 |
+
from . import PyrexTypes
|
| 12 |
+
from . import Naming
|
| 13 |
+
from . import Symtab
|
| 14 |
+
|
| 15 |
+
def dedent(text, reindent=0):
|
| 16 |
+
from textwrap import dedent
|
| 17 |
+
text = dedent(text)
|
| 18 |
+
if reindent > 0:
|
| 19 |
+
indent = " " * reindent
|
| 20 |
+
text = '\n'.join([indent + x for x in text.split('\n')])
|
| 21 |
+
return text
|
| 22 |
+
|
| 23 |
+
class IntroduceBufferAuxiliaryVars(CythonTransform):
|
| 24 |
+
|
| 25 |
+
#
|
| 26 |
+
# Entry point
|
| 27 |
+
#
|
| 28 |
+
|
| 29 |
+
buffers_exists = False
|
| 30 |
+
using_memoryview = False
|
| 31 |
+
|
| 32 |
+
def __call__(self, node):
|
| 33 |
+
assert isinstance(node, ModuleNode)
|
| 34 |
+
self.max_ndim = 0
|
| 35 |
+
result = super(IntroduceBufferAuxiliaryVars, self).__call__(node)
|
| 36 |
+
if self.buffers_exists:
|
| 37 |
+
use_bufstruct_declare_code(node.scope)
|
| 38 |
+
use_py2_buffer_functions(node.scope)
|
| 39 |
+
|
| 40 |
+
return result
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
#
|
| 44 |
+
# Basic operations for transforms
|
| 45 |
+
#
|
| 46 |
+
def handle_scope(self, node, scope):
|
| 47 |
+
# For all buffers, insert extra variables in the scope.
|
| 48 |
+
# The variables are also accessible from the buffer_info
|
| 49 |
+
# on the buffer entry
|
| 50 |
+
scope_items = scope.entries.items()
|
| 51 |
+
bufvars = [entry for name, entry in scope_items if entry.type.is_buffer]
|
| 52 |
+
if len(bufvars) > 0:
|
| 53 |
+
bufvars.sort(key=lambda entry: entry.name)
|
| 54 |
+
self.buffers_exists = True
|
| 55 |
+
|
| 56 |
+
memviewslicevars = [entry for name, entry in scope_items if entry.type.is_memoryviewslice]
|
| 57 |
+
if len(memviewslicevars) > 0:
|
| 58 |
+
self.buffers_exists = True
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
for (name, entry) in scope_items:
|
| 62 |
+
if name == 'memoryview' and isinstance(entry.utility_code_definition, CythonUtilityCode):
|
| 63 |
+
self.using_memoryview = True
|
| 64 |
+
break
|
| 65 |
+
del scope_items
|
| 66 |
+
|
| 67 |
+
if isinstance(node, ModuleNode) and len(bufvars) > 0:
|
| 68 |
+
# for now...note that pos is wrong
|
| 69 |
+
raise CompileError(node.pos, "Buffer vars not allowed in module scope")
|
| 70 |
+
for entry in bufvars:
|
| 71 |
+
if entry.type.dtype.is_ptr:
|
| 72 |
+
raise CompileError(node.pos, "Buffers with pointer types not yet supported.")
|
| 73 |
+
|
| 74 |
+
name = entry.name
|
| 75 |
+
buftype = entry.type
|
| 76 |
+
if buftype.ndim > Options.buffer_max_dims:
|
| 77 |
+
raise CompileError(node.pos,
|
| 78 |
+
"Buffer ndims exceeds Options.buffer_max_dims = %d" % Options.buffer_max_dims)
|
| 79 |
+
if buftype.ndim > self.max_ndim:
|
| 80 |
+
self.max_ndim = buftype.ndim
|
| 81 |
+
|
| 82 |
+
# Declare auxiliary vars
|
| 83 |
+
def decvar(type, prefix):
|
| 84 |
+
cname = scope.mangle(prefix, name)
|
| 85 |
+
aux_var = scope.declare_var(name=None, cname=cname,
|
| 86 |
+
type=type, pos=node.pos)
|
| 87 |
+
if entry.is_arg:
|
| 88 |
+
aux_var.used = True # otherwise, NameNode will mark whether it is used
|
| 89 |
+
|
| 90 |
+
return aux_var
|
| 91 |
+
|
| 92 |
+
auxvars = ((PyrexTypes.c_pyx_buffer_nd_type, Naming.pybuffernd_prefix),
|
| 93 |
+
(PyrexTypes.c_pyx_buffer_type, Naming.pybufferstruct_prefix))
|
| 94 |
+
pybuffernd, rcbuffer = [decvar(type, prefix) for (type, prefix) in auxvars]
|
| 95 |
+
|
| 96 |
+
entry.buffer_aux = Symtab.BufferAux(pybuffernd, rcbuffer)
|
| 97 |
+
|
| 98 |
+
scope.buffer_entries = bufvars
|
| 99 |
+
self.scope = scope
|
| 100 |
+
|
| 101 |
+
def visit_ModuleNode(self, node):
|
| 102 |
+
self.handle_scope(node, node.scope)
|
| 103 |
+
self.visitchildren(node)
|
| 104 |
+
return node
|
| 105 |
+
|
| 106 |
+
def visit_FuncDefNode(self, node):
|
| 107 |
+
self.handle_scope(node, node.local_scope)
|
| 108 |
+
self.visitchildren(node)
|
| 109 |
+
return node
|
| 110 |
+
|
| 111 |
+
#
|
| 112 |
+
# Analysis
|
| 113 |
+
#
|
| 114 |
+
buffer_options = ("dtype", "ndim", "mode", "negative_indices", "cast") # ordered!
|
| 115 |
+
buffer_defaults = {"ndim": 1, "mode": "full", "negative_indices": True, "cast": False}
|
| 116 |
+
buffer_positional_options_count = 1 # anything beyond this needs keyword argument
|
| 117 |
+
|
| 118 |
+
ERR_BUF_OPTION_UNKNOWN = '"%s" is not a buffer option'
|
| 119 |
+
ERR_BUF_TOO_MANY = 'Too many buffer options'
|
| 120 |
+
ERR_BUF_DUP = '"%s" buffer option already supplied'
|
| 121 |
+
ERR_BUF_MISSING = '"%s" missing'
|
| 122 |
+
ERR_BUF_MODE = 'Only allowed buffer modes are: "c", "fortran", "full", "strided" (as a compile-time string)'
|
| 123 |
+
ERR_BUF_NDIM = 'ndim must be a non-negative integer'
|
| 124 |
+
ERR_BUF_DTYPE = 'dtype must be "object", numeric type or a struct'
|
| 125 |
+
ERR_BUF_BOOL = '"%s" must be a boolean'
|
| 126 |
+
|
| 127 |
+
def analyse_buffer_options(globalpos, env, posargs, dictargs, defaults=None, need_complete=True):
|
| 128 |
+
"""
|
| 129 |
+
Must be called during type analysis, as analyse is called
|
| 130 |
+
on the dtype argument.
|
| 131 |
+
|
| 132 |
+
posargs and dictargs should consist of a list and a dict
|
| 133 |
+
of tuples (value, pos). Defaults should be a dict of values.
|
| 134 |
+
|
| 135 |
+
Returns a dict containing all the options a buffer can have and
|
| 136 |
+
its value (with the positions stripped).
|
| 137 |
+
"""
|
| 138 |
+
if defaults is None:
|
| 139 |
+
defaults = buffer_defaults
|
| 140 |
+
|
| 141 |
+
posargs, dictargs = Interpreter.interpret_compiletime_options(
|
| 142 |
+
posargs, dictargs, type_env=env, type_args=(0, 'dtype'))
|
| 143 |
+
|
| 144 |
+
if len(posargs) > buffer_positional_options_count:
|
| 145 |
+
raise CompileError(posargs[-1][1], ERR_BUF_TOO_MANY)
|
| 146 |
+
|
| 147 |
+
options = {}
|
| 148 |
+
for name, (value, pos) in dictargs.items():
|
| 149 |
+
if name not in buffer_options:
|
| 150 |
+
raise CompileError(pos, ERR_BUF_OPTION_UNKNOWN % name)
|
| 151 |
+
options[name] = value
|
| 152 |
+
|
| 153 |
+
for name, (value, pos) in zip(buffer_options, posargs):
|
| 154 |
+
if name not in buffer_options:
|
| 155 |
+
raise CompileError(pos, ERR_BUF_OPTION_UNKNOWN % name)
|
| 156 |
+
if name in options:
|
| 157 |
+
raise CompileError(pos, ERR_BUF_DUP % name)
|
| 158 |
+
options[name] = value
|
| 159 |
+
|
| 160 |
+
# Check that they are all there and copy defaults
|
| 161 |
+
for name in buffer_options:
|
| 162 |
+
if name not in options:
|
| 163 |
+
try:
|
| 164 |
+
options[name] = defaults[name]
|
| 165 |
+
except KeyError:
|
| 166 |
+
if need_complete:
|
| 167 |
+
raise CompileError(globalpos, ERR_BUF_MISSING % name)
|
| 168 |
+
|
| 169 |
+
dtype = options.get("dtype")
|
| 170 |
+
if dtype and dtype.is_extension_type:
|
| 171 |
+
raise CompileError(globalpos, ERR_BUF_DTYPE)
|
| 172 |
+
|
| 173 |
+
ndim = options.get("ndim")
|
| 174 |
+
if ndim and (not isinstance(ndim, int) or ndim < 0):
|
| 175 |
+
raise CompileError(globalpos, ERR_BUF_NDIM)
|
| 176 |
+
|
| 177 |
+
mode = options.get("mode")
|
| 178 |
+
if mode and not (mode in ('full', 'strided', 'c', 'fortran')):
|
| 179 |
+
raise CompileError(globalpos, ERR_BUF_MODE)
|
| 180 |
+
|
| 181 |
+
def assert_bool(name):
|
| 182 |
+
x = options.get(name)
|
| 183 |
+
if not isinstance(x, bool):
|
| 184 |
+
raise CompileError(globalpos, ERR_BUF_BOOL % name)
|
| 185 |
+
|
| 186 |
+
assert_bool('negative_indices')
|
| 187 |
+
assert_bool('cast')
|
| 188 |
+
|
| 189 |
+
return options
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
#
|
| 193 |
+
# Code generation
|
| 194 |
+
#
|
| 195 |
+
|
| 196 |
+
class BufferEntry(object):
|
| 197 |
+
def __init__(self, entry):
|
| 198 |
+
self.entry = entry
|
| 199 |
+
self.type = entry.type
|
| 200 |
+
self.cname = entry.buffer_aux.buflocal_nd_var.cname
|
| 201 |
+
self.buf_ptr = "%s.rcbuffer->pybuffer.buf" % self.cname
|
| 202 |
+
self.buf_ptr_type = entry.type.buffer_ptr_type
|
| 203 |
+
self.init_attributes()
|
| 204 |
+
|
| 205 |
+
def init_attributes(self):
|
| 206 |
+
self.shape = self.get_buf_shapevars()
|
| 207 |
+
self.strides = self.get_buf_stridevars()
|
| 208 |
+
self.suboffsets = self.get_buf_suboffsetvars()
|
| 209 |
+
|
| 210 |
+
def get_buf_suboffsetvars(self):
|
| 211 |
+
return self._for_all_ndim("%s.diminfo[%d].suboffsets")
|
| 212 |
+
|
| 213 |
+
def get_buf_stridevars(self):
|
| 214 |
+
return self._for_all_ndim("%s.diminfo[%d].strides")
|
| 215 |
+
|
| 216 |
+
def get_buf_shapevars(self):
|
| 217 |
+
return self._for_all_ndim("%s.diminfo[%d].shape")
|
| 218 |
+
|
| 219 |
+
def _for_all_ndim(self, s):
|
| 220 |
+
return [s % (self.cname, i) for i in range(self.type.ndim)]
|
| 221 |
+
|
| 222 |
+
def generate_buffer_lookup_code(self, code, index_cnames):
|
| 223 |
+
# Create buffer lookup and return it
|
| 224 |
+
# This is done via utility macros/inline functions, which vary
|
| 225 |
+
# according to the access mode used.
|
| 226 |
+
params = []
|
| 227 |
+
nd = self.type.ndim
|
| 228 |
+
mode = self.type.mode
|
| 229 |
+
if mode == 'full':
|
| 230 |
+
for i, s, o in zip(index_cnames,
|
| 231 |
+
self.get_buf_stridevars(),
|
| 232 |
+
self.get_buf_suboffsetvars()):
|
| 233 |
+
params.append(i)
|
| 234 |
+
params.append(s)
|
| 235 |
+
params.append(o)
|
| 236 |
+
funcname = "__Pyx_BufPtrFull%dd" % nd
|
| 237 |
+
funcgen = buf_lookup_full_code
|
| 238 |
+
else:
|
| 239 |
+
if mode == 'strided':
|
| 240 |
+
funcname = "__Pyx_BufPtrStrided%dd" % nd
|
| 241 |
+
funcgen = buf_lookup_strided_code
|
| 242 |
+
elif mode == 'c':
|
| 243 |
+
funcname = "__Pyx_BufPtrCContig%dd" % nd
|
| 244 |
+
funcgen = buf_lookup_c_code
|
| 245 |
+
elif mode == 'fortran':
|
| 246 |
+
funcname = "__Pyx_BufPtrFortranContig%dd" % nd
|
| 247 |
+
funcgen = buf_lookup_fortran_code
|
| 248 |
+
else:
|
| 249 |
+
assert False
|
| 250 |
+
for i, s in zip(index_cnames, self.get_buf_stridevars()):
|
| 251 |
+
params.append(i)
|
| 252 |
+
params.append(s)
|
| 253 |
+
|
| 254 |
+
# Make sure the utility code is available
|
| 255 |
+
if funcname not in code.globalstate.utility_codes:
|
| 256 |
+
code.globalstate.utility_codes.add(funcname)
|
| 257 |
+
protocode = code.globalstate['utility_code_proto']
|
| 258 |
+
defcode = code.globalstate['utility_code_def']
|
| 259 |
+
funcgen(protocode, defcode, name=funcname, nd=nd)
|
| 260 |
+
|
| 261 |
+
buf_ptr_type_code = self.buf_ptr_type.empty_declaration_code()
|
| 262 |
+
ptrcode = "%s(%s, %s, %s)" % (funcname, buf_ptr_type_code, self.buf_ptr,
|
| 263 |
+
", ".join(params))
|
| 264 |
+
return ptrcode
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def get_flags(buffer_aux, buffer_type):
|
| 268 |
+
flags = 'PyBUF_FORMAT'
|
| 269 |
+
mode = buffer_type.mode
|
| 270 |
+
if mode == 'full':
|
| 271 |
+
flags += '| PyBUF_INDIRECT'
|
| 272 |
+
elif mode == 'strided':
|
| 273 |
+
flags += '| PyBUF_STRIDES'
|
| 274 |
+
elif mode == 'c':
|
| 275 |
+
flags += '| PyBUF_C_CONTIGUOUS'
|
| 276 |
+
elif mode == 'fortran':
|
| 277 |
+
flags += '| PyBUF_F_CONTIGUOUS'
|
| 278 |
+
else:
|
| 279 |
+
assert False
|
| 280 |
+
if buffer_aux.writable_needed: flags += "| PyBUF_WRITABLE"
|
| 281 |
+
return flags
|
| 282 |
+
|
| 283 |
+
def used_buffer_aux_vars(entry):
|
| 284 |
+
buffer_aux = entry.buffer_aux
|
| 285 |
+
buffer_aux.buflocal_nd_var.used = True
|
| 286 |
+
buffer_aux.rcbuf_var.used = True
|
| 287 |
+
|
| 288 |
+
def put_unpack_buffer_aux_into_scope(buf_entry, code):
|
| 289 |
+
# Generate code to copy the needed struct info into local
|
| 290 |
+
# variables.
|
| 291 |
+
buffer_aux, mode = buf_entry.buffer_aux, buf_entry.type.mode
|
| 292 |
+
pybuffernd_struct = buffer_aux.buflocal_nd_var.cname
|
| 293 |
+
|
| 294 |
+
fldnames = ['strides', 'shape']
|
| 295 |
+
if mode == 'full':
|
| 296 |
+
fldnames.append('suboffsets')
|
| 297 |
+
|
| 298 |
+
ln = []
|
| 299 |
+
for i in range(buf_entry.type.ndim):
|
| 300 |
+
for fldname in fldnames:
|
| 301 |
+
ln.append("%s.diminfo[%d].%s = %s.rcbuffer->pybuffer.%s[%d];" % (
|
| 302 |
+
pybuffernd_struct, i, fldname,
|
| 303 |
+
pybuffernd_struct, fldname, i,
|
| 304 |
+
))
|
| 305 |
+
code.putln(' '.join(ln))
|
| 306 |
+
|
| 307 |
+
def put_init_vars(entry, code):
|
| 308 |
+
bufaux = entry.buffer_aux
|
| 309 |
+
pybuffernd_struct = bufaux.buflocal_nd_var.cname
|
| 310 |
+
pybuffer_struct = bufaux.rcbuf_var.cname
|
| 311 |
+
# init pybuffer_struct
|
| 312 |
+
code.putln("%s.pybuffer.buf = NULL;" % pybuffer_struct)
|
| 313 |
+
code.putln("%s.refcount = 0;" % pybuffer_struct)
|
| 314 |
+
# init the buffer object
|
| 315 |
+
# code.put_init_var_to_py_none(entry)
|
| 316 |
+
# init the pybuffernd_struct
|
| 317 |
+
code.putln("%s.data = NULL;" % pybuffernd_struct)
|
| 318 |
+
code.putln("%s.rcbuffer = &%s;" % (pybuffernd_struct, pybuffer_struct))
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def put_acquire_arg_buffer(entry, code, pos):
|
| 322 |
+
buffer_aux = entry.buffer_aux
|
| 323 |
+
getbuffer = get_getbuffer_call(code, entry.cname, buffer_aux, entry.type)
|
| 324 |
+
|
| 325 |
+
# Acquire any new buffer
|
| 326 |
+
code.putln("{")
|
| 327 |
+
code.putln("__Pyx_BufFmt_StackElem __pyx_stack[%d];" % entry.type.dtype.struct_nesting_depth())
|
| 328 |
+
code.putln(code.error_goto_if("%s == -1" % getbuffer, pos))
|
| 329 |
+
code.putln("}")
|
| 330 |
+
# An exception raised in arg parsing cannot be caught, so no
|
| 331 |
+
# need to care about the buffer then.
|
| 332 |
+
put_unpack_buffer_aux_into_scope(entry, code)
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def put_release_buffer_code(code, entry):
|
| 336 |
+
code.globalstate.use_utility_code(acquire_utility_code)
|
| 337 |
+
code.putln("__Pyx_SafeReleaseBuffer(&%s.rcbuffer->pybuffer);" % entry.buffer_aux.buflocal_nd_var.cname)
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def get_getbuffer_call(code, obj_cname, buffer_aux, buffer_type):
|
| 341 |
+
ndim = buffer_type.ndim
|
| 342 |
+
cast = int(buffer_type.cast)
|
| 343 |
+
flags = get_flags(buffer_aux, buffer_type)
|
| 344 |
+
pybuffernd_struct = buffer_aux.buflocal_nd_var.cname
|
| 345 |
+
|
| 346 |
+
dtype_typeinfo = get_type_information_cname(code, buffer_type.dtype)
|
| 347 |
+
|
| 348 |
+
code.globalstate.use_utility_code(acquire_utility_code)
|
| 349 |
+
return ("__Pyx_GetBufferAndValidate(&%(pybuffernd_struct)s.rcbuffer->pybuffer, "
|
| 350 |
+
"(PyObject*)%(obj_cname)s, &%(dtype_typeinfo)s, %(flags)s, %(ndim)d, "
|
| 351 |
+
"%(cast)d, __pyx_stack)" % locals())
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def put_assign_to_buffer(lhs_cname, rhs_cname, buf_entry,
|
| 355 |
+
is_initialized, pos, code):
|
| 356 |
+
"""
|
| 357 |
+
Generate code for reassigning a buffer variables. This only deals with getting
|
| 358 |
+
the buffer auxiliary structure and variables set up correctly, the assignment
|
| 359 |
+
itself and refcounting is the responsibility of the caller.
|
| 360 |
+
|
| 361 |
+
However, the assignment operation may throw an exception so that the reassignment
|
| 362 |
+
never happens.
|
| 363 |
+
|
| 364 |
+
Depending on the circumstances there are two possible outcomes:
|
| 365 |
+
- Old buffer released, new acquired, rhs assigned to lhs
|
| 366 |
+
- Old buffer released, new acquired which fails, reaqcuire old lhs buffer
|
| 367 |
+
(which may or may not succeed).
|
| 368 |
+
"""
|
| 369 |
+
|
| 370 |
+
buffer_aux, buffer_type = buf_entry.buffer_aux, buf_entry.type
|
| 371 |
+
pybuffernd_struct = buffer_aux.buflocal_nd_var.cname
|
| 372 |
+
flags = get_flags(buffer_aux, buffer_type)
|
| 373 |
+
|
| 374 |
+
code.putln("{") # Set up necessary stack for getbuffer
|
| 375 |
+
code.putln("__Pyx_BufFmt_StackElem __pyx_stack[%d];" % buffer_type.dtype.struct_nesting_depth())
|
| 376 |
+
|
| 377 |
+
getbuffer = get_getbuffer_call(code, "%s", buffer_aux, buffer_type) # fill in object below
|
| 378 |
+
|
| 379 |
+
if is_initialized:
|
| 380 |
+
# Release any existing buffer
|
| 381 |
+
code.putln('__Pyx_SafeReleaseBuffer(&%s.rcbuffer->pybuffer);' % pybuffernd_struct)
|
| 382 |
+
# Acquire
|
| 383 |
+
retcode_cname = code.funcstate.allocate_temp(PyrexTypes.c_int_type, manage_ref=False)
|
| 384 |
+
code.putln("%s = %s;" % (retcode_cname, getbuffer % rhs_cname))
|
| 385 |
+
code.putln('if (%s) {' % (code.unlikely("%s < 0" % retcode_cname)))
|
| 386 |
+
# If acquisition failed, attempt to reacquire the old buffer
|
| 387 |
+
# before raising the exception. A failure of reacquisition
|
| 388 |
+
# will cause the reacquisition exception to be reported, one
|
| 389 |
+
# can consider working around this later.
|
| 390 |
+
exc_temps = tuple(code.funcstate.allocate_temp(PyrexTypes.py_object_type, manage_ref=False)
|
| 391 |
+
for _ in range(3))
|
| 392 |
+
code.putln('PyErr_Fetch(&%s, &%s, &%s);' % exc_temps)
|
| 393 |
+
code.putln('if (%s) {' % code.unlikely("%s == -1" % (getbuffer % lhs_cname)))
|
| 394 |
+
code.putln('Py_XDECREF(%s); Py_XDECREF(%s); Py_XDECREF(%s);' % exc_temps) # Do not refnanny these!
|
| 395 |
+
code.globalstate.use_utility_code(raise_buffer_fallback_code)
|
| 396 |
+
code.putln('__Pyx_RaiseBufferFallbackError();')
|
| 397 |
+
code.putln('} else {')
|
| 398 |
+
code.putln('PyErr_Restore(%s, %s, %s);' % exc_temps)
|
| 399 |
+
code.putln('}')
|
| 400 |
+
code.putln('%s = %s = %s = 0;' % exc_temps)
|
| 401 |
+
for t in exc_temps:
|
| 402 |
+
code.funcstate.release_temp(t)
|
| 403 |
+
code.putln('}')
|
| 404 |
+
# Unpack indices
|
| 405 |
+
put_unpack_buffer_aux_into_scope(buf_entry, code)
|
| 406 |
+
code.putln(code.error_goto_if_neg(retcode_cname, pos))
|
| 407 |
+
code.funcstate.release_temp(retcode_cname)
|
| 408 |
+
else:
|
| 409 |
+
# Our entry had no previous value, so set to None when acquisition fails.
|
| 410 |
+
# In this case, auxiliary vars should be set up right in initialization to a zero-buffer,
|
| 411 |
+
# so it suffices to set the buf field to NULL.
|
| 412 |
+
code.putln('if (%s) {' % code.unlikely("%s == -1" % (getbuffer % rhs_cname)))
|
| 413 |
+
code.putln('%s = %s; __Pyx_INCREF(Py_None); %s.rcbuffer->pybuffer.buf = NULL;' %
|
| 414 |
+
(lhs_cname,
|
| 415 |
+
PyrexTypes.typecast(buffer_type, PyrexTypes.py_object_type, "Py_None"),
|
| 416 |
+
pybuffernd_struct))
|
| 417 |
+
code.putln(code.error_goto(pos))
|
| 418 |
+
code.put('} else {')
|
| 419 |
+
# Unpack indices
|
| 420 |
+
put_unpack_buffer_aux_into_scope(buf_entry, code)
|
| 421 |
+
code.putln('}')
|
| 422 |
+
|
| 423 |
+
code.putln("}") # Release stack
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
def put_buffer_lookup_code(entry, index_signeds, index_cnames, directives,
|
| 427 |
+
pos, code, negative_indices, in_nogil_context):
|
| 428 |
+
"""
|
| 429 |
+
Generates code to process indices and calculate an offset into
|
| 430 |
+
a buffer. Returns a C string which gives a pointer which can be
|
| 431 |
+
read from or written to at will (it is an expression so caller should
|
| 432 |
+
store it in a temporary if it is used more than once).
|
| 433 |
+
|
| 434 |
+
As the bounds checking can have any number of combinations of unsigned
|
| 435 |
+
arguments, smart optimizations etc. we insert it directly in the function
|
| 436 |
+
body. The lookup however is delegated to a inline function that is instantiated
|
| 437 |
+
once per ndim (lookup with suboffsets tend to get quite complicated).
|
| 438 |
+
|
| 439 |
+
entry is a BufferEntry
|
| 440 |
+
"""
|
| 441 |
+
negative_indices = directives['wraparound'] and negative_indices
|
| 442 |
+
|
| 443 |
+
if directives['boundscheck']:
|
| 444 |
+
# Check bounds and fix negative indices.
|
| 445 |
+
# We allocate a temporary which is initialized to -1, meaning OK (!).
|
| 446 |
+
# If an error occurs, the temp is set to the index dimension the
|
| 447 |
+
# error is occurring at.
|
| 448 |
+
failed_dim_temp = code.funcstate.allocate_temp(PyrexTypes.c_int_type, manage_ref=False)
|
| 449 |
+
code.putln("%s = -1;" % failed_dim_temp)
|
| 450 |
+
for dim, (signed, cname, shape) in enumerate(zip(index_signeds, index_cnames, entry.get_buf_shapevars())):
|
| 451 |
+
if signed != 0:
|
| 452 |
+
# not unsigned, deal with negative index
|
| 453 |
+
code.putln("if (%s < 0) {" % cname)
|
| 454 |
+
if negative_indices:
|
| 455 |
+
code.putln("%s += %s;" % (cname, shape))
|
| 456 |
+
code.putln("if (%s) %s = %d;" % (
|
| 457 |
+
code.unlikely("%s < 0" % cname),
|
| 458 |
+
failed_dim_temp, dim))
|
| 459 |
+
else:
|
| 460 |
+
code.putln("%s = %d;" % (failed_dim_temp, dim))
|
| 461 |
+
code.put("} else ")
|
| 462 |
+
# check bounds in positive direction
|
| 463 |
+
if signed != 0:
|
| 464 |
+
cast = ""
|
| 465 |
+
else:
|
| 466 |
+
cast = "(size_t)"
|
| 467 |
+
code.putln("if (%s) %s = %d;" % (
|
| 468 |
+
code.unlikely("%s >= %s%s" % (cname, cast, shape)),
|
| 469 |
+
failed_dim_temp, dim))
|
| 470 |
+
|
| 471 |
+
if in_nogil_context:
|
| 472 |
+
code.globalstate.use_utility_code(raise_indexerror_nogil)
|
| 473 |
+
func = '__Pyx_RaiseBufferIndexErrorNogil'
|
| 474 |
+
else:
|
| 475 |
+
code.globalstate.use_utility_code(raise_indexerror_code)
|
| 476 |
+
func = '__Pyx_RaiseBufferIndexError'
|
| 477 |
+
|
| 478 |
+
code.putln("if (%s) {" % code.unlikely("%s != -1" % failed_dim_temp))
|
| 479 |
+
code.putln('%s(%s);' % (func, failed_dim_temp))
|
| 480 |
+
code.putln(code.error_goto(pos))
|
| 481 |
+
code.putln('}')
|
| 482 |
+
code.funcstate.release_temp(failed_dim_temp)
|
| 483 |
+
elif negative_indices:
|
| 484 |
+
# Only fix negative indices.
|
| 485 |
+
for signed, cname, shape in zip(index_signeds, index_cnames, entry.get_buf_shapevars()):
|
| 486 |
+
if signed != 0:
|
| 487 |
+
code.putln("if (%s < 0) %s += %s;" % (cname, cname, shape))
|
| 488 |
+
|
| 489 |
+
return entry.generate_buffer_lookup_code(code, index_cnames)
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
def use_bufstruct_declare_code(env):
|
| 493 |
+
env.use_utility_code(buffer_struct_declare_code)
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
def buf_lookup_full_code(proto, defin, name, nd):
|
| 497 |
+
"""
|
| 498 |
+
Generates a buffer lookup function for the right number
|
| 499 |
+
of dimensions. The function gives back a void* at the right location.
|
| 500 |
+
"""
|
| 501 |
+
# _i_ndex, _s_tride, sub_o_ffset
|
| 502 |
+
macroargs = ", ".join(["i%d, s%d, o%d" % (i, i, i) for i in range(nd)])
|
| 503 |
+
proto.putln("#define %s(type, buf, %s) (type)(%s_imp(buf, %s))" % (name, macroargs, name, macroargs))
|
| 504 |
+
|
| 505 |
+
funcargs = ", ".join(["Py_ssize_t i%d, Py_ssize_t s%d, Py_ssize_t o%d" % (i, i, i) for i in range(nd)])
|
| 506 |
+
proto.putln("static CYTHON_INLINE void* %s_imp(void* buf, %s);" % (name, funcargs))
|
| 507 |
+
defin.putln(dedent("""
|
| 508 |
+
static CYTHON_INLINE void* %s_imp(void* buf, %s) {
|
| 509 |
+
char* ptr = (char*)buf;
|
| 510 |
+
""") % (name, funcargs) + "".join([dedent("""\
|
| 511 |
+
ptr += s%d * i%d;
|
| 512 |
+
if (o%d >= 0) ptr = *((char**)ptr) + o%d;
|
| 513 |
+
""") % (i, i, i, i) for i in range(nd)]
|
| 514 |
+
) + "\nreturn ptr;\n}")
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
def buf_lookup_strided_code(proto, defin, name, nd):
|
| 518 |
+
"""
|
| 519 |
+
Generates a buffer lookup function for the right number
|
| 520 |
+
of dimensions. The function gives back a void* at the right location.
|
| 521 |
+
"""
|
| 522 |
+
# _i_ndex, _s_tride
|
| 523 |
+
args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)])
|
| 524 |
+
offset = " + ".join(["i%d * s%d" % (i, i) for i in range(nd)])
|
| 525 |
+
proto.putln("#define %s(type, buf, %s) (type)((char*)buf + %s)" % (name, args, offset))
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
def buf_lookup_c_code(proto, defin, name, nd):
|
| 529 |
+
"""
|
| 530 |
+
Similar to strided lookup, but can assume that the last dimension
|
| 531 |
+
doesn't need a multiplication as long as.
|
| 532 |
+
Still we keep the same signature for now.
|
| 533 |
+
"""
|
| 534 |
+
if nd == 1:
|
| 535 |
+
proto.putln("#define %s(type, buf, i0, s0) ((type)buf + i0)" % name)
|
| 536 |
+
else:
|
| 537 |
+
args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)])
|
| 538 |
+
offset = " + ".join(["i%d * s%d" % (i, i) for i in range(nd - 1)])
|
| 539 |
+
proto.putln("#define %s(type, buf, %s) ((type)((char*)buf + %s) + i%d)" % (name, args, offset, nd - 1))
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
def buf_lookup_fortran_code(proto, defin, name, nd):
|
| 543 |
+
"""
|
| 544 |
+
Like C lookup, but the first index is optimized instead.
|
| 545 |
+
"""
|
| 546 |
+
if nd == 1:
|
| 547 |
+
proto.putln("#define %s(type, buf, i0, s0) ((type)buf + i0)" % name)
|
| 548 |
+
else:
|
| 549 |
+
args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)])
|
| 550 |
+
offset = " + ".join(["i%d * s%d" % (i, i) for i in range(1, nd)])
|
| 551 |
+
proto.putln("#define %s(type, buf, %s) ((type)((char*)buf + %s) + i%d)" % (name, args, offset, 0))
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
def use_py2_buffer_functions(env):
|
| 555 |
+
env.use_utility_code(GetAndReleaseBufferUtilityCode())
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
class GetAndReleaseBufferUtilityCode(object):
|
| 559 |
+
# Emulation of PyObject_GetBuffer and PyBuffer_Release for Python 2.
|
| 560 |
+
# For >= 2.6 we do double mode -- use the new buffer interface on objects
|
| 561 |
+
# which has the right tp_flags set, but emulation otherwise.
|
| 562 |
+
|
| 563 |
+
requires = None
|
| 564 |
+
is_cython_utility = False
|
| 565 |
+
|
| 566 |
+
def __init__(self):
|
| 567 |
+
pass
|
| 568 |
+
|
| 569 |
+
def __eq__(self, other):
|
| 570 |
+
return isinstance(other, GetAndReleaseBufferUtilityCode)
|
| 571 |
+
|
| 572 |
+
def __hash__(self):
|
| 573 |
+
return 24342342
|
| 574 |
+
|
| 575 |
+
def get_tree(self, **kwargs): pass
|
| 576 |
+
|
| 577 |
+
def put_code(self, output):
|
| 578 |
+
code = output['utility_code_def']
|
| 579 |
+
proto_code = output['utility_code_proto']
|
| 580 |
+
env = output.module_node.scope
|
| 581 |
+
cython_scope = env.context.cython_scope
|
| 582 |
+
|
| 583 |
+
# Search all types for __getbuffer__ overloads
|
| 584 |
+
types = []
|
| 585 |
+
visited_scopes = set()
|
| 586 |
+
def find_buffer_types(scope):
|
| 587 |
+
if scope in visited_scopes:
|
| 588 |
+
return
|
| 589 |
+
visited_scopes.add(scope)
|
| 590 |
+
for m in scope.cimported_modules:
|
| 591 |
+
find_buffer_types(m)
|
| 592 |
+
for e in scope.type_entries:
|
| 593 |
+
if isinstance(e.utility_code_definition, CythonUtilityCode):
|
| 594 |
+
continue
|
| 595 |
+
t = e.type
|
| 596 |
+
if t.is_extension_type:
|
| 597 |
+
if scope is cython_scope and not e.used:
|
| 598 |
+
continue
|
| 599 |
+
release = get = None
|
| 600 |
+
for x in t.scope.pyfunc_entries:
|
| 601 |
+
if x.name == u"__getbuffer__": get = x.func_cname
|
| 602 |
+
elif x.name == u"__releasebuffer__": release = x.func_cname
|
| 603 |
+
if get:
|
| 604 |
+
types.append((t.typeptr_cname, get, release))
|
| 605 |
+
|
| 606 |
+
find_buffer_types(env)
|
| 607 |
+
|
| 608 |
+
util_code = TempitaUtilityCode.load(
|
| 609 |
+
"GetAndReleaseBuffer", from_file="Buffer.c",
|
| 610 |
+
context=dict(types=types))
|
| 611 |
+
|
| 612 |
+
proto = util_code.format_code(util_code.proto)
|
| 613 |
+
impl = util_code.format_code(
|
| 614 |
+
util_code.inject_string_constants(util_code.impl, output)[1])
|
| 615 |
+
|
| 616 |
+
proto_code.putln(proto)
|
| 617 |
+
code.putln(impl)
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
def mangle_dtype_name(dtype):
|
| 621 |
+
# Use prefixes to separate user defined types from builtins
|
| 622 |
+
# (consider "typedef float unsigned_int")
|
| 623 |
+
if dtype.is_pyobject:
|
| 624 |
+
return "object"
|
| 625 |
+
elif dtype.is_ptr:
|
| 626 |
+
return "ptr"
|
| 627 |
+
else:
|
| 628 |
+
if dtype.is_typedef or dtype.is_struct_or_union:
|
| 629 |
+
prefix = "nn_"
|
| 630 |
+
else:
|
| 631 |
+
prefix = ""
|
| 632 |
+
return prefix + dtype.specialization_name()
|
| 633 |
+
|
| 634 |
+
def get_type_information_cname(code, dtype, maxdepth=None):
|
| 635 |
+
"""
|
| 636 |
+
Output the run-time type information (__Pyx_TypeInfo) for given dtype,
|
| 637 |
+
and return the name of the type info struct.
|
| 638 |
+
|
| 639 |
+
Structs with two floats of the same size are encoded as complex numbers.
|
| 640 |
+
One can separate between complex numbers declared as struct or with native
|
| 641 |
+
encoding by inspecting to see if the fields field of the type is
|
| 642 |
+
filled in.
|
| 643 |
+
"""
|
| 644 |
+
namesuffix = mangle_dtype_name(dtype)
|
| 645 |
+
name = "__Pyx_TypeInfo_%s" % namesuffix
|
| 646 |
+
structinfo_name = "__Pyx_StructFields_%s" % namesuffix
|
| 647 |
+
|
| 648 |
+
if dtype.is_error: return "<error>"
|
| 649 |
+
|
| 650 |
+
# It's critical that walking the type info doesn't use more stack
|
| 651 |
+
# depth than dtype.struct_nesting_depth() returns, so use an assertion for this
|
| 652 |
+
if maxdepth is None: maxdepth = dtype.struct_nesting_depth()
|
| 653 |
+
if maxdepth <= 0:
|
| 654 |
+
assert False
|
| 655 |
+
|
| 656 |
+
if name not in code.globalstate.utility_codes:
|
| 657 |
+
code.globalstate.utility_codes.add(name)
|
| 658 |
+
typecode = code.globalstate['typeinfo']
|
| 659 |
+
|
| 660 |
+
arraysizes = []
|
| 661 |
+
if dtype.is_array:
|
| 662 |
+
while dtype.is_array:
|
| 663 |
+
arraysizes.append(dtype.size)
|
| 664 |
+
dtype = dtype.base_type
|
| 665 |
+
|
| 666 |
+
complex_possible = dtype.is_struct_or_union and dtype.can_be_complex()
|
| 667 |
+
|
| 668 |
+
declcode = dtype.empty_declaration_code()
|
| 669 |
+
if dtype.is_simple_buffer_dtype():
|
| 670 |
+
structinfo_name = "NULL"
|
| 671 |
+
elif dtype.is_struct:
|
| 672 |
+
struct_scope = dtype.scope
|
| 673 |
+
if dtype.is_cv_qualified:
|
| 674 |
+
struct_scope = struct_scope.base_type_scope
|
| 675 |
+
# Must pre-call all used types in order not to recurse during utility code writing.
|
| 676 |
+
fields = struct_scope.var_entries
|
| 677 |
+
assert len(fields) > 0
|
| 678 |
+
types = [get_type_information_cname(code, f.type, maxdepth - 1)
|
| 679 |
+
for f in fields]
|
| 680 |
+
typecode.putln("static __Pyx_StructField %s[] = {" % structinfo_name, safe=True)
|
| 681 |
+
|
| 682 |
+
if dtype.is_cv_qualified:
|
| 683 |
+
# roughly speaking, remove "const" from struct_type
|
| 684 |
+
struct_type = dtype.cv_base_type.empty_declaration_code()
|
| 685 |
+
else:
|
| 686 |
+
struct_type = dtype.empty_declaration_code()
|
| 687 |
+
|
| 688 |
+
for f, typeinfo in zip(fields, types):
|
| 689 |
+
typecode.putln(' {&%s, "%s", offsetof(%s, %s)},' %
|
| 690 |
+
(typeinfo, f.name, struct_type, f.cname), safe=True)
|
| 691 |
+
|
| 692 |
+
typecode.putln(' {NULL, NULL, 0}', safe=True)
|
| 693 |
+
typecode.putln("};", safe=True)
|
| 694 |
+
else:
|
| 695 |
+
assert False
|
| 696 |
+
|
| 697 |
+
rep = str(dtype)
|
| 698 |
+
|
| 699 |
+
flags = "0"
|
| 700 |
+
is_unsigned = "0"
|
| 701 |
+
if dtype is PyrexTypes.c_char_type:
|
| 702 |
+
is_unsigned = "__PYX_IS_UNSIGNED(%s)" % declcode
|
| 703 |
+
typegroup = "'H'"
|
| 704 |
+
elif dtype.is_int:
|
| 705 |
+
is_unsigned = "__PYX_IS_UNSIGNED(%s)" % declcode
|
| 706 |
+
typegroup = "%s ? 'U' : 'I'" % is_unsigned
|
| 707 |
+
elif complex_possible or dtype.is_complex:
|
| 708 |
+
typegroup = "'C'"
|
| 709 |
+
elif dtype.is_float:
|
| 710 |
+
typegroup = "'R'"
|
| 711 |
+
elif dtype.is_struct:
|
| 712 |
+
typegroup = "'S'"
|
| 713 |
+
if dtype.packed:
|
| 714 |
+
flags = "__PYX_BUF_FLAGS_PACKED_STRUCT"
|
| 715 |
+
elif dtype.is_pyobject:
|
| 716 |
+
typegroup = "'O'"
|
| 717 |
+
else:
|
| 718 |
+
assert False, dtype
|
| 719 |
+
|
| 720 |
+
typeinfo = ('static __Pyx_TypeInfo %s = '
|
| 721 |
+
'{ "%s", %s, sizeof(%s), { %s }, %s, %s, %s, %s };')
|
| 722 |
+
tup = (name, rep, structinfo_name, declcode,
|
| 723 |
+
', '.join([str(x) for x in arraysizes]) or '0', len(arraysizes),
|
| 724 |
+
typegroup, is_unsigned, flags)
|
| 725 |
+
typecode.putln(typeinfo % tup, safe=True)
|
| 726 |
+
|
| 727 |
+
return name
|
| 728 |
+
|
| 729 |
+
def load_buffer_utility(util_code_name, context=None, **kwargs):
|
| 730 |
+
if context is None:
|
| 731 |
+
return UtilityCode.load(util_code_name, "Buffer.c", **kwargs)
|
| 732 |
+
else:
|
| 733 |
+
return TempitaUtilityCode.load(util_code_name, "Buffer.c", context=context, **kwargs)
|
| 734 |
+
|
| 735 |
+
context = dict(max_dims=Options.buffer_max_dims)
|
| 736 |
+
buffer_struct_declare_code = load_buffer_utility("BufferStructDeclare", context=context)
|
| 737 |
+
buffer_formats_declare_code = load_buffer_utility("BufferFormatStructs")
|
| 738 |
+
|
| 739 |
+
# Utility function to set the right exception
|
| 740 |
+
# The caller should immediately goto_error
|
| 741 |
+
raise_indexerror_code = load_buffer_utility("BufferIndexError")
|
| 742 |
+
raise_indexerror_nogil = load_buffer_utility("BufferIndexErrorNogil")
|
| 743 |
+
raise_buffer_fallback_code = load_buffer_utility("BufferFallbackError")
|
| 744 |
+
|
| 745 |
+
acquire_utility_code = load_buffer_utility("BufferGetAndValidate", context=context)
|
| 746 |
+
buffer_format_check_code = load_buffer_utility("BufferFormatCheck", context=context)
|
| 747 |
+
|
| 748 |
+
# See utility code BufferFormatFromTypeInfo
|
| 749 |
+
_typeinfo_to_format_code = load_buffer_utility("TypeInfoToFormat")
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/Code.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/FusedNode.cpython-311-x86_64-linux-gnu.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:02498cb7e330a1a7ccae2b142938c3c3c01d80751d06f9ade63bac39f2ab681a
|
| 3 |
+
size 517064
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/Parsing.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# empty file
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Lexicon.cpython-311.pyc
ADDED
|
Binary file (14.5 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Naming.cpython-311.pyc
ADDED
|
Binary file (7.78 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Options.cpython-311.pyc
ADDED
|
Binary file (24.3 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Pythran.cpython-311.pyc
ADDED
|
Binary file (12.8 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/Scanning.cpython-311.pyc
ADDED
|
Binary file (29.7 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/UtilityCode.cpython-311.pyc
ADDED
|
Binary file (13.3 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/networkx/generators/atlas.dat.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:73fc416df0164923607751cb759f4ae81deb5f6550bf25be59c86de3b747e41d
|
| 3 |
+
size 8887
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pyproject_hooks/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Wrappers to call pyproject.toml-based build backend hooks.
|
| 2 |
+
"""
|
| 3 |
+
|
| 4 |
+
from ._impl import (
|
| 5 |
+
BackendInvalid,
|
| 6 |
+
BackendUnavailable,
|
| 7 |
+
BuildBackendHookCaller,
|
| 8 |
+
HookMissing,
|
| 9 |
+
UnsupportedOperation,
|
| 10 |
+
default_subprocess_runner,
|
| 11 |
+
quiet_subprocess_runner,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
__version__ = '1.0.0'
|
| 15 |
+
__all__ = [
|
| 16 |
+
'BackendUnavailable',
|
| 17 |
+
'BackendInvalid',
|
| 18 |
+
'HookMissing',
|
| 19 |
+
'UnsupportedOperation',
|
| 20 |
+
'default_subprocess_runner',
|
| 21 |
+
'quiet_subprocess_runner',
|
| 22 |
+
'BuildBackendHookCaller',
|
| 23 |
+
]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pyproject_hooks/__pycache__/_compat.cpython-311.pyc
ADDED
|
Binary file (425 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pyproject_hooks/_in_process/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""This is a subpackage because the directory is on sys.path for _in_process.py
|
| 2 |
+
|
| 3 |
+
The subpackage should stay as empty as possible to avoid shadowing modules that
|
| 4 |
+
the backend might import.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import importlib.resources as resources
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
resources.files
|
| 11 |
+
except AttributeError:
|
| 12 |
+
# Python 3.8 compatibility
|
| 13 |
+
def _in_proc_script_path():
|
| 14 |
+
return resources.path(__package__, '_in_process.py')
|
| 15 |
+
else:
|
| 16 |
+
def _in_proc_script_path():
|
| 17 |
+
return resources.as_file(
|
| 18 |
+
resources.files(__package__).joinpath('_in_process.py'))
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pyproject_hooks/_in_process/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (1.19 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/code_template.cpython-311.pyc
ADDED
|
Binary file (5.1 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/context.cpython-311.pyc
ADDED
|
Binary file (7.73 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_aoti_c_shim.cpython-311.pyc
ADDED
|
Binary file (18.2 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_backend_stubs.cpython-311.pyc
ADDED
|
Binary file (27 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_executorch.cpython-311.pyc
ADDED
|
Binary file (46.6 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_functionalization_type.cpython-311.pyc
ADDED
|
Binary file (41.1 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_lazy_tensor.cpython-311.pyc
ADDED
|
Binary file (22.2 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_vmap_plumbing.cpython-311.pyc
ADDED
|
Binary file (15.6 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/local.cpython-311.pyc
ADDED
|
Binary file (2.18 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/native_function_generation.cpython-311.pyc
ADDED
|
Binary file (25.3 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycache__/yaml_utils.cpython-311.pyc
ADDED
|
Binary file (1.65 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .lazy_ir import (
|
| 2 |
+
generate_non_native_lazy_ir_nodes as generate_non_native_lazy_ir_nodes,
|
| 3 |
+
GenLazyIR as GenLazyIR,
|
| 4 |
+
GenLazyNativeFuncDefinition as GenLazyNativeFuncDefinition,
|
| 5 |
+
GenLazyShapeInferenceDefinition as GenLazyShapeInferenceDefinition,
|
| 6 |
+
)
|
| 7 |
+
from .native_functions import (
|
| 8 |
+
compute_native_function_declaration as compute_native_function_declaration,
|
| 9 |
+
)
|
| 10 |
+
from .register_dispatch_key import (
|
| 11 |
+
gen_registration_headers as gen_registration_headers,
|
| 12 |
+
gen_registration_helpers as gen_registration_helpers,
|
| 13 |
+
RegisterDispatchKey as RegisterDispatchKey,
|
| 14 |
+
)
|
| 15 |
+
from .ufunc import (
|
| 16 |
+
compute_ufunc_cpu as compute_ufunc_cpu,
|
| 17 |
+
compute_ufunc_cpu_kernel as compute_ufunc_cpu_kernel,
|
| 18 |
+
compute_ufunc_cuda as compute_ufunc_cuda,
|
| 19 |
+
)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/__pycache__/lazy_ts_lowering.cpython-311.pyc
ADDED
|
Binary file (3.54 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/__pycache__/ufunc.cpython-311.pyc
ADDED
|
Binary file (28.1 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/lazy_ir.py
ADDED
|
@@ -0,0 +1,707 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
from abc import ABC
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import torchgen.api.dispatcher as dispatcher
|
| 7 |
+
from torchgen.api.lazy import (
|
| 8 |
+
getValueT,
|
| 9 |
+
isValueType,
|
| 10 |
+
LazyArgument,
|
| 11 |
+
LazyIrProperties,
|
| 12 |
+
LazyIrSchema,
|
| 13 |
+
tensorListValueT,
|
| 14 |
+
)
|
| 15 |
+
from torchgen.api.translate import translate
|
| 16 |
+
from torchgen.api.types import (
|
| 17 |
+
BaseCType,
|
| 18 |
+
Binding,
|
| 19 |
+
deviceT,
|
| 20 |
+
DispatcherSignature,
|
| 21 |
+
kernel_signature,
|
| 22 |
+
NativeSignature,
|
| 23 |
+
OptionalCType,
|
| 24 |
+
VectorCType,
|
| 25 |
+
)
|
| 26 |
+
from torchgen.context import method_with_native_function
|
| 27 |
+
from torchgen.dest.lazy_ts_lowering import ts_lowering_body
|
| 28 |
+
from torchgen.model import (
|
| 29 |
+
Argument,
|
| 30 |
+
BackendIndex,
|
| 31 |
+
BackendMetadata,
|
| 32 |
+
BaseTy,
|
| 33 |
+
BaseType,
|
| 34 |
+
FunctionSchema,
|
| 35 |
+
ListType,
|
| 36 |
+
NativeFunction,
|
| 37 |
+
NativeFunctionsGroup,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def node_ctor_arg_rvalue_string(arg: LazyArgument) -> str:
|
| 42 |
+
"""
|
| 43 |
+
Given a LazyArgument,
|
| 44 |
+
generate a c++ string for materializing an rvalue of that arg for passing into
|
| 45 |
+
a lazy Node constructor.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
# TODO: Matching on CType seems wrong; should be matching on Type
|
| 49 |
+
if isValueType(arg.lazy_type):
|
| 50 |
+
if isinstance(arg.lazy_type, BaseCType):
|
| 51 |
+
if arg.is_wrapped_scalar:
|
| 52 |
+
return f"node_{arg.name}"
|
| 53 |
+
elif arg.lazy_type.type is tensorListValueT:
|
| 54 |
+
return f"lazy_{arg.name}_tensorlist"
|
| 55 |
+
elif arg.is_symint_or_list:
|
| 56 |
+
return f"GetSymIntValue({arg.name})"
|
| 57 |
+
return f"lazy_{arg.name}->GetIrValue()"
|
| 58 |
+
elif isinstance(arg.lazy_type, OptionalCType):
|
| 59 |
+
if arg.is_symint_or_list:
|
| 60 |
+
# TODO: I don't understand when you should put lazy_ in the name
|
| 61 |
+
# or not
|
| 62 |
+
return f"{arg.name} ? c10::make_optional(GetSymIntValue(*{arg.name})) : c10::nullopt"
|
| 63 |
+
elif arg.is_wrapped_scalar:
|
| 64 |
+
return f"node_{arg.name}"
|
| 65 |
+
return (
|
| 66 |
+
f"lazy_{arg.name} ? "
|
| 67 |
+
f"c10::make_optional(lazy_{arg.name}->GetIrValue()) : "
|
| 68 |
+
"c10::nullopt"
|
| 69 |
+
)
|
| 70 |
+
else:
|
| 71 |
+
raise AssertionError(
|
| 72 |
+
f"TODO not sure if there are other valid types to handle here ({arg.lazy_type})"
|
| 73 |
+
)
|
| 74 |
+
else:
|
| 75 |
+
# NB: this is here because right now we aren't treating SymInt[] as a
|
| 76 |
+
# value type; when we do this needs to move above
|
| 77 |
+
# NB: we cannot test arg.lazy_type as we've already specified it is an
|
| 78 |
+
# int64_t and so we cannot distinguish between SymInt and int64_t
|
| 79 |
+
if isinstance(arg.orig_type, ListType) and arg.orig_type.elem == BaseType(
|
| 80 |
+
BaseTy.SymInt
|
| 81 |
+
):
|
| 82 |
+
if arg.symint:
|
| 83 |
+
return f"GetSymIntArrayRefValue({arg.name})"
|
| 84 |
+
else:
|
| 85 |
+
return f"std::vector<int64_t>({arg.name}.begin(), {arg.name}.end())"
|
| 86 |
+
elif isinstance(arg.lazy_type, VectorCType) and isinstance(
|
| 87 |
+
arg.lazy_type.elem, BaseCType
|
| 88 |
+
):
|
| 89 |
+
return f"std::vector<{arg.lazy_type.elem.type}>({arg.name}.begin(), {arg.name}.end())"
|
| 90 |
+
elif (
|
| 91 |
+
isinstance(arg.lazy_type, OptionalCType)
|
| 92 |
+
and isinstance(arg.lazy_type.elem, VectorCType)
|
| 93 |
+
and isinstance(arg.lazy_type.elem.elem, BaseCType)
|
| 94 |
+
):
|
| 95 |
+
return f"torch::lazy::ToOptionalVector<{arg.lazy_type.elem.elem.type}>({arg.name})"
|
| 96 |
+
else:
|
| 97 |
+
return f"{arg.name}"
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def node_ctor_inputs(schema: LazyIrSchema) -> str:
|
| 101 |
+
"""
|
| 102 |
+
Produce a formatted string with the arguments as passed into the constructor of a node class.
|
| 103 |
+
"""
|
| 104 |
+
node_ctor_values = [
|
| 105 |
+
node_ctor_arg_rvalue_string(arg) for arg in schema.filtered_args()
|
| 106 |
+
]
|
| 107 |
+
return ", ".join(node_ctor_values)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def gen_fallback_code(
|
| 111 |
+
schema: LazyIrSchema,
|
| 112 |
+
sig: Union[DispatcherSignature, NativeSignature],
|
| 113 |
+
overload_name: str,
|
| 114 |
+
) -> str:
|
| 115 |
+
"""
|
| 116 |
+
Generate code that falls back to eager conditioned on a predicate
|
| 117 |
+
"""
|
| 118 |
+
dispatcher_sig = DispatcherSignature.from_schema(schema.func)
|
| 119 |
+
exprs = translate(sig.arguments(), dispatcher_sig.arguments())
|
| 120 |
+
fallback_args = ",\n ".join([a.expr for a in exprs])
|
| 121 |
+
if len(overload_name):
|
| 122 |
+
aten_op_str = f"ATEN_OP2({schema.aten_name}, {overload_name})"
|
| 123 |
+
else:
|
| 124 |
+
aten_op_str = f"ATEN_OP({schema.aten_name})"
|
| 125 |
+
return f"""
|
| 126 |
+
if (force_eager_fallback({aten_symbol(schema)})) {{
|
| 127 |
+
return at::native::call_fallback_fn_symint<<c_eager_fallback, {aten_op_str}>::call(
|
| 128 |
+
{fallback_args}
|
| 129 |
+
);
|
| 130 |
+
}}
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def aten_symbol(schema: LazyIrSchema) -> str:
|
| 135 |
+
missing_interned_strings = {
|
| 136 |
+
"sigmoid_backward",
|
| 137 |
+
}
|
| 138 |
+
if schema.aten_name in missing_interned_strings:
|
| 139 |
+
return f'c10::Symbol::fromQualString("aten::{schema.aten_name}")'
|
| 140 |
+
|
| 141 |
+
if not schema.aten_name.startswith("at::"):
|
| 142 |
+
return f"at::aten::{schema.aten_name}"
|
| 143 |
+
else:
|
| 144 |
+
return schema.aten_name
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
# converts all tensor-like arguments to meta tensors. Returns:
|
| 148 |
+
# (1) a string containing all of the logic that does the conversions.
|
| 149 |
+
# (2) a context, to be used by translate(), with all of the relevant bindings.
|
| 150 |
+
def convert_to_meta_tensors(sig: DispatcherSignature) -> Tuple[str, List[Binding]]:
|
| 151 |
+
context: List[Binding] = []
|
| 152 |
+
unwrapped_tensor_args: List[str] = []
|
| 153 |
+
for arg in sig.arguments():
|
| 154 |
+
if isinstance(arg.argument, Argument) and arg.argument.type.is_tensor_like():
|
| 155 |
+
unwrapped_name = f"{arg.name}_meta"
|
| 156 |
+
unwrapped_tensor_args.append(
|
| 157 |
+
f"auto {unwrapped_name} = to_meta({arg.name});"
|
| 158 |
+
)
|
| 159 |
+
context.append(arg.with_name(unwrapped_name))
|
| 160 |
+
else:
|
| 161 |
+
context.append(arg)
|
| 162 |
+
unwrap_tensor_args_str = "\n ".join(unwrapped_tensor_args)
|
| 163 |
+
return unwrap_tensor_args_str, context
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
@dataclass(frozen=True)
|
| 167 |
+
class GenLazyIR(ABC):
|
| 168 |
+
backend_index: BackendIndex
|
| 169 |
+
backend_name: str
|
| 170 |
+
node_base: str
|
| 171 |
+
use_lazy_shape: bool
|
| 172 |
+
|
| 173 |
+
@method_with_native_function
|
| 174 |
+
def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]:
|
| 175 |
+
func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
|
| 176 |
+
metadata = self.backend_index.get_kernel(
|
| 177 |
+
f.functional if isinstance(f, NativeFunctionsGroup) else f
|
| 178 |
+
)
|
| 179 |
+
schema = LazyIrSchema(
|
| 180 |
+
func, symint=metadata is not None and metadata.supports_symint()
|
| 181 |
+
)
|
| 182 |
+
return self.gen(schema)
|
| 183 |
+
|
| 184 |
+
# there is no lowering functionality generated unless this IR base class is subclassed and
|
| 185 |
+
# implemented as a backend-specific node
|
| 186 |
+
def lowering_function(self, schema: LazyIrSchema) -> str:
|
| 187 |
+
return ""
|
| 188 |
+
|
| 189 |
+
def create_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
|
| 190 |
+
return ""
|
| 191 |
+
|
| 192 |
+
def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
|
| 193 |
+
return f"""bool CanBeReused({node_ctor_args}) const {{
|
| 194 |
+
return false;
|
| 195 |
+
}}"""
|
| 196 |
+
|
| 197 |
+
def node_base_ctor_call(self, schema: LazyIrSchema) -> str:
|
| 198 |
+
value_args = schema.filtered_args(values=True, scalars=False)
|
| 199 |
+
# backends can customize the way the node base class constructor is called,
|
| 200 |
+
# as long as all of its arguments can be generated from information available from the schema
|
| 201 |
+
base_ctor_value_args_list = []
|
| 202 |
+
for arg in value_args:
|
| 203 |
+
if isinstance(arg.lazy_type, (BaseCType, VectorCType)):
|
| 204 |
+
base_ctor_value_args_list.append(f"{arg.name}")
|
| 205 |
+
elif isinstance(arg.lazy_type, OptionalCType):
|
| 206 |
+
base_ctor_value_args_list.append(f"{arg.name}.value_or(kNullValue)")
|
| 207 |
+
else:
|
| 208 |
+
raise AssertionError(
|
| 209 |
+
f"Unsupported type ({arg.lazy_type}) - add support if necessary"
|
| 210 |
+
)
|
| 211 |
+
base_ctor_value_args = ", ".join(base_ctor_value_args_list)
|
| 212 |
+
|
| 213 |
+
scalar_args = schema.filtered_args(values=False, scalars=True)
|
| 214 |
+
|
| 215 |
+
# Shape construction.
|
| 216 |
+
# Conditionally build shape depending on specified shape property
|
| 217 |
+
if schema.properties.ShapePrecompute:
|
| 218 |
+
shape_ctor_arg = "std::move(shapes),"
|
| 219 |
+
elif schema.properties.ShapeCompute:
|
| 220 |
+
shape_args = [a.name for a in value_args]
|
| 221 |
+
shape_args.extend(a.name for a in scalar_args)
|
| 222 |
+
shape_ctor_arg = f"compute_shape_{schema.name}({', '.join(shape_args)}),"
|
| 223 |
+
elif schema.properties.ShapeCache:
|
| 224 |
+
shape_args = [f"operand({i})" for i in range(len(value_args))]
|
| 225 |
+
shape_args.extend(a.name for a in scalar_args)
|
| 226 |
+
shape_ctor_arg = f"[&](){{ return compute_shape_{schema.name}({', '.join(shape_args)})[0]; }},"
|
| 227 |
+
else:
|
| 228 |
+
shape_ctor_arg = ""
|
| 229 |
+
|
| 230 |
+
scalar_hashes = ", ".join(f"{a.name}" for a in scalar_args)
|
| 231 |
+
|
| 232 |
+
return f"""{self.node_base}(
|
| 233 |
+
{schema.node_name}::ClassOpKind(),
|
| 234 |
+
OpList{{{base_ctor_value_args}}},
|
| 235 |
+
{shape_ctor_arg}
|
| 236 |
+
/* num_outputs */ {len(schema.returns)},
|
| 237 |
+
torch::lazy::MHash({scalar_hashes}))"""
|
| 238 |
+
|
| 239 |
+
def gen(self, schema: LazyIrSchema) -> List[str]:
|
| 240 |
+
opkind = schema.opkind or aten_symbol(schema)
|
| 241 |
+
|
| 242 |
+
# for now, we just want one IR class decl and soon after also the method defs
|
| 243 |
+
# and we use the functional version not out/inplace.
|
| 244 |
+
all_args = schema.filtered_args()
|
| 245 |
+
value_args = schema.filtered_args(values=True, scalars=False)
|
| 246 |
+
scalar_args = schema.filtered_args(values=False, scalars=True)
|
| 247 |
+
|
| 248 |
+
ctor_args = [f"const {i.lazy_type.cpp_type()}& {i.name}" for i in all_args]
|
| 249 |
+
reuse_ctor_args = ", ".join(ctor_args)
|
| 250 |
+
if self.use_lazy_shape and schema.properties.ShapePrecompute:
|
| 251 |
+
ctor_args.append("std::vector<torch::lazy::Shape>&& shapes")
|
| 252 |
+
node_ctor_args = ", ".join(ctor_args)
|
| 253 |
+
|
| 254 |
+
scalar_initializers = ",\n ".join(
|
| 255 |
+
[
|
| 256 |
+
# This code is just special casing the mapping from string_view -> strings
|
| 257 |
+
f"{a.name}({a.name}.has_value() ? c10::make_optional(std::string(*{a.name})) : c10::nullopt)"
|
| 258 |
+
if a.lazy_type.cpp_type() == "c10::optional<c10::string_view>"
|
| 259 |
+
else f"{a.name}({a.name})"
|
| 260 |
+
for a in scalar_args
|
| 261 |
+
]
|
| 262 |
+
)
|
| 263 |
+
if len(scalar_initializers):
|
| 264 |
+
scalar_initializers = f",\n {scalar_initializers}"
|
| 265 |
+
scalar_decls = "\n ".join(
|
| 266 |
+
[
|
| 267 |
+
f"std::string {a.name};"
|
| 268 |
+
if a.lazy_type.cpp_type() == "c10::string_view"
|
| 269 |
+
else f"c10::optional<std::string> {a.name};"
|
| 270 |
+
if a.lazy_type.cpp_type() == "c10::optional<c10::string_view>"
|
| 271 |
+
else f"{a.lazy_type.cpp_type()} {a.name};"
|
| 272 |
+
for a in scalar_args
|
| 273 |
+
]
|
| 274 |
+
)
|
| 275 |
+
optional_values = [
|
| 276 |
+
arg.name
|
| 277 |
+
for arg in schema.filtered_args(values=True, scalars=False)
|
| 278 |
+
if isinstance(arg.lazy_type, OptionalCType)
|
| 279 |
+
]
|
| 280 |
+
has_optional_decls = "\n ".join(
|
| 281 |
+
[f"bool has_{value}: 1;" for value in optional_values]
|
| 282 |
+
)
|
| 283 |
+
has_optional_defs = "\n ".join(
|
| 284 |
+
[f"has_{value} = !!{value};" for value in optional_values]
|
| 285 |
+
)
|
| 286 |
+
members_to_string = []
|
| 287 |
+
for arg in scalar_args:
|
| 288 |
+
if isinstance(arg.lazy_type, OptionalCType):
|
| 289 |
+
value = f"{arg.name}.value()"
|
| 290 |
+
if arg.is_generator:
|
| 291 |
+
value = '"torch.Generator()"'
|
| 292 |
+
members_to_string.append(
|
| 293 |
+
f"""if ({arg.name}.has_value()) {{
|
| 294 |
+
ss << ", {arg.name}=" << {value};
|
| 295 |
+
}} else {{
|
| 296 |
+
ss << ", {arg.name}=null";
|
| 297 |
+
}}"""
|
| 298 |
+
)
|
| 299 |
+
else:
|
| 300 |
+
members_to_string.append(f'ss << ", {arg.name}=" << {arg.name};')
|
| 301 |
+
members_to_string_str = "\n ".join(members_to_string)
|
| 302 |
+
|
| 303 |
+
return [
|
| 304 |
+
f"""\
|
| 305 |
+
class {schema.node_name} : public {self.node_base} {{
|
| 306 |
+
public:
|
| 307 |
+
static torch::lazy::OpKind ClassOpKind() {{
|
| 308 |
+
return torch::lazy::OpKind({opkind});
|
| 309 |
+
}}
|
| 310 |
+
|
| 311 |
+
{schema.node_name}({node_ctor_args})
|
| 312 |
+
: {self.node_base_ctor_call(schema)}{scalar_initializers}
|
| 313 |
+
{{
|
| 314 |
+
{has_optional_defs}
|
| 315 |
+
}}
|
| 316 |
+
|
| 317 |
+
std::string ToString() const override {{
|
| 318 |
+
std::stringstream ss;
|
| 319 |
+
ss << {self.node_base}::ToString();
|
| 320 |
+
{members_to_string_str}
|
| 321 |
+
return ss.str();
|
| 322 |
+
}}
|
| 323 |
+
|
| 324 |
+
{self.create_function(schema, reuse_ctor_args)}
|
| 325 |
+
|
| 326 |
+
{self.can_be_reused_function(schema, reuse_ctor_args)}
|
| 327 |
+
|
| 328 |
+
{self.lowering_function(schema)}
|
| 329 |
+
|
| 330 |
+
{scalar_decls}
|
| 331 |
+
{has_optional_decls}
|
| 332 |
+
|
| 333 |
+
}};
|
| 334 |
+
|
| 335 |
+
""",
|
| 336 |
+
]
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
@dataclass(frozen=True)
|
| 340 |
+
class GenTSLazyIR(GenLazyIR):
|
| 341 |
+
def lowering_function(self, schema: LazyIrSchema) -> str:
|
| 342 |
+
signature = """
|
| 343 |
+
torch::lazy::TSOpVector Lower(
|
| 344 |
+
std::shared_ptr<torch::jit::GraphFunction> function,
|
| 345 |
+
torch::lazy::TSLoweringContext* loctx) const override"""
|
| 346 |
+
|
| 347 |
+
if schema.properties.LowerDeclOnly:
|
| 348 |
+
return f"{signature};"
|
| 349 |
+
elif schema.properties.Lower:
|
| 350 |
+
return f"""{signature} {{
|
| 351 |
+
{ts_lowering_body(schema)}
|
| 352 |
+
}}
|
| 353 |
+
"""
|
| 354 |
+
else:
|
| 355 |
+
return ""
|
| 356 |
+
|
| 357 |
+
def create_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
|
| 358 |
+
signature = f"static NodePtr Create({node_ctor_args})"
|
| 359 |
+
if schema.properties.CreateFnDeclOnly:
|
| 360 |
+
return f"{signature};"
|
| 361 |
+
elif not schema.properties.CreateFn:
|
| 362 |
+
return ""
|
| 363 |
+
return f"""{signature} {{
|
| 364 |
+
return ReuseOrMakeNode<{schema.node_name}>(data);
|
| 365 |
+
}}"""
|
| 366 |
+
|
| 367 |
+
def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
|
| 368 |
+
signature = f"bool CanBeReused({node_ctor_args}) const"
|
| 369 |
+
if schema.properties.CanBeReusedDeclOnly:
|
| 370 |
+
return f"{signature};"
|
| 371 |
+
elif not schema.properties.CanBeReused:
|
| 372 |
+
return ""
|
| 373 |
+
value_comparison = []
|
| 374 |
+
for arg in itertools.chain(schema.positional_values, schema.keyword_values):
|
| 375 |
+
if isinstance(arg.lazy_type, OptionalCType):
|
| 376 |
+
value_comparison.append(
|
| 377 |
+
f"nullable_operand(i++) == {arg.name}.value_or(kNullValue)"
|
| 378 |
+
)
|
| 379 |
+
else:
|
| 380 |
+
value_comparison.append(f"operand(i++) == {arg.name}")
|
| 381 |
+
for arg in itertools.chain(schema.positional_scalars, schema.keyword_scalars):
|
| 382 |
+
if isinstance(arg.lazy_type, OptionalCType):
|
| 383 |
+
value_comparison.append(
|
| 384 |
+
f"((!this->{arg.name}&&!{arg.name}) || (this->{arg.name}&&{arg.name} && *(this->{arg.name}) == *{arg.name}))"
|
| 385 |
+
)
|
| 386 |
+
else:
|
| 387 |
+
value_comparison.append(f"this->{arg.name} == {arg.name}")
|
| 388 |
+
value_comparison_str = " &&\n ".join(value_comparison)
|
| 389 |
+
|
| 390 |
+
return f"""{signature} {{
|
| 391 |
+
size_t i = 0;
|
| 392 |
+
return ({value_comparison_str});
|
| 393 |
+
}}"""
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
@dataclass(frozen=True)
|
| 397 |
+
class GenLazyNativeFuncDefinition:
|
| 398 |
+
class_method_name: str
|
| 399 |
+
backend_index: BackendIndex
|
| 400 |
+
tensor_class: str
|
| 401 |
+
gen_forced_fallback_code: bool
|
| 402 |
+
backend_namespace: str
|
| 403 |
+
get_tensorlist: str
|
| 404 |
+
get_tensor_or_wrap_number: str
|
| 405 |
+
try_get_tensor: str
|
| 406 |
+
metrics_counter: str
|
| 407 |
+
create_tensor: str
|
| 408 |
+
create_from_first_tensor: bool
|
| 409 |
+
create_aten_from_ltc_tensor: str
|
| 410 |
+
tuple_aten_from_ltc_tensors: str
|
| 411 |
+
lazy_tensor_ptr: str
|
| 412 |
+
get_device_fn: str
|
| 413 |
+
|
| 414 |
+
def lazy_tensor_decls(self, func: NativeFunction, schema: LazyIrSchema) -> str:
|
| 415 |
+
value_args = schema.filtered_args(values=True, scalars=False)
|
| 416 |
+
# Generates lazy_{name} variables for LazyTensors wrapping input tensors
|
| 417 |
+
lazy_tensor_decls: List[str] = []
|
| 418 |
+
for arg in value_args:
|
| 419 |
+
if arg.is_wrapped_scalar:
|
| 420 |
+
if isinstance(arg.lazy_type, OptionalCType):
|
| 421 |
+
lazy_tensor_decls.append(
|
| 422 |
+
f"""auto node_{arg.name} = {arg.name} ?
|
| 423 |
+
c10::make_optional(torch::lazy::LazyGraphExecutor::Get()->
|
| 424 |
+
GetIrValueForScalarFromCodegen(*{arg.name}, *common_device)):
|
| 425 |
+
c10::nullopt;"""
|
| 426 |
+
)
|
| 427 |
+
else:
|
| 428 |
+
lazy_tensor_decls.append(
|
| 429 |
+
f"""auto node_{arg.name} = torch::lazy::LazyGraphExecutor::Get()->
|
| 430 |
+
GetIrValueForScalarFromCodegen({arg.name}, *common_device);"""
|
| 431 |
+
)
|
| 432 |
+
elif arg.is_symint_or_list:
|
| 433 |
+
continue # values are extracted in isValueType
|
| 434 |
+
elif isinstance(arg.lazy_type, BaseCType):
|
| 435 |
+
if arg.lazy_type.type is tensorListValueT:
|
| 436 |
+
lazy_tensor_decls.append(
|
| 437 |
+
f"auto lazy_{arg.name}_tensorlist = "
|
| 438 |
+
f"{self.backend_namespace}::{self.get_tensorlist}({arg.name});"
|
| 439 |
+
)
|
| 440 |
+
else:
|
| 441 |
+
lazy_tensor_decls.append(
|
| 442 |
+
f"{self.lazy_tensor_ptr} lazy_{arg.name} = "
|
| 443 |
+
f"{self.backend_namespace}::{self.get_tensor_or_wrap_number}({arg.name}, *common_device);"
|
| 444 |
+
)
|
| 445 |
+
elif isinstance(arg.lazy_type, OptionalCType):
|
| 446 |
+
assert arg.lazy_type.elem == BaseCType(getValueT()), arg.lazy_type.elem
|
| 447 |
+
# TODO(alanwaketan): Maybe we want to apply GetLtcTensorOrCreateForWrappedNumber here, but hold it
|
| 448 |
+
# until we encounter a real world example.
|
| 449 |
+
lazy_tensor_decls.append(
|
| 450 |
+
f"{self.lazy_tensor_ptr} lazy_{arg.name} = "
|
| 451 |
+
f"{self.backend_namespace}::{self.try_get_tensor}({arg.name}.value_or(at::Tensor()));"
|
| 452 |
+
)
|
| 453 |
+
else:
|
| 454 |
+
raise AssertionError(
|
| 455 |
+
f"TODO not sure if there are other valid types to handle here ({arg.lazy_type})"
|
| 456 |
+
)
|
| 457 |
+
return ("\n ").join(lazy_tensor_decls)
|
| 458 |
+
|
| 459 |
+
def force_eager_fallback(
|
| 460 |
+
self,
|
| 461 |
+
func: NativeFunction,
|
| 462 |
+
schema: LazyIrSchema,
|
| 463 |
+
metadata: BackendMetadata,
|
| 464 |
+
sig: Union[DispatcherSignature, NativeSignature],
|
| 465 |
+
) -> str:
|
| 466 |
+
if self.gen_forced_fallback_code:
|
| 467 |
+
return gen_fallback_code(
|
| 468 |
+
schema, sig, overload_name=func.func.name.overload_name
|
| 469 |
+
)
|
| 470 |
+
return ""
|
| 471 |
+
|
| 472 |
+
def metrics(self, func: NativeFunction, schema: LazyIrSchema) -> str:
|
| 473 |
+
return f"{self.metrics_counter};"
|
| 474 |
+
|
| 475 |
+
def get_device(self, func: NativeFunction, schema: LazyIrSchema) -> str:
|
| 476 |
+
value_args = schema.filtered_args(values=True, scalars=False)
|
| 477 |
+
scalar_args = schema.filtered_args(values=False, scalars=True)
|
| 478 |
+
value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar]
|
| 479 |
+
optional_device = OptionalCType(BaseCType(deviceT))
|
| 480 |
+
optional_devices = [
|
| 481 |
+
a.name for a in scalar_args if a.lazy_type == optional_device
|
| 482 |
+
]
|
| 483 |
+
assert (
|
| 484 |
+
len(value_types_names) > 0 or len(optional_devices) > 0
|
| 485 |
+
), "Expected at least one Value or Device type"
|
| 486 |
+
get_device_str = (
|
| 487 |
+
f"{self.get_device_fn}({', '.join(value_types_names + optional_devices)})"
|
| 488 |
+
)
|
| 489 |
+
return f"""auto common_device = {get_device_str};
|
| 490 |
+
TORCH_INTERNAL_ASSERT(common_device);
|
| 491 |
+
"""
|
| 492 |
+
|
| 493 |
+
def shape_inference(self, func: NativeFunction, schema: LazyIrSchema) -> str:
|
| 494 |
+
metadata = self.backend_index.get_kernel(func)
|
| 495 |
+
assert metadata is not None
|
| 496 |
+
all_args = schema.filtered_args()
|
| 497 |
+
returns_length = len(schema.returns)
|
| 498 |
+
# call the meta kernel if it exists, to compute output shape/dtype for our IR
|
| 499 |
+
# Note [Generated LTC Shape Functions]
|
| 500 |
+
# LTC uses meta tensors from core to do shape inference when possible, and otherwise
|
| 501 |
+
# we generate a shape function declaration that needs to be manually implemented.
|
| 502 |
+
# How do we detect which ops are eligible to use meta tensors?
|
| 503 |
+
# In general we should be able to use meta tensors not just on structured operators,
|
| 504 |
+
# but also on composite operators that are implemented in terms of structured kernels.
|
| 505 |
+
# We don't currently have a way of knowing at codegen time which ops are implemented that way.
|
| 506 |
+
# This is the case for all view and view_copy operators however, so we're going to
|
| 507 |
+
# use them specifically for all of the view_copy ops (instead of manually writing shape rules for all of them).
|
| 508 |
+
is_view_copy_op = "view_copy" in func.tags
|
| 509 |
+
is_structured = func.structured or func.structured_delegate is not None
|
| 510 |
+
if is_structured or is_view_copy_op:
|
| 511 |
+
meta_out = """
|
| 512 |
+
std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};"""
|
| 513 |
+
if returns_length > 1:
|
| 514 |
+
|
| 515 |
+
def this_shape(i: int) -> str:
|
| 516 |
+
return f"torch::lazy::Shape(std::get<{i}>(out_meta).scalar_type(), std::get<{i}>(out_meta).sizes().vec())"
|
| 517 |
+
|
| 518 |
+
shapes_str = ",".join([this_shape(i) for i in range(returns_length)])
|
| 519 |
+
meta_out = "std::vector<torch::lazy::Shape> shapes{" + shapes_str + "};"
|
| 520 |
+
|
| 521 |
+
# Convert tensor args to the meta device and call it.
|
| 522 |
+
# (We can't pass in the input tensors directly, because they are "functional wrappers".
|
| 523 |
+
# If any of the meta kernels call a tensor op and redispatch, we don't want to hit the functionalize kernels.)
|
| 524 |
+
# Even at::meta:: functions might redispatch, e.g. if they call into view ops.
|
| 525 |
+
dispatcher_sig = DispatcherSignature.from_schema(func.func)
|
| 526 |
+
meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig)
|
| 527 |
+
meta_call_args = [
|
| 528 |
+
e.expr
|
| 529 |
+
for e in translate(
|
| 530 |
+
meta_call_ctx, dispatcher_sig.arguments(), method=False
|
| 531 |
+
)
|
| 532 |
+
]
|
| 533 |
+
if is_view_copy_op:
|
| 534 |
+
# view_copy ops always have a CompositeExplicitAutogradNonFunctional kernel
|
| 535 |
+
assert func.has_composite_explicit_autograd_non_functional_kernel
|
| 536 |
+
dispatch_ns = "compositeexplicitautogradnonfunctional"
|
| 537 |
+
else:
|
| 538 |
+
dispatch_ns = "meta"
|
| 539 |
+
aten_name = schema.aten_name
|
| 540 |
+
# TODO: this is trolling
|
| 541 |
+
if func.func.has_symint() and metadata.supports_symint():
|
| 542 |
+
aten_name += "_symint"
|
| 543 |
+
shape_str = f"""\
|
| 544 |
+
{meta_conversion_str}
|
| 545 |
+
auto out_meta = at::{dispatch_ns}::{aten_name}({', '.join(meta_call_args)});
|
| 546 |
+
{meta_out}"""
|
| 547 |
+
else:
|
| 548 |
+
shape_sig = ComputeShapeSignature(
|
| 549 |
+
metadata.kernel, func, symint=metadata.supports_symint()
|
| 550 |
+
)
|
| 551 |
+
shape_str = f"""
|
| 552 |
+
auto shapes = {shape_sig.shape_call};"""
|
| 553 |
+
|
| 554 |
+
shape_str += f"""
|
| 555 |
+
TORCH_INTERNAL_ASSERT(shapes.size() == {returns_length});"""
|
| 556 |
+
|
| 557 |
+
# Calculating which dimensions are symbolic
|
| 558 |
+
func_schema_str = "aten::" + str(func.func)
|
| 559 |
+
shape_str += f"""
|
| 560 |
+
if(torch::lazy::symbolicShapeEnabled()){{
|
| 561 |
+
std::vector<torch::jit::IValue> inputs = {{ {', '.join(str(a.name) for a in all_args)} }};
|
| 562 |
+
const char* schema_str = "{func_schema_str}";
|
| 563 |
+
applySymbolicShapesOnLT(schema_str, inputs, shapes);
|
| 564 |
+
}}
|
| 565 |
+
"""
|
| 566 |
+
return shape_str
|
| 567 |
+
|
| 568 |
+
def build_ir_node(self, func: NativeFunction, schema: LazyIrSchema) -> str:
|
| 569 |
+
node_ctor_input_str = node_ctor_inputs(schema)
|
| 570 |
+
return f"""torch::lazy::NodePtr node = torch::lazy::ReuseNode<{schema.node_name}>({node_ctor_input_str});
|
| 571 |
+
if (!node) {{
|
| 572 |
+
{self.shape_inference(func, schema)}
|
| 573 |
+
node = torch::lazy::MakeNode<{schema.node_name}>({node_ctor_input_str}, std::move(shapes));
|
| 574 |
+
CacheNode(node);
|
| 575 |
+
}}
|
| 576 |
+
"""
|
| 577 |
+
|
| 578 |
+
def create_lazy_tensor(self, first_tensor_name: Optional[str] = None) -> str:
|
| 579 |
+
# xla uses an instance method for tensor creation, for the time being
|
| 580 |
+
if self.create_from_first_tensor:
|
| 581 |
+
# TODO(whc) remove this if XLA switches to using static method for creation
|
| 582 |
+
assert (
|
| 583 |
+
first_tensor_name is not None
|
| 584 |
+
), "Requires first tensor to create lazy tensor"
|
| 585 |
+
return f"{first_tensor_name}.{self.create_tensor}"
|
| 586 |
+
return f"{self.backend_namespace}::{self.create_tensor}"
|
| 587 |
+
|
| 588 |
+
def return_aten_tensor(self, func: NativeFunction, schema: LazyIrSchema) -> str:
|
| 589 |
+
returns_length = len(schema.returns)
|
| 590 |
+
value_args = schema.filtered_args(values=True, scalars=False)
|
| 591 |
+
value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar]
|
| 592 |
+
first_tensor_name = value_types_names[0] if len(value_types_names) > 0 else None
|
| 593 |
+
bridge_str = f"""auto result = {self.create_aten_from_ltc_tensor}(
|
| 594 |
+
{self.create_lazy_tensor(first_tensor_name)}(std::move(node), *common_device));"""
|
| 595 |
+
|
| 596 |
+
if returns_length > 1:
|
| 597 |
+
assert (
|
| 598 |
+
len(value_types_names) > 0
|
| 599 |
+
), "Code below assumes there is at least one tensor arg"
|
| 600 |
+
bridge_str = f"""std::vector<{self.lazy_tensor_ptr}> lazy_tensors;
|
| 601 |
+
for (int i = 0; i < {returns_length}; i++) {{
|
| 602 |
+
lazy_tensors.push_back({self.create_lazy_tensor(first_tensor_name)}({getValueT()}(node, i), *common_device));
|
| 603 |
+
}}
|
| 604 |
+
auto result = {self.tuple_aten_from_ltc_tensors}<{returns_length}>(lazy_tensors);"""
|
| 605 |
+
|
| 606 |
+
if schema.name.name.inplace or func.func.is_out_fn():
|
| 607 |
+
assert returns_length == 1, (
|
| 608 |
+
"We assumed there was no such case where an op is an in-place variant "
|
| 609 |
+
f"and has tuple outputs, but got tuple of len {returns_length}."
|
| 610 |
+
)
|
| 611 |
+
bridge_str = f"""lazy_{first_tensor_name}->SetInPlaceIrValue(node);
|
| 612 |
+
auto& result = {first_tensor_name};"""
|
| 613 |
+
|
| 614 |
+
bridge_str += """
|
| 615 |
+
return result;"""
|
| 616 |
+
return bridge_str
|
| 617 |
+
|
| 618 |
+
@method_with_native_function
|
| 619 |
+
def __call__(self, func: NativeFunction) -> List[str]:
|
| 620 |
+
sig = kernel_signature(func, self.backend_index)
|
| 621 |
+
metadata = self.backend_index.get_kernel(func)
|
| 622 |
+
assert metadata is not None
|
| 623 |
+
schema = LazyIrSchema(func.func, symint=metadata.supports_symint())
|
| 624 |
+
return [
|
| 625 |
+
f"""\
|
| 626 |
+
{sig.decl(name=f"{self.class_method_name}::{metadata.kernel}")} {{
|
| 627 |
+
{self.force_eager_fallback(func, schema, metadata, sig)}
|
| 628 |
+
{self.metrics(func, schema)}
|
| 629 |
+
{self.get_device(func, schema)}
|
| 630 |
+
{self.lazy_tensor_decls(func, schema)}
|
| 631 |
+
{self.build_ir_node(func, schema)}
|
| 632 |
+
{self.return_aten_tensor(func, schema)}
|
| 633 |
+
}}\n
|
| 634 |
+
"""
|
| 635 |
+
]
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
class ComputeShapeSignature:
|
| 639 |
+
"""
|
| 640 |
+
Here we use the base name as the suffix of the signature to avoid generating for in-place variants.
|
| 641 |
+
"""
|
| 642 |
+
|
| 643 |
+
def __init__(self, kernel_name: str, f: NativeFunction, *, symint: bool):
|
| 644 |
+
self.__schema = LazyIrSchema(f.func, symint=symint)
|
| 645 |
+
self.__dispatch_args = ", ".join(
|
| 646 |
+
[a.decl() for a in dispatcher.arguments(f.func, symint=symint)]
|
| 647 |
+
)
|
| 648 |
+
self.__call_args = ", ".join(
|
| 649 |
+
[f"{arg.name}" for arg in self.__schema.filtered_args(generator=True)]
|
| 650 |
+
)
|
| 651 |
+
self.__kernel_name = kernel_name
|
| 652 |
+
|
| 653 |
+
def __decl_suffix(self) -> str:
|
| 654 |
+
return f"{self.__kernel_name}({self.__dispatch_args})"
|
| 655 |
+
|
| 656 |
+
def __call_suffix(self) -> str:
|
| 657 |
+
return f"{self.__kernel_name}({self.__call_args})"
|
| 658 |
+
|
| 659 |
+
@property
|
| 660 |
+
def shape_decl(self) -> str:
|
| 661 |
+
return f"TORCH_API std::vector<torch::lazy::Shape> compute_shape_{self.__decl_suffix()}"
|
| 662 |
+
|
| 663 |
+
@property
|
| 664 |
+
def shape_call(self) -> str:
|
| 665 |
+
return f"torch::lazy::compute_shape_{self.__call_suffix()}"
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
@dataclass(frozen=True)
|
| 669 |
+
class GenLazyShapeInferenceDefinition:
|
| 670 |
+
backend_index: BackendIndex
|
| 671 |
+
tensor_class: str
|
| 672 |
+
|
| 673 |
+
@method_with_native_function
|
| 674 |
+
def __call__(self, f: NativeFunction) -> List[str]:
|
| 675 |
+
sig = kernel_signature(f, self.backend_index)
|
| 676 |
+
metadata = self.backend_index.get_kernel(f)
|
| 677 |
+
assert metadata is not None
|
| 678 |
+
|
| 679 |
+
# See Note [Generated LTC Shape Functions]
|
| 680 |
+
is_view_copy_op = "view_copy" in f.tags
|
| 681 |
+
is_structured = f.structured or f.structured_delegate is not None
|
| 682 |
+
if is_structured or is_view_copy_op:
|
| 683 |
+
return []
|
| 684 |
+
else:
|
| 685 |
+
shape_sig = ComputeShapeSignature(
|
| 686 |
+
metadata.kernel, f, symint=metadata.supports_symint()
|
| 687 |
+
)
|
| 688 |
+
return ["\n".join([f"{shape_sig.shape_decl};"])]
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
def generate_non_native_lazy_ir_nodes(
|
| 692 |
+
non_native: List[Dict[str, Any]], gen_lazy_ir: GenLazyIR
|
| 693 |
+
) -> List[str]:
|
| 694 |
+
"""Generate the non-native lazy IR node classes"""
|
| 695 |
+
nodes = []
|
| 696 |
+
for op in non_native:
|
| 697 |
+
# Set default properties for Non-Native IRs
|
| 698 |
+
properties = LazyIrProperties("ShapeCache", "CanBeReused", "LowerDeclOnly")
|
| 699 |
+
for p in op.get("properties", []):
|
| 700 |
+
setattr(properties, p, True)
|
| 701 |
+
|
| 702 |
+
# non-native is assumed to want symint bindings if you wrote symint
|
| 703 |
+
schema = LazyIrSchema(FunctionSchema.parse(op["func"]), properties, symint=True)
|
| 704 |
+
schema.opkind = op.get("opkind")
|
| 705 |
+
nodes.append(gen_lazy_ir.gen(schema)[0])
|
| 706 |
+
|
| 707 |
+
return nodes
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/lazy_ts_lowering.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torchgen.api.lazy import LazyArgument, LazyIrSchema
|
| 2 |
+
from torchgen.api.types import OptionalCType
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def ts_lowering_body(schema: LazyIrSchema) -> str:
|
| 6 |
+
# for now, we just want one IR class decl and soon after also the method defs
|
| 7 |
+
# and we use the functional version not out/inplace.
|
| 8 |
+
emplace_arguments = []
|
| 9 |
+
|
| 10 |
+
def get_value(arg: LazyArgument) -> str:
|
| 11 |
+
if isinstance(arg.lazy_type, OptionalCType):
|
| 12 |
+
return f"has_{arg.name} ? loctx->GetOutputOp(operand(i++)) : nullptr"
|
| 13 |
+
return "loctx->GetOutputOp(operand(i++))"
|
| 14 |
+
|
| 15 |
+
for arg in schema.positional_args:
|
| 16 |
+
if arg.is_lazy_value:
|
| 17 |
+
emplace_arguments.append(get_value(arg))
|
| 18 |
+
continue
|
| 19 |
+
emplace_arguments.append(f'"{arg.name}", {arg.name}')
|
| 20 |
+
|
| 21 |
+
emplace_arguments_str = "\n ".join(
|
| 22 |
+
[f"arguments.emplace_back({a});" for a in emplace_arguments]
|
| 23 |
+
)
|
| 24 |
+
emplace_kwarg_values = [
|
| 25 |
+
f'"{arg.name}", {get_value(arg)}' for arg in schema.keyword_values
|
| 26 |
+
]
|
| 27 |
+
emplace_kwarg_scalars = [
|
| 28 |
+
f'"{arg.name}", {arg.name}' for arg in schema.keyword_scalars
|
| 29 |
+
]
|
| 30 |
+
emplace_kwarguments = "\n ".join(
|
| 31 |
+
[
|
| 32 |
+
f"kwarguments.emplace_back({a});"
|
| 33 |
+
for a in emplace_kwarg_values + emplace_kwarg_scalars
|
| 34 |
+
]
|
| 35 |
+
)
|
| 36 |
+
return f"""\
|
| 37 |
+
std::vector<torch::jit::NamedValue> arguments;
|
| 38 |
+
std::vector<torch::jit::NamedValue> kwarguments;
|
| 39 |
+
arguments.reserve({len(emplace_arguments)});
|
| 40 |
+
kwarguments.reserve({len(emplace_kwarg_values + emplace_kwarg_scalars)});
|
| 41 |
+
size_t i = 0;
|
| 42 |
+
{emplace_arguments_str}
|
| 43 |
+
{emplace_kwarguments}
|
| 44 |
+
torch::lazy::TSOpVector {schema.aten_name}_out = torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments);
|
| 45 |
+
TORCH_CHECK_EQ({schema.aten_name}_out.size(), {len(schema.returns)});
|
| 46 |
+
|
| 47 |
+
return {schema.aten_name}_out;
|
| 48 |
+
"""
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/native_functions.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Union
|
| 2 |
+
|
| 3 |
+
import torchgen.api.meta as meta
|
| 4 |
+
import torchgen.api.structured as structured
|
| 5 |
+
from torchgen.api.types import kernel_signature
|
| 6 |
+
|
| 7 |
+
from torchgen.context import with_native_function_and_index
|
| 8 |
+
from torchgen.model import BackendIndex, NativeFunction, NativeFunctionsGroup
|
| 9 |
+
from torchgen.utils import mapMaybe
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@with_native_function_and_index
|
| 13 |
+
def gen_unstructured(f: NativeFunction, backend_index: BackendIndex) -> Optional[str]:
|
| 14 |
+
sig = kernel_signature(f, backend_index)
|
| 15 |
+
metadata = backend_index.get_kernel(f)
|
| 16 |
+
if metadata is None:
|
| 17 |
+
return None
|
| 18 |
+
if "legacy::" in metadata.kernel:
|
| 19 |
+
return None
|
| 20 |
+
else:
|
| 21 |
+
prefix = "static" if backend_index.external else "TORCH_API"
|
| 22 |
+
return f"{prefix} {sig.decl(name=metadata.kernel)};"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@with_native_function_and_index
|
| 26 |
+
def gen_structured(g: NativeFunctionsGroup, backend_index: BackendIndex) -> List[str]:
|
| 27 |
+
meta_name = meta.name(g)
|
| 28 |
+
out_args = structured.impl_arguments(g)
|
| 29 |
+
metadata = backend_index.get_kernel(g)
|
| 30 |
+
if metadata is None:
|
| 31 |
+
return []
|
| 32 |
+
prefix = "" if backend_index.external else "TORCH_API "
|
| 33 |
+
return [
|
| 34 |
+
f"""\
|
| 35 |
+
struct {prefix}structured_{metadata.kernel} : public at::meta::structured_{meta_name} {{
|
| 36 |
+
void impl({', '.join(a.decl() for a in out_args)});
|
| 37 |
+
}};
|
| 38 |
+
"""
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# Generates NativeFunctions.h, a list of forward declarations of all
|
| 43 |
+
# actual kernel definitions we keep in aten/src/ATen/native/
|
| 44 |
+
@with_native_function_and_index
|
| 45 |
+
def compute_native_function_declaration(
|
| 46 |
+
g: Union[NativeFunctionsGroup, NativeFunction], backend_index: BackendIndex
|
| 47 |
+
) -> List[str]:
|
| 48 |
+
metadata = backend_index.get_kernel(g)
|
| 49 |
+
if isinstance(g, NativeFunctionsGroup):
|
| 50 |
+
if metadata is not None and metadata.structured:
|
| 51 |
+
if backend_index.external:
|
| 52 |
+
# Structured hasn't been tested with external backends yet.
|
| 53 |
+
raise AssertionError(
|
| 54 |
+
"Structured external backend functions are not implemented yet."
|
| 55 |
+
)
|
| 56 |
+
else:
|
| 57 |
+
return gen_structured(g, backend_index)
|
| 58 |
+
else:
|
| 59 |
+
return list(
|
| 60 |
+
mapMaybe(lambda f: gen_unstructured(f, backend_index), g.functions())
|
| 61 |
+
)
|
| 62 |
+
else:
|
| 63 |
+
x = gen_unstructured(g, backend_index)
|
| 64 |
+
return [] if x is None else [x]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/register_dispatch_key.py
ADDED
|
@@ -0,0 +1,989 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
import textwrap
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import List, Literal, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import torchgen.api.cpp as cpp
|
| 7 |
+
import torchgen.api.meta as meta
|
| 8 |
+
import torchgen.api.structured as structured
|
| 9 |
+
from torchgen.api.translate import translate
|
| 10 |
+
from torchgen.api.types import (
|
| 11 |
+
BaseCType,
|
| 12 |
+
Binding,
|
| 13 |
+
ConstRefCType,
|
| 14 |
+
CppSignature,
|
| 15 |
+
CppSignatureGroup,
|
| 16 |
+
DispatcherSignature,
|
| 17 |
+
Expr,
|
| 18 |
+
kernel_signature,
|
| 19 |
+
MutRefCType,
|
| 20 |
+
NamedCType,
|
| 21 |
+
NativeSignature,
|
| 22 |
+
tensorT,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
from torchgen.context import method_with_native_function, native_function_manager
|
| 26 |
+
from torchgen.model import (
|
| 27 |
+
Argument,
|
| 28 |
+
BackendIndex,
|
| 29 |
+
DeviceCheckType,
|
| 30 |
+
DispatchKey,
|
| 31 |
+
gets_generated_out_inplace_wrapper,
|
| 32 |
+
is_cuda_dispatch_key,
|
| 33 |
+
NativeFunction,
|
| 34 |
+
NativeFunctionsGroup,
|
| 35 |
+
SchemaKind,
|
| 36 |
+
TensorOptionsArguments,
|
| 37 |
+
)
|
| 38 |
+
from torchgen.selective_build.selector import SelectiveBuilder
|
| 39 |
+
from torchgen.utils import assert_never, mapMaybe, Target
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def gen_registration_headers(
|
| 43 |
+
backend_index: BackendIndex,
|
| 44 |
+
per_operator_headers: bool,
|
| 45 |
+
rocm: bool,
|
| 46 |
+
) -> List[str]:
|
| 47 |
+
if per_operator_headers:
|
| 48 |
+
headers = ["#include <ATen/ops/as_strided_native.h>"]
|
| 49 |
+
else:
|
| 50 |
+
headers = ["#include <ATen/NativeFunctions.h>"]
|
| 51 |
+
|
| 52 |
+
if backend_index.dispatch_key in (DispatchKey.CPU, DispatchKey.Meta):
|
| 53 |
+
headers.append("#include <ATen/EmptyTensor.h>")
|
| 54 |
+
elif backend_index.dispatch_key == DispatchKey.CUDA:
|
| 55 |
+
if rocm:
|
| 56 |
+
headers.append("#include <ATen/hip/EmptyTensor.h>")
|
| 57 |
+
else:
|
| 58 |
+
headers.append("#include <ATen/cuda/EmptyTensor.h>")
|
| 59 |
+
elif backend_index.dispatch_key == DispatchKey.MPS:
|
| 60 |
+
headers.append("#include <ATen/mps/EmptyTensor.h>")
|
| 61 |
+
elif per_operator_headers:
|
| 62 |
+
headers += [
|
| 63 |
+
"#include <ATen/ops/empty.h>",
|
| 64 |
+
"#include <ATen/ops/empty_strided.h>",
|
| 65 |
+
"#include <ATen/ops/_copy_from_and_resize.h>",
|
| 66 |
+
"#include <ATen/ops/_copy_from.h>",
|
| 67 |
+
]
|
| 68 |
+
else:
|
| 69 |
+
headers.append("#include <ATen/Functions.h>")
|
| 70 |
+
|
| 71 |
+
return headers
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def gen_empty_impl_names(
|
| 75 |
+
backend_index: BackendIndex,
|
| 76 |
+
) -> Tuple[Optional[str], Optional[str]]:
|
| 77 |
+
empty_impl = None
|
| 78 |
+
empty_strided_impl = None
|
| 79 |
+
|
| 80 |
+
if backend_index.dispatch_key in (
|
| 81 |
+
DispatchKey.Meta,
|
| 82 |
+
DispatchKey.CPU,
|
| 83 |
+
DispatchKey.CUDA,
|
| 84 |
+
DispatchKey.MPS,
|
| 85 |
+
):
|
| 86 |
+
dispatch = str(backend_index.dispatch_key).lower()
|
| 87 |
+
empty_impl = f"at::detail::empty_{dispatch}"
|
| 88 |
+
empty_strided_impl = f"at::detail::empty_strided_{dispatch}"
|
| 89 |
+
elif backend_index.dispatch_key in (
|
| 90 |
+
DispatchKey.CompositeExplicitAutogradNonFunctional,
|
| 91 |
+
DispatchKey.QuantizedCPU,
|
| 92 |
+
DispatchKey.QuantizedCUDA,
|
| 93 |
+
):
|
| 94 |
+
empty_impl = "at::empty"
|
| 95 |
+
empty_strided_impl = "at::empty_strided"
|
| 96 |
+
|
| 97 |
+
return empty_impl, empty_strided_impl
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def gen_create_out_helper(backend_index: BackendIndex) -> List[str]:
|
| 101 |
+
if backend_index.dispatch_key == DispatchKey.Meta:
|
| 102 |
+
empty_options = "options.device(at::kMeta)"
|
| 103 |
+
else:
|
| 104 |
+
empty_options = "options"
|
| 105 |
+
|
| 106 |
+
empty_impl, empty_strided_impl = gen_empty_impl_names(backend_index)
|
| 107 |
+
if empty_impl is None:
|
| 108 |
+
return []
|
| 109 |
+
|
| 110 |
+
return [
|
| 111 |
+
f"""
|
| 112 |
+
Tensor create_out(IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {{
|
| 113 |
+
if (strides.empty()) {{
|
| 114 |
+
return {empty_impl}(sizes, {empty_options});
|
| 115 |
+
}} else {{
|
| 116 |
+
return {empty_strided_impl}(sizes, strides, {empty_options});
|
| 117 |
+
}}
|
| 118 |
+
}}
|
| 119 |
+
"""
|
| 120 |
+
]
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def gen_maybe_create_proxy_helper(backend_index: BackendIndex) -> List[str]:
|
| 124 |
+
_, empty_strided_impl = gen_empty_impl_names(backend_index)
|
| 125 |
+
return (
|
| 126 |
+
[]
|
| 127 |
+
if empty_strided_impl is None
|
| 128 |
+
else [
|
| 129 |
+
f"""
|
| 130 |
+
c10::optional<Tensor> maybe_create_proxy(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {{
|
| 131 |
+
if (out.strides() != strides) {{
|
| 132 |
+
return {empty_strided_impl}(sizes, strides, options);
|
| 133 |
+
}}
|
| 134 |
+
return c10::nullopt;
|
| 135 |
+
}}
|
| 136 |
+
"""
|
| 137 |
+
]
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def gen_resize_out_helper(backend_index: BackendIndex) -> List[str]:
|
| 142 |
+
if backend_index.dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional:
|
| 143 |
+
# The function isn't used by this key (since only functional ops have a kernel for this key),
|
| 144 |
+
# so we need to not include it to avoid a defined-but-not-used error.
|
| 145 |
+
return []
|
| 146 |
+
return [
|
| 147 |
+
"""
|
| 148 |
+
void resize_out(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {
|
| 149 |
+
TORCH_CHECK(options.dtype() == out.dtype(),
|
| 150 |
+
"Expected out tensor to have dtype ", options.dtype(), ", but got ", out.dtype(), " instead");
|
| 151 |
+
TORCH_CHECK(options.device() == out.device(),
|
| 152 |
+
"Expected out tensor to have device ", options.device(), ", but got ", out.device(), " instead");
|
| 153 |
+
const bool resized = at::native::resize_output(out, sizes);
|
| 154 |
+
// Only restride if a resize occurred; otherwise we ignore the (advisory)
|
| 155 |
+
// strides from the meta function and directly use the output tensor's
|
| 156 |
+
// preexisting strides
|
| 157 |
+
if (resized) {
|
| 158 |
+
if (!strides.empty()) {
|
| 159 |
+
TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value());
|
| 160 |
+
// TODO: avoid the redispatch here
|
| 161 |
+
out.as_strided_(sizes, strides);
|
| 162 |
+
} else if (options.memory_format_opt().has_value()) {
|
| 163 |
+
out.unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt());
|
| 164 |
+
}
|
| 165 |
+
}
|
| 166 |
+
}
|
| 167 |
+
"""
|
| 168 |
+
]
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def gen_check_inplace_helper(backend_index: BackendIndex) -> List[str]:
|
| 172 |
+
return [
|
| 173 |
+
"""
|
| 174 |
+
void check_inplace(const Tensor &self, IntArrayRef sizes, const TensorOptions &options) {
|
| 175 |
+
// These checks are needed on those operators that:
|
| 176 |
+
// 1) don't use 'TensorIterator' (e.g. 'addmm' and 'baddbmm')
|
| 177 |
+
// 2) have particular typing rules (e.g. 'cumsum' and 'cumprod')
|
| 178 |
+
// For other operators (e.g. 'add'), 'TensorIterator' already checks
|
| 179 |
+
// these things separately.
|
| 180 |
+
TORCH_CHECK(options.dtype() == self.dtype(),
|
| 181 |
+
"Bad in-place call: ",
|
| 182 |
+
"input tensor dtype ", self.dtype(), " and output tensor dtype ", options.dtype(), " should match");
|
| 183 |
+
TORCH_CHECK(options.device() == self.device(),
|
| 184 |
+
"Bad in-place call: ",
|
| 185 |
+
"input tensor device ", self.device(), " and output tensor device ", options.device(), " should match");
|
| 186 |
+
TORCH_CHECK(sizes == self.sizes(),
|
| 187 |
+
"Bad in-place call: ",
|
| 188 |
+
"input tensor size ", self.sizes(), " and output tensor size ", sizes, " should match");
|
| 189 |
+
}
|
| 190 |
+
"""
|
| 191 |
+
]
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def gen_registration_helpers(backend_index: BackendIndex) -> List[str]:
|
| 195 |
+
return [
|
| 196 |
+
*gen_create_out_helper(backend_index),
|
| 197 |
+
*gen_resize_out_helper(backend_index),
|
| 198 |
+
*gen_check_inplace_helper(backend_index),
|
| 199 |
+
*gen_maybe_create_proxy_helper(backend_index),
|
| 200 |
+
]
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
# Generates Register{dispatch}.cpp (e.g., RegisterCPU.cpp).
|
| 204 |
+
#
|
| 205 |
+
# - The primary function of this file is to register all of the
|
| 206 |
+
# implementations for the given dispatch key to the dispatcher,
|
| 207 |
+
# so they are available for use in PyTorch. If dispatch is
|
| 208 |
+
# None, we generate schema (def) registrations and catchall
|
| 209 |
+
# registrations.
|
| 210 |
+
# - The secondary function of this file is to generate a wrapper
|
| 211 |
+
# around functions. In CPUType these wrappers do nothing
|
| 212 |
+
# (and should be removed), but in other cases they handle
|
| 213 |
+
# DeviceGuard. A small extra benefit of wrappers is they
|
| 214 |
+
# are not overloaded, so they can be used in the registration
|
| 215 |
+
# API without having to disambiguate which overload you want
|
| 216 |
+
# (as would be the case if you directly registered native::
|
| 217 |
+
# functions).
|
| 218 |
+
# - The tertiary function of this file is to generate *static*
|
| 219 |
+
# cpp API bindings which can be used to bypass dispatcher
|
| 220 |
+
# directly to kernels, but with user-friendly cpp-style API
|
| 221 |
+
@dataclass(frozen=True)
|
| 222 |
+
class RegisterDispatchKey:
|
| 223 |
+
backend_index: BackendIndex
|
| 224 |
+
|
| 225 |
+
target: Literal[
|
| 226 |
+
Target.ANONYMOUS_DEFINITION,
|
| 227 |
+
Target.NAMESPACED_DEFINITION,
|
| 228 |
+
Target.NAMESPACED_DECLARATION,
|
| 229 |
+
Target.REGISTRATION,
|
| 230 |
+
]
|
| 231 |
+
|
| 232 |
+
# Selector object to determine which operators to generate
|
| 233 |
+
# registration code for.
|
| 234 |
+
selector: SelectiveBuilder
|
| 235 |
+
|
| 236 |
+
# Whether or not we are actually code-genning for ROCm
|
| 237 |
+
rocm: bool
|
| 238 |
+
|
| 239 |
+
# Whether or not to generate symint registrations or not. External users
|
| 240 |
+
# of codegen who don't care about symints can set this to false to get
|
| 241 |
+
# non-SymInt codegen
|
| 242 |
+
symint: bool
|
| 243 |
+
|
| 244 |
+
# The class that all unstructured native functions live under. This is used to improve
|
| 245 |
+
# compiler error messages when a kernel writer adds a native function with the wrong signature.
|
| 246 |
+
# This is only used in unstructured kernels, since structured kernels already live in a class.
|
| 247 |
+
# Finally, this field is currently Optional because it is only used by external backends.
|
| 248 |
+
# It would be nice if we can add the same logic to in-tree kernels too, but that requires updating
|
| 249 |
+
# all of the existing kernel signatures scattered across aten/src/ATen/native.
|
| 250 |
+
class_method_name: Optional[str]
|
| 251 |
+
|
| 252 |
+
# Only set to true in lightweight dispatch. If lightweight dispatch is enabled we are registering
|
| 253 |
+
# operators into JIT op registry, thus we need to avoid generating code to register into the dispatcher.
|
| 254 |
+
skip_dispatcher_op_registration: bool
|
| 255 |
+
|
| 256 |
+
@staticmethod
|
| 257 |
+
def gen_device_check(
|
| 258 |
+
type: DeviceCheckType, args: List[Argument], method_name: str
|
| 259 |
+
) -> str:
|
| 260 |
+
if type == DeviceCheckType.NoCheck:
|
| 261 |
+
return " // No device check\n"
|
| 262 |
+
|
| 263 |
+
device_check = "c10::optional<Device> common_device = nullopt;\n"
|
| 264 |
+
device_check += "(void)common_device; // Suppress unused variable warning\n"
|
| 265 |
+
for arg in args:
|
| 266 |
+
# Only tensor like arguments are eligible
|
| 267 |
+
if arg.type.is_tensor_like():
|
| 268 |
+
device_check += f"""
|
| 269 |
+
c10::impl::check_and_update_common_device(common_device, {arg.name}, "{method_name}", "{arg.name}");"""
|
| 270 |
+
return device_check
|
| 271 |
+
|
| 272 |
+
@method_with_native_function
|
| 273 |
+
def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]:
|
| 274 |
+
if isinstance(f, NativeFunctionsGroup):
|
| 275 |
+
g: NativeFunctionsGroup = f
|
| 276 |
+
# Note: We call gen_structured() if the operator is marked structured, regardless of the backend.
|
| 277 |
+
# gen_structured() has special logic to handle auto-generated kernels.
|
| 278 |
+
if g.structured:
|
| 279 |
+
return self.gen_structured(g)
|
| 280 |
+
else:
|
| 281 |
+
return list(
|
| 282 |
+
mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions())
|
| 283 |
+
)
|
| 284 |
+
elif isinstance(f, NativeFunction):
|
| 285 |
+
r = self.gen_unstructured(f)
|
| 286 |
+
return [] if r is None else [r]
|
| 287 |
+
else:
|
| 288 |
+
assert_never(f)
|
| 289 |
+
|
| 290 |
+
def wrapper_kernel_sig(
|
| 291 |
+
self, f: NativeFunction
|
| 292 |
+
) -> Union[NativeSignature, DispatcherSignature]:
|
| 293 |
+
# The prefix is just to ensure uniqueness. The Dispatcher API doesn't guarantee unique kernel names.
|
| 294 |
+
return DispatcherSignature.from_schema(
|
| 295 |
+
f.func,
|
| 296 |
+
prefix=f"wrapper_{self.backend_index.dispatch_key}_{f.func.name.overload_name}_",
|
| 297 |
+
symint=self.symint,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
def gen_out_inplace_wrapper(
|
| 301 |
+
self, f: NativeFunction, g: Optional[NativeFunctionsGroup]
|
| 302 |
+
) -> Optional[str]:
|
| 303 |
+
if g is None:
|
| 304 |
+
return None
|
| 305 |
+
k = f.func.kind()
|
| 306 |
+
if k is SchemaKind.inplace:
|
| 307 |
+
copy_op = "at::_copy_from"
|
| 308 |
+
elif k is SchemaKind.out:
|
| 309 |
+
copy_op = "at::_copy_from_and_resize"
|
| 310 |
+
else:
|
| 311 |
+
raise AssertionError("gen_out_inplace_wrapper called on a functional op")
|
| 312 |
+
|
| 313 |
+
sig = self.wrapper_kernel_sig(f)
|
| 314 |
+
name = sig.name()
|
| 315 |
+
|
| 316 |
+
func_res = f"{name}_tmp"
|
| 317 |
+
return_names = cpp.return_names(f)
|
| 318 |
+
if len(return_names) > 1:
|
| 319 |
+
updates = "\n ".join(
|
| 320 |
+
f"{copy_op}(std::get<{i}>({func_res}), {ret_name});"
|
| 321 |
+
for i, ret_name in enumerate(return_names)
|
| 322 |
+
)
|
| 323 |
+
returns = f'{sig.returns_type().cpp_type()}({", ".join(return_names)})'
|
| 324 |
+
elif len(return_names) == 1:
|
| 325 |
+
ret_name = return_names[0]
|
| 326 |
+
updates = f"{copy_op}({func_res}, {ret_name});"
|
| 327 |
+
returns = ret_name
|
| 328 |
+
else:
|
| 329 |
+
assert len(f.func.arguments.out) == 1
|
| 330 |
+
returns = ""
|
| 331 |
+
out_arg = f.func.arguments.out[0]
|
| 332 |
+
if out_arg.type.is_list_like():
|
| 333 |
+
updates = f"""\
|
| 334 |
+
for (int64_t i = 0; i < {func_res}.size(); ++i) {{
|
| 335 |
+
{copy_op}({func_res}[i], {out_arg.name}[i]);
|
| 336 |
+
}}"""
|
| 337 |
+
else:
|
| 338 |
+
updates = f"{copy_op}({func_res}, {out_arg.name});"
|
| 339 |
+
|
| 340 |
+
functional_sig = self.wrapper_kernel_sig(g.functional)
|
| 341 |
+
wrapper_name = sig.name()
|
| 342 |
+
|
| 343 |
+
return f"""\
|
| 344 |
+
{sig.defn(name=wrapper_name)} {{
|
| 345 |
+
auto {func_res} = {functional_sig.name()}({", ".join(e.expr for e in translate(sig.arguments(), functional_sig.arguments()))});
|
| 346 |
+
{updates}
|
| 347 |
+
return {returns};
|
| 348 |
+
}}
|
| 349 |
+
"""
|
| 350 |
+
|
| 351 |
+
def gen_structured(self, g: NativeFunctionsGroup) -> List[str]:
|
| 352 |
+
metadata = self.backend_index.get_kernel(g)
|
| 353 |
+
if self.backend_index.dispatch_key == DispatchKey.Meta:
|
| 354 |
+
assert not self.backend_index.has_kernel(g.out), (
|
| 355 |
+
"Do not explicitly specify Meta dispatch key on structured "
|
| 356 |
+
"functions, they will be automatically generated for you"
|
| 357 |
+
)
|
| 358 |
+
elif (
|
| 359 |
+
self.backend_index.dispatch_key
|
| 360 |
+
== DispatchKey.CompositeExplicitAutogradNonFunctional
|
| 361 |
+
):
|
| 362 |
+
assert not self.backend_index.has_kernel(g.out), (
|
| 363 |
+
"Do not explicitly specify CompositeExplicitAutograd dispatch key on structured "
|
| 364 |
+
"functions, they will be automatically generated for you"
|
| 365 |
+
)
|
| 366 |
+
elif metadata is None or not metadata.structured:
|
| 367 |
+
return list(mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions()))
|
| 368 |
+
structured_gen = StructuredRegisterDispatchKey(
|
| 369 |
+
self.backend_index,
|
| 370 |
+
self.target,
|
| 371 |
+
self.selector,
|
| 372 |
+
self.rocm,
|
| 373 |
+
self.symint,
|
| 374 |
+
self.class_method_name,
|
| 375 |
+
self.skip_dispatcher_op_registration,
|
| 376 |
+
g,
|
| 377 |
+
)
|
| 378 |
+
return list(mapMaybe(structured_gen.gen_one, g.functions()))
|
| 379 |
+
|
| 380 |
+
def gen_unstructured(
|
| 381 |
+
self, f: NativeFunction, g: Optional[NativeFunctionsGroup] = None
|
| 382 |
+
) -> Optional[str]:
|
| 383 |
+
with native_function_manager(f):
|
| 384 |
+
inplace_meta = False
|
| 385 |
+
gets_out_inplace_wrapper = False
|
| 386 |
+
if not self.backend_index.has_kernel(f):
|
| 387 |
+
if (
|
| 388 |
+
self.backend_index.dispatch_key == DispatchKey.Meta
|
| 389 |
+
and f.func.kind() is SchemaKind.inplace
|
| 390 |
+
and
|
| 391 |
+
# Defer to composites for meta implementation
|
| 392 |
+
not f.has_composite_kernel
|
| 393 |
+
and
|
| 394 |
+
# Inplace list operations are not supported
|
| 395 |
+
len(f.func.returns) == 1
|
| 396 |
+
):
|
| 397 |
+
inplace_meta = True
|
| 398 |
+
elif (
|
| 399 |
+
not self.backend_index.use_out_as_primary
|
| 400 |
+
and g is not None
|
| 401 |
+
and gets_generated_out_inplace_wrapper(f, g, self.backend_index)
|
| 402 |
+
):
|
| 403 |
+
# We want to generate inplace/out wrappers, that don't have a kernel for the backend.
|
| 404 |
+
gets_out_inplace_wrapper = True
|
| 405 |
+
else:
|
| 406 |
+
return None
|
| 407 |
+
if f.manual_kernel_registration:
|
| 408 |
+
return None
|
| 409 |
+
|
| 410 |
+
if (
|
| 411 |
+
self.target is Target.REGISTRATION
|
| 412 |
+
and not self.selector.is_native_function_selected(f)
|
| 413 |
+
):
|
| 414 |
+
return None
|
| 415 |
+
|
| 416 |
+
sig = self.wrapper_kernel_sig(f)
|
| 417 |
+
|
| 418 |
+
name = sig.name()
|
| 419 |
+
returns_type = sig.returns_type().cpp_type()
|
| 420 |
+
args = sig.arguments()
|
| 421 |
+
args_str = ", ".join(a.defn() for a in args)
|
| 422 |
+
|
| 423 |
+
# See Note [Direct dispatch bindings]
|
| 424 |
+
cpp_sig_group = CppSignatureGroup.from_native_function(
|
| 425 |
+
f, method=False, fallback_binding=False
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
# TODO: dedupe this with the structured codegen
|
| 429 |
+
if self.target is Target.NAMESPACED_DECLARATION:
|
| 430 |
+
result = ""
|
| 431 |
+
for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
|
| 432 |
+
result += f"TORCH_API {cpp_sig.decl()};\n"
|
| 433 |
+
return result
|
| 434 |
+
elif self.target is Target.NAMESPACED_DEFINITION:
|
| 435 |
+
|
| 436 |
+
def generate_defn(cpp_sig: CppSignature) -> str:
|
| 437 |
+
return f"""
|
| 438 |
+
{cpp_sig.defn()} {{
|
| 439 |
+
return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
|
| 440 |
+
}}
|
| 441 |
+
"""
|
| 442 |
+
|
| 443 |
+
result = ""
|
| 444 |
+
for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
|
| 445 |
+
result += generate_defn(cpp_sig)
|
| 446 |
+
return result
|
| 447 |
+
|
| 448 |
+
elif self.target is Target.ANONYMOUS_DEFINITION:
|
| 449 |
+
# short circuit for inplace_meta
|
| 450 |
+
if inplace_meta:
|
| 451 |
+
assert f.func.arguments.self_arg is not None
|
| 452 |
+
self_arg_name = f.func.arguments.self_arg.argument.name
|
| 453 |
+
# TODO: handle in place on tensor list
|
| 454 |
+
return f"""
|
| 455 |
+
{returns_type} {name}({args_str}) {{
|
| 456 |
+
TORCH_CHECK_NOT_IMPLEMENTED({self_arg_name}.is_meta(),
|
| 457 |
+
"Cannot inplace into non-meta tensor with meta tensor argument");
|
| 458 |
+
return {self_arg_name};
|
| 459 |
+
}}
|
| 460 |
+
"""
|
| 461 |
+
|
| 462 |
+
# short circuit for generated inplace/out wrappers
|
| 463 |
+
if gets_out_inplace_wrapper:
|
| 464 |
+
return self.gen_out_inplace_wrapper(f, g)
|
| 465 |
+
|
| 466 |
+
metadata = self.backend_index.get_kernel(f)
|
| 467 |
+
if metadata is None:
|
| 468 |
+
return None
|
| 469 |
+
if self.class_method_name is None:
|
| 470 |
+
impl_name = f"{metadata.cpp_namespace}::{metadata.kernel}"
|
| 471 |
+
else:
|
| 472 |
+
impl_name = f"{metadata.cpp_namespace}::{self.class_method_name}::{metadata.kernel}"
|
| 473 |
+
|
| 474 |
+
kernel_sig = kernel_signature(f, self.backend_index)
|
| 475 |
+
|
| 476 |
+
args_exprs_str = ", ".join(
|
| 477 |
+
e.expr
|
| 478 |
+
for e in translate(
|
| 479 |
+
sig.arguments(), kernel_sig.arguments(), method=False
|
| 480 |
+
)
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
device_check = " // No device check\n"
|
| 484 |
+
# Backends that require device guards presumably also require device checks.
|
| 485 |
+
if self.backend_index.device_guard:
|
| 486 |
+
device_check_args = itertools.chain(
|
| 487 |
+
f.func.arguments.out, f.func.arguments.flat_positional
|
| 488 |
+
)
|
| 489 |
+
device_check = RegisterDispatchKey.gen_device_check(
|
| 490 |
+
f.device_check, list(device_check_args), name
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
device_guard = "// DeviceGuard omitted" # default
|
| 494 |
+
if f.device_guard and self.backend_index.device_guard:
|
| 495 |
+
has_tensor_options = any(
|
| 496 |
+
isinstance(a, TensorOptionsArguments)
|
| 497 |
+
for a in f.func.arguments.non_out
|
| 498 |
+
)
|
| 499 |
+
if has_tensor_options:
|
| 500 |
+
# kernel is creating a tensor
|
| 501 |
+
device_guard = """
|
| 502 |
+
const DeviceGuard device_guard(device_or_default(device));"""
|
| 503 |
+
|
| 504 |
+
# CUDA requires special handling
|
| 505 |
+
if is_cuda_dispatch_key(self.backend_index.dispatch_key):
|
| 506 |
+
device_guard = (
|
| 507 |
+
f"globalContext().lazyInitCUDA();\n{device_guard}"
|
| 508 |
+
)
|
| 509 |
+
else:
|
| 510 |
+
# kernel is operating on existing tensors
|
| 511 |
+
|
| 512 |
+
# There is precedence for which argument we use to do
|
| 513 |
+
# device guard. This describes the precedence order.
|
| 514 |
+
self_arg = (
|
| 515 |
+
[f.func.arguments.self_arg.argument]
|
| 516 |
+
if f.func.arguments.self_arg is not None
|
| 517 |
+
else []
|
| 518 |
+
)
|
| 519 |
+
candidate_args = itertools.chain(
|
| 520 |
+
self_arg,
|
| 521 |
+
f.func.arguments.out,
|
| 522 |
+
f.func.arguments.flat_positional,
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
# Only tensor like arguments are eligible
|
| 526 |
+
device_of = next(
|
| 527 |
+
(
|
| 528 |
+
f"{a.name}"
|
| 529 |
+
for a in candidate_args
|
| 530 |
+
if a.type.is_tensor_like()
|
| 531 |
+
),
|
| 532 |
+
None,
|
| 533 |
+
)
|
| 534 |
+
if device_of is not None:
|
| 535 |
+
device_guard = f"const OptionalDeviceGuard device_guard(device_of({device_of}));"
|
| 536 |
+
|
| 537 |
+
return f"""\
|
| 538 |
+
namespace {{
|
| 539 |
+
|
| 540 |
+
{returns_type} {name}({args_str}) {{
|
| 541 |
+
{device_check}
|
| 542 |
+
|
| 543 |
+
{device_guard}
|
| 544 |
+
return {impl_name}({args_exprs_str});
|
| 545 |
+
}}
|
| 546 |
+
|
| 547 |
+
}} // anonymous namespace
|
| 548 |
+
"""
|
| 549 |
+
|
| 550 |
+
elif self.target is Target.REGISTRATION:
|
| 551 |
+
if f.manual_kernel_registration or self.skip_dispatcher_op_registration:
|
| 552 |
+
return None
|
| 553 |
+
else:
|
| 554 |
+
payload = f"TORCH_FN({name})"
|
| 555 |
+
return f'm.impl("{f.func.name}",\n{payload});\n'
|
| 556 |
+
else:
|
| 557 |
+
assert_never(self.target)
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
| 561 |
+
#
|
| 562 |
+
# STRUCTURED
|
| 563 |
+
#
|
| 564 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
@dataclass(frozen=True)
|
| 568 |
+
class StructuredRegisterDispatchKey(RegisterDispatchKey):
|
| 569 |
+
g: NativeFunctionsGroup
|
| 570 |
+
|
| 571 |
+
def gen_class_set_output_functions(
|
| 572 |
+
self, k: SchemaKind, parent_class: str, generate_super: bool
|
| 573 |
+
) -> str:
|
| 574 |
+
if generate_super:
|
| 575 |
+
set_output_super = f"{parent_class}::set_output_raw_strided(output_idx, sizes, strides, options, names);"
|
| 576 |
+
else:
|
| 577 |
+
set_output_super = ""
|
| 578 |
+
|
| 579 |
+
def gen_set_output_function(name: str, maybe_create_proxy: bool) -> str:
|
| 580 |
+
return f"""
|
| 581 |
+
void set_output_{name}(
|
| 582 |
+
int64_t output_idx, IntArrayRef sizes, IntArrayRef strides,
|
| 583 |
+
TensorOptions options, DimnameList names
|
| 584 |
+
) override {{
|
| 585 |
+
{textwrap.indent(self.gen_class_set_output_body(k, maybe_create_proxy), " ")}
|
| 586 |
+
if (!names.empty()) {{
|
| 587 |
+
namedinference::propagate_names(outputs_[output_idx], names);
|
| 588 |
+
}}
|
| 589 |
+
// super must happen after, so that downstream can use maybe_get_output
|
| 590 |
+
// to retrieve the output
|
| 591 |
+
{textwrap.indent(set_output_super, " ")}
|
| 592 |
+
}}
|
| 593 |
+
"""
|
| 594 |
+
|
| 595 |
+
return f"""
|
| 596 |
+
{gen_set_output_function("strided", maybe_create_proxy=True)}
|
| 597 |
+
{gen_set_output_function("raw_strided", maybe_create_proxy=False)}
|
| 598 |
+
"""
|
| 599 |
+
|
| 600 |
+
def gen_class_set_output_body(self, k: SchemaKind, maybe_create_proxy: bool) -> str:
|
| 601 |
+
if self.backend_index.dispatch_key in [
|
| 602 |
+
DispatchKey.CUDA,
|
| 603 |
+
DispatchKey.MPS,
|
| 604 |
+
DispatchKey.CompositeExplicitAutogradNonFunctional,
|
| 605 |
+
]:
|
| 606 |
+
maybe_set_guard = """
|
| 607 |
+
auto current_device = guard_.current_device();
|
| 608 |
+
if (C10_UNLIKELY(current_device.has_value())) {
|
| 609 |
+
TORCH_INTERNAL_ASSERT(*current_device == options.device(),
|
| 610 |
+
"structured kernels don't support multi-device outputs");
|
| 611 |
+
} else {
|
| 612 |
+
guard_.reset_device(options.device());
|
| 613 |
+
}
|
| 614 |
+
"""
|
| 615 |
+
maybe_set_guard_line = maybe_set_guard + "\n"
|
| 616 |
+
else:
|
| 617 |
+
maybe_set_guard_line = maybe_set_guard = ""
|
| 618 |
+
|
| 619 |
+
if maybe_create_proxy:
|
| 620 |
+
create_proxy = """
|
| 621 |
+
auto maybe_proxy = maybe_create_proxy(out, sizes, strides, options);
|
| 622 |
+
if (C10_UNLIKELY(maybe_proxy.has_value())) {
|
| 623 |
+
proxy_outputs_[output_idx] = std::move(maybe_proxy).value();
|
| 624 |
+
}
|
| 625 |
+
"""
|
| 626 |
+
else:
|
| 627 |
+
create_proxy = ""
|
| 628 |
+
|
| 629 |
+
if k is SchemaKind.functional:
|
| 630 |
+
assert self.backend_index.dispatch_key in (
|
| 631 |
+
DispatchKey.Meta,
|
| 632 |
+
DispatchKey.CPU,
|
| 633 |
+
DispatchKey.CUDA,
|
| 634 |
+
DispatchKey.MPS,
|
| 635 |
+
DispatchKey.CompositeExplicitAutogradNonFunctional,
|
| 636 |
+
)
|
| 637 |
+
return f"""{maybe_set_guard_line}
|
| 638 |
+
outputs_[output_idx] = create_out(sizes, strides, options);"""
|
| 639 |
+
elif k is SchemaKind.inplace:
|
| 640 |
+
return f"""{maybe_set_guard_line}
|
| 641 |
+
const auto& out = outputs_[output_idx].get();
|
| 642 |
+
check_inplace(out, sizes, options);
|
| 643 |
+
{create_proxy}"""
|
| 644 |
+
elif k is SchemaKind.out:
|
| 645 |
+
return f"""{maybe_set_guard_line}
|
| 646 |
+
const auto& out = outputs_[output_idx].get();
|
| 647 |
+
resize_out(out, sizes, strides, options);
|
| 648 |
+
{create_proxy}"""
|
| 649 |
+
elif k is SchemaKind.mutable or k is SchemaKind.scratch:
|
| 650 |
+
raise AssertionError(
|
| 651 |
+
f"{k} structured operators are currently not supported"
|
| 652 |
+
)
|
| 653 |
+
else:
|
| 654 |
+
assert_never(k)
|
| 655 |
+
|
| 656 |
+
# returns the definition of a ctor, as well as how to construct
|
| 657 |
+
# this class to a variable named op
|
| 658 |
+
def gen_class_ctor(self, k: SchemaKind, class_name: str, returns: int) -> str:
|
| 659 |
+
if k is SchemaKind.functional:
|
| 660 |
+
return ""
|
| 661 |
+
elif k is SchemaKind.inplace:
|
| 662 |
+
# TODO: Make sure out argument is guaranteed to be self
|
| 663 |
+
return f"{class_name}(Tensor& self) : outputs_{{std::ref(self)}} {{}}"
|
| 664 |
+
elif k is SchemaKind.out:
|
| 665 |
+
out_args = ", ".join(f"Tensor& out{i}" for i in range(returns))
|
| 666 |
+
out_refs = ", ".join(f"std::ref(out{i})" for i in range(returns))
|
| 667 |
+
return f"{class_name}({out_args}) : outputs_{{ {out_refs} }} {{}}"
|
| 668 |
+
elif k is SchemaKind.mutable or k is SchemaKind.scratch:
|
| 669 |
+
raise AssertionError(
|
| 670 |
+
f"{k} structured operators are currently not supported"
|
| 671 |
+
)
|
| 672 |
+
else:
|
| 673 |
+
assert_never(k)
|
| 674 |
+
|
| 675 |
+
def gen_class(
|
| 676 |
+
self,
|
| 677 |
+
f: NativeFunction,
|
| 678 |
+
k: SchemaKind,
|
| 679 |
+
*,
|
| 680 |
+
class_name: str,
|
| 681 |
+
parent_class: str,
|
| 682 |
+
generate_super: bool,
|
| 683 |
+
) -> str:
|
| 684 |
+
if k is SchemaKind.functional:
|
| 685 |
+
output_type = "Tensor"
|
| 686 |
+
output_value = "outputs_[output_idx]"
|
| 687 |
+
proxy_field = ""
|
| 688 |
+
elif k is SchemaKind.inplace:
|
| 689 |
+
output_type = "std::reference_wrapper<Tensor>"
|
| 690 |
+
output_value = "proxy_outputs_[output_idx].has_value() ? *proxy_outputs_[output_idx] : outputs_[output_idx].get()"
|
| 691 |
+
proxy_field = f"std::array<c10::optional<Tensor>, {len(f.func.returns)}> proxy_outputs_;"
|
| 692 |
+
elif k is SchemaKind.out:
|
| 693 |
+
output_type = "std::reference_wrapper<Tensor>"
|
| 694 |
+
output_value = "proxy_outputs_[output_idx].has_value() ? *proxy_outputs_[output_idx] : outputs_[output_idx].get()"
|
| 695 |
+
proxy_field = f"std::array<c10::optional<Tensor>, {len(f.func.returns)}> proxy_outputs_;"
|
| 696 |
+
|
| 697 |
+
if self.backend_index.dispatch_key == DispatchKey.CUDA:
|
| 698 |
+
if self.rocm:
|
| 699 |
+
guard_field = "c10::hip::OptionalHIPGuardMasqueradingAsCUDA guard_;"
|
| 700 |
+
else:
|
| 701 |
+
guard_field = "c10::cuda::OptionalCUDAGuard guard_;"
|
| 702 |
+
elif (
|
| 703 |
+
self.backend_index.dispatch_key
|
| 704 |
+
== DispatchKey.CompositeExplicitAutogradNonFunctional
|
| 705 |
+
):
|
| 706 |
+
guard_field = "c10::OptionalDeviceGuard guard_;"
|
| 707 |
+
elif self.backend_index.dispatch_key == DispatchKey.MPS:
|
| 708 |
+
# TODO: Move to OptionalMPSGuard.
|
| 709 |
+
guard_field = "c10::OptionalDeviceGuard guard_;"
|
| 710 |
+
else:
|
| 711 |
+
guard_field = ""
|
| 712 |
+
|
| 713 |
+
indent = " " * 4
|
| 714 |
+
class_ctor_str = self.gen_class_ctor(k, class_name, len(f.func.returns))
|
| 715 |
+
lines = (
|
| 716 |
+
f"struct {class_name} final : public {parent_class} {{",
|
| 717 |
+
f"{textwrap.indent(class_ctor_str, indent)}",
|
| 718 |
+
f"{textwrap.indent(self.gen_class_set_output_functions(k, parent_class, generate_super), indent)}",
|
| 719 |
+
" const Tensor& maybe_get_output(int64_t output_idx) override {",
|
| 720 |
+
f" return {output_value};\n", # type: ignore[possibly-undefined] # TODO: audit
|
| 721 |
+
" }",
|
| 722 |
+
f" std::array<{output_type}, {len(f.func.returns)}> outputs_;", # type: ignore[possibly-undefined] # TODO: audit
|
| 723 |
+
f"{textwrap.indent(proxy_field, indent)}", # type: ignore[possibly-undefined] # TODO: audit
|
| 724 |
+
f"{textwrap.indent(guard_field, indent)}",
|
| 725 |
+
"};",
|
| 726 |
+
)
|
| 727 |
+
return "\n".join(line for line in lines if line)
|
| 728 |
+
|
| 729 |
+
@method_with_native_function
|
| 730 |
+
def gen_one(self, f: NativeFunction) -> Optional[str]:
|
| 731 |
+
assert not f.manual_kernel_registration
|
| 732 |
+
|
| 733 |
+
if (
|
| 734 |
+
self.target is Target.REGISTRATION
|
| 735 |
+
and not self.selector.is_native_function_selected(f)
|
| 736 |
+
):
|
| 737 |
+
return None
|
| 738 |
+
|
| 739 |
+
# TODO: Now, there is something interesting going on here. In the code below,
|
| 740 |
+
# we generate CompositeExplicitAutogradNonFunctional implementations of functional and inplace
|
| 741 |
+
# based on the out implementation. But in fact, out is definable by
|
| 742 |
+
# functional too (just not very efficiently), and this is honestly the
|
| 743 |
+
# MORE likely situation for a backend implementor. How do we pick?
|
| 744 |
+
# Well, taking a page from Haskell type classes and default methods,
|
| 745 |
+
# we could conceivably register a circular definition (out in terms
|
| 746 |
+
# of functional, and functional in terms of out) and just require
|
| 747 |
+
# someone to implement one or the other. We'd have to do a little bit
|
| 748 |
+
# of work to not register one of these "weak" definitions unless there
|
| 749 |
+
# is a strong definition somewhere in the DAG! So it's not implemented yet.
|
| 750 |
+
if (
|
| 751 |
+
self.backend_index.dispatch_key
|
| 752 |
+
== DispatchKey.CompositeExplicitAutogradNonFunctional
|
| 753 |
+
and f.func.kind() is SchemaKind.out
|
| 754 |
+
):
|
| 755 |
+
# Never generate a default implementation for out, that's what you
|
| 756 |
+
# have to define as a backend implementor
|
| 757 |
+
return None
|
| 758 |
+
|
| 759 |
+
# Note [Direct dispatch bindings]
|
| 760 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 761 |
+
# Signature of the non-dispatched function we'll expose in a header
|
| 762 |
+
# (e.g., at::cpu::add). We don't generate methods (TODO: do this
|
| 763 |
+
# when CPUTensor class is a thing); nor do we generate fallback
|
| 764 |
+
# bindings for manual_cpp_binding functions.
|
| 765 |
+
cpp_sig_group = CppSignatureGroup.from_native_function(
|
| 766 |
+
f, method=False, fallback_binding=False
|
| 767 |
+
)
|
| 768 |
+
|
| 769 |
+
# Signature of the wrapper function we'll register to the dispatcher
|
| 770 |
+
kern = self.backend_index.get_kernel(f)
|
| 771 |
+
sig = NativeSignature(
|
| 772 |
+
f.func,
|
| 773 |
+
prefix=f"wrapper_{self.backend_index.dispatch_key}_",
|
| 774 |
+
symint=kern is not None and kern.supports_symint(),
|
| 775 |
+
)
|
| 776 |
+
|
| 777 |
+
if self.target is Target.NAMESPACED_DECLARATION:
|
| 778 |
+
result = ""
|
| 779 |
+
for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
|
| 780 |
+
result += f"TORCH_API {cpp_sig.decl()};\n"
|
| 781 |
+
return result
|
| 782 |
+
|
| 783 |
+
elif self.target is Target.NAMESPACED_DEFINITION:
|
| 784 |
+
|
| 785 |
+
def generate_defn(cpp_sig: CppSignature) -> str:
|
| 786 |
+
return f"""
|
| 787 |
+
{cpp_sig.defn()} {{
|
| 788 |
+
return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
|
| 789 |
+
}}
|
| 790 |
+
"""
|
| 791 |
+
|
| 792 |
+
result = ""
|
| 793 |
+
for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
|
| 794 |
+
result += generate_defn(cpp_sig)
|
| 795 |
+
return result
|
| 796 |
+
|
| 797 |
+
elif self.target is Target.ANONYMOUS_DEFINITION:
|
| 798 |
+
k = f.func.kind()
|
| 799 |
+
|
| 800 |
+
# Construct the body of the wrapper function with signature sig
|
| 801 |
+
sig_body = []
|
| 802 |
+
# We'll use context to keep track of any variables we've brought
|
| 803 |
+
# into scope while generating code
|
| 804 |
+
context: List[Union[Binding, Expr]] = list(sig.arguments())
|
| 805 |
+
|
| 806 |
+
# Initialize the class corresponding to this structured
|
| 807 |
+
# operator; feeding it the output argument(s) if it is known
|
| 808 |
+
if self.backend_index.dispatch_key is DispatchKey.Meta:
|
| 809 |
+
class_name = f"structured_{meta.name(self.g)}_meta_{k.name}"
|
| 810 |
+
parent_class = f"at::meta::structured_{meta.name(self.g)}"
|
| 811 |
+
elif (
|
| 812 |
+
self.backend_index.dispatch_key
|
| 813 |
+
is DispatchKey.CompositeExplicitAutogradNonFunctional
|
| 814 |
+
):
|
| 815 |
+
# TODO: dedup this branch
|
| 816 |
+
class_name = f"structured_{meta.name(self.g)}_default_backend_{k.name}"
|
| 817 |
+
parent_class = f"at::meta::structured_{meta.name(self.g)}"
|
| 818 |
+
else:
|
| 819 |
+
metadata = self.backend_index.get_kernel(self.g)
|
| 820 |
+
assert metadata is not None
|
| 821 |
+
class_name = f"structured_{metadata.kernel}_{k.name}"
|
| 822 |
+
parent_class = f"{metadata.cpp_namespace}::structured_{metadata.kernel}"
|
| 823 |
+
|
| 824 |
+
if self.backend_index.device_guard:
|
| 825 |
+
device_check_args = itertools.chain(
|
| 826 |
+
f.func.arguments.out, f.func.arguments.flat_positional
|
| 827 |
+
)
|
| 828 |
+
sig_body.append(
|
| 829 |
+
RegisterDispatchKey.gen_device_check(
|
| 830 |
+
f.device_check, list(device_check_args), sig.name()
|
| 831 |
+
)
|
| 832 |
+
)
|
| 833 |
+
|
| 834 |
+
if k is SchemaKind.functional:
|
| 835 |
+
sig_body.append(f"{class_name} op;")
|
| 836 |
+
elif k is SchemaKind.inplace:
|
| 837 |
+
sig_body.append(f"{class_name} op(self);")
|
| 838 |
+
elif k is SchemaKind.out:
|
| 839 |
+
out_args_str = ", ".join(a.name for a in f.func.arguments.out)
|
| 840 |
+
sig_body.append(f"{class_name} op({out_args_str});")
|
| 841 |
+
|
| 842 |
+
# Translate the input native arguments into structured
|
| 843 |
+
# arguments for the meta call
|
| 844 |
+
meta_exprs = ", ".join(
|
| 845 |
+
e.expr
|
| 846 |
+
for e in translate(
|
| 847 |
+
context, structured.meta_arguments(self.g), method=False
|
| 848 |
+
)
|
| 849 |
+
)
|
| 850 |
+
|
| 851 |
+
if self.g.out.precomputed:
|
| 852 |
+
# If this function group has precomputed elements, the meta function
|
| 853 |
+
# returns a struct containing them which must be saved so that it
|
| 854 |
+
# can be unpacked when generating code to call the impl.
|
| 855 |
+
sig_body.append(f"auto precompute = op.meta({meta_exprs});")
|
| 856 |
+
|
| 857 |
+
# Put all of the contents of the precompute struct into the context
|
| 858 |
+
# so that translate will be able to return the correct args for the
|
| 859 |
+
# call to the impl.
|
| 860 |
+
precomputed_values = [
|
| 861 |
+
*self.g.out.precomputed.replace.values(),
|
| 862 |
+
self.g.out.precomputed.add,
|
| 863 |
+
]
|
| 864 |
+
for precomputed_elems in precomputed_values:
|
| 865 |
+
for arg in precomputed_elems:
|
| 866 |
+
context.append(
|
| 867 |
+
Expr(
|
| 868 |
+
expr=f"precompute.{arg.name}",
|
| 869 |
+
type=structured.argument_type(arg, binds=arg.name),
|
| 870 |
+
)
|
| 871 |
+
)
|
| 872 |
+
|
| 873 |
+
# Add a use of the precompute struct so FB internal compilers don't
|
| 874 |
+
# complain that there is an unused variable.
|
| 875 |
+
sig_body.append("(void)precompute;")
|
| 876 |
+
else:
|
| 877 |
+
sig_body.append(f"op.meta({meta_exprs});")
|
| 878 |
+
|
| 879 |
+
# After running meta, op.outputs_ is guaranteed to be valid;
|
| 880 |
+
# add it to the context
|
| 881 |
+
out_args = structured.out_arguments(self.g)
|
| 882 |
+
for i, out_arg in enumerate(out_args):
|
| 883 |
+
assert ConstRefCType(BaseCType(tensorT)) == out_arg.nctype.type
|
| 884 |
+
|
| 885 |
+
if k is SchemaKind.out:
|
| 886 |
+
expr = f"op.maybe_get_output({i})"
|
| 887 |
+
else:
|
| 888 |
+
expr = f"op.outputs_[{i}]"
|
| 889 |
+
|
| 890 |
+
context.append(
|
| 891 |
+
Expr(
|
| 892 |
+
expr=expr,
|
| 893 |
+
# TODO: Stop hardcoding that the output type is a Tensor. Note
|
| 894 |
+
# that for the codegen here this is fine because outputs_ is
|
| 895 |
+
# hardcoded to be tensor already
|
| 896 |
+
type=NamedCType(
|
| 897 |
+
out_arg.nctype.name, MutRefCType(BaseCType(tensorT))
|
| 898 |
+
),
|
| 899 |
+
)
|
| 900 |
+
)
|
| 901 |
+
|
| 902 |
+
# With the expanded context, do the impl call (if not a meta
|
| 903 |
+
# function)
|
| 904 |
+
if (
|
| 905 |
+
self.backend_index.dispatch_key
|
| 906 |
+
== DispatchKey.CompositeExplicitAutogradNonFunctional
|
| 907 |
+
):
|
| 908 |
+
# TODO: https://github.com/pytorch/pytorch/issues/53023
|
| 909 |
+
out_sig_group = CppSignatureGroup.from_native_function(
|
| 910 |
+
self.g.out, method=False, fallback_binding=f.manual_cpp_binding
|
| 911 |
+
)
|
| 912 |
+
out_sig = out_sig_group.most_faithful_signature()
|
| 913 |
+
api_name = out_sig.name()
|
| 914 |
+
out_exprs = ", ".join(
|
| 915 |
+
e.expr
|
| 916 |
+
for e in translate(context, out_sig.arguments(), method=False)
|
| 917 |
+
)
|
| 918 |
+
# TODO: I think this means structured won't work with method
|
| 919 |
+
# only functions (but maybe you're saved by faithful? iunno.)
|
| 920 |
+
# NB: Originally I wrote this as an at::redispatch call, but
|
| 921 |
+
# I got in trouble because that meant I needed a DispatchKeySet
|
| 922 |
+
# in the wrapper function, which meant I needed a DispatchKeySet
|
| 923 |
+
# in the DispatchKeyFunctions declarations, but the defined API
|
| 924 |
+
# there does NOT permit a dispatch key set. I think you can
|
| 925 |
+
# probably unwind this by calling some function to do the TLS
|
| 926 |
+
# fetch and get the DispatchKeySet when you don't have it, but
|
| 927 |
+
# I didn't do it for this version
|
| 928 |
+
sig_body.append(f"at::{api_name}({out_exprs});")
|
| 929 |
+
elif self.backend_index.dispatch_key != DispatchKey.Meta:
|
| 930 |
+
impl_exprs = ", ".join(
|
| 931 |
+
e.expr
|
| 932 |
+
for e in translate(
|
| 933 |
+
context, structured.impl_arguments(self.g), method=False
|
| 934 |
+
)
|
| 935 |
+
)
|
| 936 |
+
sig_body.append(f"op.impl({impl_exprs});")
|
| 937 |
+
|
| 938 |
+
# Go over each output, and check if there is a proxy created for it.
|
| 939 |
+
# If so, copy it over to the original output.
|
| 940 |
+
if k is SchemaKind.out or k is SchemaKind.inplace:
|
| 941 |
+
for i in range(len(f.func.returns)):
|
| 942 |
+
sig_body.append(
|
| 943 |
+
f"if (op.proxy_outputs_[{i}].has_value()) op.outputs_[{i}].get().copy_(*op.proxy_outputs_[{i}]);"
|
| 944 |
+
)
|
| 945 |
+
|
| 946 |
+
# Destructively return the final tensors
|
| 947 |
+
# TODO: Do this in translate instead
|
| 948 |
+
if k is SchemaKind.functional:
|
| 949 |
+
if len(f.func.returns) == 1:
|
| 950 |
+
ret_expr = "std::move(op.outputs_[0])" # small optimization
|
| 951 |
+
else:
|
| 952 |
+
moved = ", ".join(
|
| 953 |
+
f"std::move(op.outputs_[{i}])"
|
| 954 |
+
for i in range(len(f.func.returns))
|
| 955 |
+
)
|
| 956 |
+
ret_expr = f"std::make_tuple({moved})"
|
| 957 |
+
elif k is SchemaKind.inplace:
|
| 958 |
+
ret_expr = "self"
|
| 959 |
+
elif k is SchemaKind.out:
|
| 960 |
+
if len(f.func.returns) == 1:
|
| 961 |
+
ret_expr = f.func.arguments.out[0].name
|
| 962 |
+
else:
|
| 963 |
+
refs = ", ".join(a.name for a in f.func.arguments.out)
|
| 964 |
+
ret_expr = f"std::forward_as_tuple({refs})"
|
| 965 |
+
sig_body.append(f"return {ret_expr};") # type: ignore[possibly-undefined] # TODO: audit
|
| 966 |
+
|
| 967 |
+
sig_body_str = "\n".join(sig_body)
|
| 968 |
+
|
| 969 |
+
# For an overview of what this template code looks like, see
|
| 970 |
+
# https://github.com/pytorch/rfcs/pull/9
|
| 971 |
+
return f"""\
|
| 972 |
+
{self.gen_class(
|
| 973 |
+
f, k,
|
| 974 |
+
class_name=class_name,
|
| 975 |
+
parent_class=parent_class,
|
| 976 |
+
generate_super=self.g.out.structured_inherits is not None
|
| 977 |
+
)}
|
| 978 |
+
|
| 979 |
+
{sig.defn()} {{
|
| 980 |
+
{sig_body_str}
|
| 981 |
+
}}
|
| 982 |
+
"""
|
| 983 |
+
|
| 984 |
+
elif self.target is Target.REGISTRATION:
|
| 985 |
+
return f'm.impl("{f.func.name}", TORCH_FN({sig.name()}));'
|
| 986 |
+
else:
|
| 987 |
+
assert_never(self.target)
|
| 988 |
+
# Silence mypy's "Missing return statement" error
|
| 989 |
+
return None
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/dest/ufunc.py
ADDED
|
@@ -0,0 +1,545 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import torchgen.api.ufunc as ufunc
|
| 5 |
+
from torchgen.api.translate import translate
|
| 6 |
+
from torchgen.api.types import (
|
| 7 |
+
BaseCType,
|
| 8 |
+
Binding,
|
| 9 |
+
CType,
|
| 10 |
+
Expr,
|
| 11 |
+
NamedCType,
|
| 12 |
+
opmath_t,
|
| 13 |
+
scalar_t,
|
| 14 |
+
StructuredImplSignature,
|
| 15 |
+
VectorizedCType,
|
| 16 |
+
)
|
| 17 |
+
from torchgen.api.ufunc import UfunctorBindings
|
| 18 |
+
from torchgen.context import with_native_function
|
| 19 |
+
from torchgen.model import (
|
| 20 |
+
Argument,
|
| 21 |
+
BaseTy,
|
| 22 |
+
BaseType,
|
| 23 |
+
DispatchKey,
|
| 24 |
+
NativeFunctionsGroup,
|
| 25 |
+
ScalarType,
|
| 26 |
+
UfuncKey,
|
| 27 |
+
)
|
| 28 |
+
from torchgen.utils import OrderedSet
|
| 29 |
+
|
| 30 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
| 31 |
+
#
|
| 32 |
+
# CUDA STUFF
|
| 33 |
+
#
|
| 34 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
| 35 |
+
|
| 36 |
+
# NB: not bothering to generate dispatch stub forward declaration in header,
|
| 37 |
+
# we can just paste it whereever necessary
|
| 38 |
+
|
| 39 |
+
# TODO: use BackendIndex
|
| 40 |
+
# dispatch_key: DispatchKey # only CPU/CUDA right now
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# Represents functors for implementing CUDA ufuncs.
|
| 44 |
+
# Functors are templated by scalar_t because when USERS instantiate functors
|
| 45 |
+
# they are templated. A functor looks something like this:
|
| 46 |
+
#
|
| 47 |
+
# template <typename scalar_t>
|
| 48 |
+
# struct CUDAFunctorOnSelf_add {
|
| 49 |
+
# using opmath_t = at::opmath_type<scalar_t>;
|
| 50 |
+
# opmath_t other_;
|
| 51 |
+
# opmath_t alpha_;
|
| 52 |
+
# CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha)
|
| 53 |
+
# : other_(other), alpha_(alpha) {}
|
| 54 |
+
# __device__ scalar_t operator()(scalar_t self) {
|
| 55 |
+
# return ufunc::add(static_cast<opmath_t>(self), other_, alpha_);
|
| 56 |
+
# }
|
| 57 |
+
# };
|
| 58 |
+
#
|
| 59 |
+
@dataclass(frozen=True)
|
| 60 |
+
class UfunctorSignature:
|
| 61 |
+
g: NativeFunctionsGroup
|
| 62 |
+
scalar_tensor_idx: Optional[int]
|
| 63 |
+
name: str
|
| 64 |
+
|
| 65 |
+
def arguments(self) -> UfunctorBindings:
|
| 66 |
+
return ufunc.ufunctor_arguments(
|
| 67 |
+
self.g, scalar_tensor_idx=self.scalar_tensor_idx, scalar_t=scalar_t
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
def fields(self) -> List[Binding]:
|
| 71 |
+
# fields are renamed to have a trailing underscore, as is conventional
|
| 72 |
+
return [b.rename(f"{b.name}_") for b in self.arguments().ctor]
|
| 73 |
+
|
| 74 |
+
def returns_type(self) -> CType:
|
| 75 |
+
# TODO: don't hardcode; return type will be inferred based on tags on
|
| 76 |
+
# the native function
|
| 77 |
+
return BaseCType(scalar_t)
|
| 78 |
+
|
| 79 |
+
def decl_fields(self) -> str:
|
| 80 |
+
return "\n".join(f"{f.type} {f.name};" for f in self.fields())
|
| 81 |
+
|
| 82 |
+
def inline_defn_ctor(self) -> str:
|
| 83 |
+
args_str = ", ".join(a.decl() for a in self.arguments().ctor)
|
| 84 |
+
# NB: hypothetically could do this with translate but the
|
| 85 |
+
# transition here is very regular
|
| 86 |
+
init_str = ", ".join(f"{a.name}_({a.name})" for a in self.arguments().ctor)
|
| 87 |
+
return f"{self.name}({args_str}) : {init_str} {{}}"
|
| 88 |
+
|
| 89 |
+
def decl_apply(self) -> str:
|
| 90 |
+
args_str = ", ".join(a.decl() for a in self.arguments().apply)
|
| 91 |
+
return f"{self.returns_type().cpp_type()} operator()({args_str}) const"
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@dataclass(frozen=True)
|
| 95 |
+
class UfuncSignature:
|
| 96 |
+
g: NativeFunctionsGroup
|
| 97 |
+
name: str
|
| 98 |
+
compute_t: CType
|
| 99 |
+
|
| 100 |
+
def arguments(self) -> List[Binding]:
|
| 101 |
+
return ufunc.ufunc_arguments(self.g, compute_t=self.compute_t)
|
| 102 |
+
|
| 103 |
+
def call(self, ctx: Sequence[Union[Binding, Expr]]) -> str:
|
| 104 |
+
return f"{self.name}({', '.join(a.expr for a in translate(ctx, self.arguments()))})"
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# steps:
|
| 108 |
+
# 1. take the functional signature
|
| 109 |
+
# 2. use api.ufunc to convert it to template signature. this establishes
|
| 110 |
+
# the type of the template function
|
| 111 |
+
# 3. use api.ufunc (II) to generate a split struct / operator() signature.
|
| 112 |
+
# this establish context in which we call the template signature
|
| 113 |
+
#
|
| 114 |
+
# StructuredImplSignature context
|
| 115 |
+
# ~> functor constructor sig
|
| 116 |
+
#
|
| 117 |
+
# Functor constructor context
|
| 118 |
+
# ~> functor fields sig
|
| 119 |
+
#
|
| 120 |
+
# Functor apply context (functor fields + functor apply sig)
|
| 121 |
+
# ~> template sig
|
| 122 |
+
#
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def eligible_for_binary_scalar_specialization(g: NativeFunctionsGroup) -> bool:
|
| 126 |
+
num_tensors = sum(
|
| 127 |
+
1 for a in g.functional.func.arguments.flat_non_out if a.type.is_tensor_like()
|
| 128 |
+
)
|
| 129 |
+
return num_tensors == 2
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def compute_ufunc_cuda_functors(
|
| 133 |
+
g: NativeFunctionsGroup,
|
| 134 |
+
) -> Tuple[Dict[ScalarType, Dict[UfuncKey, UfunctorSignature]], str]:
|
| 135 |
+
# First, build the functors.
|
| 136 |
+
ufunctor_sigs: Dict[ScalarType, Dict[UfuncKey, UfunctorSignature]] = {}
|
| 137 |
+
ufunctors: List[str] = []
|
| 138 |
+
loops = g.out.ufunc_inner_loop
|
| 139 |
+
scalar_tensor_idx_lookup = {
|
| 140 |
+
UfuncKey.CUDAFunctorOnSelf: 1,
|
| 141 |
+
UfuncKey.CUDAFunctorOnOther: 0,
|
| 142 |
+
UfuncKey.CUDAFunctor: None,
|
| 143 |
+
}
|
| 144 |
+
if eligible_for_binary_scalar_specialization(g):
|
| 145 |
+
keys = [
|
| 146 |
+
UfuncKey.CUDAFunctorOnSelf,
|
| 147 |
+
UfuncKey.CUDAFunctorOnOther,
|
| 148 |
+
UfuncKey.CUDAFunctor,
|
| 149 |
+
]
|
| 150 |
+
else:
|
| 151 |
+
keys = [UfuncKey.CUDAFunctor]
|
| 152 |
+
for k in [UfuncKey.CUDAFunctorOnSelf, UfuncKey.CUDAFunctorOnOther]:
|
| 153 |
+
assert k not in loops, f"cannot use {k} on non-binary function"
|
| 154 |
+
for k in keys:
|
| 155 |
+
# If the key was directly defined, skip functor codegen; we assume the
|
| 156 |
+
# user already done it for us
|
| 157 |
+
if k in loops:
|
| 158 |
+
ufunctor_sig = UfunctorSignature(
|
| 159 |
+
g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=loops[k].name
|
| 160 |
+
)
|
| 161 |
+
for dtype in loops[k].supported_dtypes:
|
| 162 |
+
ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig
|
| 163 |
+
continue
|
| 164 |
+
|
| 165 |
+
# Note [ScalarOnly and Generic must match names for CUDA]
|
| 166 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 167 |
+
# Otherwise, look in ANY of the generic entries. For simplicity of
|
| 168 |
+
# codegen, both ScalarOnly and Generic are defined, the ufunc name
|
| 169 |
+
# must match (if they didn't match, we'd have to generate distinct
|
| 170 |
+
# functors per dtype, which is awful, so we're not going to do it unless
|
| 171 |
+
# someone really forces us to)
|
| 172 |
+
ufunc_name = None
|
| 173 |
+
supported_dtypes: OrderedSet[ScalarType] = OrderedSet()
|
| 174 |
+
for lk in [UfuncKey.ScalarOnly, UfuncKey.Generic]:
|
| 175 |
+
if lk not in loops:
|
| 176 |
+
continue
|
| 177 |
+
if ufunc_name is None:
|
| 178 |
+
ufunc_name = loops[lk].name
|
| 179 |
+
else:
|
| 180 |
+
# See Note [ScalarOnly and Generic must match names for CUDA]
|
| 181 |
+
assert (
|
| 182 |
+
ufunc_name == loops[lk].name
|
| 183 |
+
), "ScalarOnly and Generic must have same ufunc name"
|
| 184 |
+
supported_dtypes |= loops[lk].supported_dtypes
|
| 185 |
+
assert ufunc_name is not None
|
| 186 |
+
|
| 187 |
+
name = f"{k}_{ufunc_name}"
|
| 188 |
+
ufunctor_sig = UfunctorSignature(
|
| 189 |
+
g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=name
|
| 190 |
+
)
|
| 191 |
+
for dtype in supported_dtypes:
|
| 192 |
+
ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig
|
| 193 |
+
|
| 194 |
+
ufunc_sig = UfuncSignature(
|
| 195 |
+
g, name=f"ufunc::{ufunc_name}", compute_t=BaseCType(opmath_t)
|
| 196 |
+
)
|
| 197 |
+
apply_ctx = ufunctor_sig.fields() + ufunctor_sig.arguments().apply
|
| 198 |
+
ufunctors.append(
|
| 199 |
+
f"""
|
| 200 |
+
template <typename scalar_t>
|
| 201 |
+
struct {ufunctor_sig.name} {{
|
| 202 |
+
using opmath_t = at::opmath_type<scalar_t>;
|
| 203 |
+
{ufunctor_sig.decl_fields()}
|
| 204 |
+
{ufunctor_sig.inline_defn_ctor()}
|
| 205 |
+
__device__ {ufunctor_sig.decl_apply()} {{
|
| 206 |
+
return {ufunc_sig.call(apply_ctx)};
|
| 207 |
+
}}
|
| 208 |
+
}};
|
| 209 |
+
"""
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
return ufunctor_sigs, "\n".join(ufunctors)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
@dataclass(frozen=True)
|
| 216 |
+
class BinaryScalarSpecializationConfig:
|
| 217 |
+
scalar_idx: int
|
| 218 |
+
ctor_tensor: str
|
| 219 |
+
ufunc_key: UfuncKey
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
BinaryScalarSpecializationConfigs = [
|
| 223 |
+
BinaryScalarSpecializationConfig(
|
| 224 |
+
scalar_idx=0,
|
| 225 |
+
ctor_tensor="self",
|
| 226 |
+
ufunc_key=UfuncKey.CUDAFunctorOnOther,
|
| 227 |
+
),
|
| 228 |
+
BinaryScalarSpecializationConfig(
|
| 229 |
+
scalar_idx=1,
|
| 230 |
+
ctor_tensor="other",
|
| 231 |
+
ufunc_key=UfuncKey.CUDAFunctorOnSelf,
|
| 232 |
+
),
|
| 233 |
+
]
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def compute_ufunc_cuda_dtype_body(
|
| 237 |
+
g: NativeFunctionsGroup,
|
| 238 |
+
dtype: ScalarType,
|
| 239 |
+
inner_loops: Dict[UfuncKey, UfunctorSignature],
|
| 240 |
+
parent_ctx: Sequence[Binding],
|
| 241 |
+
) -> str:
|
| 242 |
+
body = "using opmath_t = at::opmath_type<scalar_t>;"
|
| 243 |
+
body += "if (false) {}\n" # for ease of codegen
|
| 244 |
+
for config in BinaryScalarSpecializationConfigs:
|
| 245 |
+
if config.ufunc_key not in inner_loops:
|
| 246 |
+
continue
|
| 247 |
+
ufunctor_sig = inner_loops[config.ufunc_key]
|
| 248 |
+
scalar_idx = config.scalar_idx + 1
|
| 249 |
+
# Make a copy and at the same time widen the type (not permissible
|
| 250 |
+
# without copy; we don't want to mutate the input argument anyway)
|
| 251 |
+
ctx: List[Union[Expr, Binding]] = list(parent_ctx)
|
| 252 |
+
ctx.append(
|
| 253 |
+
Expr(
|
| 254 |
+
expr=f"iter.scalar_value<opmath_t>({scalar_idx})",
|
| 255 |
+
type=NamedCType(config.ctor_tensor, BaseCType(opmath_t)),
|
| 256 |
+
)
|
| 257 |
+
)
|
| 258 |
+
ufunctor_ctor_exprs_str = ", ".join(
|
| 259 |
+
a.expr for a in translate(ctx, ufunctor_sig.arguments().ctor)
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
# NB: ufunctor must be allocated before iter.remove_operand is called,
|
| 263 |
+
# as it relies on iter
|
| 264 |
+
body += f"""\
|
| 265 |
+
else if (iter.is_cpu_scalar({scalar_idx})) {{
|
| 266 |
+
{ufunctor_sig.name}<scalar_t> ufunctor({ufunctor_ctor_exprs_str});
|
| 267 |
+
iter.remove_operand({scalar_idx});
|
| 268 |
+
gpu_kernel(iter, ufunctor);
|
| 269 |
+
}}"""
|
| 270 |
+
|
| 271 |
+
ufunctor_sig = inner_loops[UfuncKey.CUDAFunctor]
|
| 272 |
+
ufunctor_ctor_exprs_str = ", ".join(
|
| 273 |
+
a.expr for a in translate(parent_ctx, ufunctor_sig.arguments().ctor)
|
| 274 |
+
)
|
| 275 |
+
body += f"""
|
| 276 |
+
else {{
|
| 277 |
+
gpu_kernel(iter, {ufunctor_sig.name}<scalar_t>({ufunctor_ctor_exprs_str}));
|
| 278 |
+
}}
|
| 279 |
+
"""
|
| 280 |
+
return body
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
@with_native_function
|
| 284 |
+
def compute_ufunc_cuda(g: NativeFunctionsGroup) -> str:
|
| 285 |
+
# First, build the functors, indexing them by dtype
|
| 286 |
+
ufunctor_sigs, ufunctors = compute_ufunc_cuda_functors(g)
|
| 287 |
+
|
| 288 |
+
# Next, build the conditionals
|
| 289 |
+
sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CUDA))
|
| 290 |
+
dtype_cases = []
|
| 291 |
+
for dtype, inner_ufunc_sigs in ufunctor_sigs.items():
|
| 292 |
+
dtype_cases.append(
|
| 293 |
+
f"""
|
| 294 |
+
AT_DISPATCH_CASE(at::ScalarType::{dtype},
|
| 295 |
+
[&]() {{
|
| 296 |
+
{compute_ufunc_cuda_dtype_body(g, dtype, inner_ufunc_sigs, sig.arguments())}
|
| 297 |
+
}}
|
| 298 |
+
)
|
| 299 |
+
"""
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
dtype_cases_str = "\n".join(dtype_cases)
|
| 303 |
+
|
| 304 |
+
stub_sig = StubSignature(g)
|
| 305 |
+
|
| 306 |
+
return f"""
|
| 307 |
+
{ufunctors}
|
| 308 |
+
|
| 309 |
+
{stub_sig.type_defn()};
|
| 310 |
+
{stub_sig.dispatch_decl()};
|
| 311 |
+
|
| 312 |
+
{stub_sig.kernel_defn()} {{
|
| 313 |
+
AT_DISPATCH_SWITCH(iter.common_dtype(), "{sig.name}",
|
| 314 |
+
{dtype_cases_str}
|
| 315 |
+
);
|
| 316 |
+
}}
|
| 317 |
+
REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name});
|
| 318 |
+
|
| 319 |
+
{sig.defn()} {{
|
| 320 |
+
{stub_sig.direct_call(sig.arguments())};
|
| 321 |
+
}}
|
| 322 |
+
"""
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
| 326 |
+
#
|
| 327 |
+
# CPU STUFF
|
| 328 |
+
#
|
| 329 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
@dataclass(frozen=True)
|
| 333 |
+
class StubSignature:
|
| 334 |
+
g: NativeFunctionsGroup
|
| 335 |
+
|
| 336 |
+
@property
|
| 337 |
+
def name(self) -> str:
|
| 338 |
+
return f"{str(self.g.functional.func.name.name)}_stub"
|
| 339 |
+
|
| 340 |
+
@property
|
| 341 |
+
def kernel_name(self) -> str:
|
| 342 |
+
return f"{str(self.g.functional.func.name.name)}_kernel"
|
| 343 |
+
|
| 344 |
+
@property
|
| 345 |
+
def type_name(self) -> str:
|
| 346 |
+
return f"{str(self.g.functional.func.name.name)}_fn"
|
| 347 |
+
|
| 348 |
+
def arguments(self) -> List[Binding]:
|
| 349 |
+
return ufunc.stub_arguments(self.g)
|
| 350 |
+
|
| 351 |
+
def type(self) -> str:
|
| 352 |
+
cpp_args = self.arguments()
|
| 353 |
+
return f"void(*)(TensorIteratorBase&, {', '.join(a.type for a in cpp_args)})"
|
| 354 |
+
|
| 355 |
+
def dispatch_decl(self) -> str:
|
| 356 |
+
return f"DECLARE_DISPATCH({self.type_name}, {self.name})"
|
| 357 |
+
|
| 358 |
+
def dispatch_defn(self) -> str:
|
| 359 |
+
return f"DEFINE_DISPATCH({self.name})"
|
| 360 |
+
|
| 361 |
+
def kernel_defn(self) -> str:
|
| 362 |
+
return f"void {self.kernel_name}(TensorIteratorBase& iter, {', '.join(a.defn() for a in self.arguments())})"
|
| 363 |
+
|
| 364 |
+
def type_defn(self) -> str:
|
| 365 |
+
return f"using {self.type_name} = {self.type()}"
|
| 366 |
+
|
| 367 |
+
# must be called from context where this is TensorIteratorBase*
|
| 368 |
+
def call(self, ctx: Sequence[Binding]) -> str:
|
| 369 |
+
return f"{self.name}(device_type(), *this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})"
|
| 370 |
+
|
| 371 |
+
# used in CUDA to skip the unnecessary dynamic dispatch
|
| 372 |
+
def direct_call(self, ctx: Sequence[Binding]) -> str:
|
| 373 |
+
return f"{self.kernel_name}(*this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})"
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
@with_native_function
|
| 377 |
+
def compute_ufunc_cpu(g: NativeFunctionsGroup) -> str:
|
| 378 |
+
stub_sig = StubSignature(g)
|
| 379 |
+
sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CPU))
|
| 380 |
+
|
| 381 |
+
return f"""
|
| 382 |
+
{stub_sig.type_defn()};
|
| 383 |
+
{stub_sig.dispatch_decl()};
|
| 384 |
+
{stub_sig.dispatch_defn()};
|
| 385 |
+
|
| 386 |
+
{sig.defn()} {{
|
| 387 |
+
{stub_sig.call(sig.arguments())};
|
| 388 |
+
}}
|
| 389 |
+
"""
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def compute_ufunc_cpu_dtype_body(
|
| 393 |
+
g: NativeFunctionsGroup,
|
| 394 |
+
dtype: ScalarType,
|
| 395 |
+
inner_loops: Dict[UfuncKey, UfuncSignature],
|
| 396 |
+
parent_ctx: Sequence[Binding],
|
| 397 |
+
) -> str:
|
| 398 |
+
assert UfuncKey.CPUScalar in inner_loops, f"{dtype}, {inner_loops.keys()}"
|
| 399 |
+
assert inner_loops.keys() <= {UfuncKey.CPUScalar, UfuncKey.CPUVector}
|
| 400 |
+
scalar_loop = inner_loops[UfuncKey.CPUScalar]
|
| 401 |
+
vec_loop = None
|
| 402 |
+
if UfuncKey.CPUVector in inner_loops:
|
| 403 |
+
vec_loop = inner_loops[UfuncKey.CPUVector]
|
| 404 |
+
|
| 405 |
+
# NB: We DON'T use translate here, because translate is
|
| 406 |
+
# incapable of CSE'ing the scalar accesses in case it is also
|
| 407 |
+
# used by Vectorized; also, the unpacking here is very simple
|
| 408 |
+
# and only affects Scalar; everything else is implicitly captured
|
| 409 |
+
# by the lambda
|
| 410 |
+
|
| 411 |
+
# Setup scalar in scope
|
| 412 |
+
body = []
|
| 413 |
+
ctx = []
|
| 414 |
+
for b in parent_ctx:
|
| 415 |
+
if isinstance(b.argument, Argument) and b.argument.type != BaseType(
|
| 416 |
+
BaseTy.Scalar
|
| 417 |
+
):
|
| 418 |
+
continue
|
| 419 |
+
body.append(f"auto _s_{b.name} = {b.name}.to<scalar_t>();")
|
| 420 |
+
ctx.append(Expr(f"_s_{b.name}", NamedCType(b.nctype.name, BaseCType(scalar_t))))
|
| 421 |
+
if vec_loop is not None:
|
| 422 |
+
for b in parent_ctx:
|
| 423 |
+
if isinstance(b.argument, Argument) and b.argument.type != BaseType(
|
| 424 |
+
BaseTy.Scalar
|
| 425 |
+
):
|
| 426 |
+
continue
|
| 427 |
+
body.append(
|
| 428 |
+
f"auto _v_{b.name} = at::vec::Vectorized<scalar_t>(_s_{b.name});"
|
| 429 |
+
)
|
| 430 |
+
ctx.append(
|
| 431 |
+
Expr(
|
| 432 |
+
f"_v_{b.name}",
|
| 433 |
+
NamedCType(b.nctype.name, VectorizedCType(BaseCType(scalar_t))),
|
| 434 |
+
)
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
# Setup lambda signature
|
| 438 |
+
# NB: simplified version of ufunctor_arguments
|
| 439 |
+
scalar_bindings = []
|
| 440 |
+
vec_bindings = []
|
| 441 |
+
for a in g.functional.func.arguments.flat_non_out:
|
| 442 |
+
if not a.type.is_tensor_like():
|
| 443 |
+
continue
|
| 444 |
+
assert a.type == BaseType(BaseTy.Tensor)
|
| 445 |
+
scalar_bindings.append(
|
| 446 |
+
Binding(
|
| 447 |
+
name=a.name,
|
| 448 |
+
nctype=NamedCType(a.name, BaseCType(scalar_t)),
|
| 449 |
+
argument=a,
|
| 450 |
+
)
|
| 451 |
+
)
|
| 452 |
+
if vec_loop is not None:
|
| 453 |
+
vec_bindings.append(
|
| 454 |
+
Binding(
|
| 455 |
+
name=a.name,
|
| 456 |
+
nctype=NamedCType(a.name, VectorizedCType(BaseCType(scalar_t))),
|
| 457 |
+
argument=a,
|
| 458 |
+
)
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
def with_ctx(b: Sequence[Binding]) -> List[Union[Expr, Binding]]:
|
| 462 |
+
r: List[Union[Expr, Binding]] = []
|
| 463 |
+
r.extend(ctx)
|
| 464 |
+
r.extend(b)
|
| 465 |
+
return r
|
| 466 |
+
|
| 467 |
+
body_str = "\n".join(body)
|
| 468 |
+
if vec_loop is not None:
|
| 469 |
+
return f"""
|
| 470 |
+
{body_str}
|
| 471 |
+
cpu_kernel_vec(iter,
|
| 472 |
+
[=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }},
|
| 473 |
+
[=]({', '.join(b.decl() for b in vec_bindings)}) {{ return {vec_loop.call(with_ctx(vec_bindings))}; }}
|
| 474 |
+
);
|
| 475 |
+
"""
|
| 476 |
+
else:
|
| 477 |
+
return f"""
|
| 478 |
+
{body_str}
|
| 479 |
+
cpu_kernel(iter,
|
| 480 |
+
[=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }}
|
| 481 |
+
);
|
| 482 |
+
"""
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
@with_native_function
|
| 486 |
+
def compute_ufunc_cpu_kernel(g: NativeFunctionsGroup) -> str:
|
| 487 |
+
stub_sig = StubSignature(g)
|
| 488 |
+
|
| 489 |
+
# Reindex the ufunc by dtypes; processing generic/scalaronly as well
|
| 490 |
+
loops = g.out.ufunc_inner_loop
|
| 491 |
+
ufunc_sigs: Dict[ScalarType, Dict[UfuncKey, UfuncSignature]] = {}
|
| 492 |
+
for k in [UfuncKey.CPUScalar, UfuncKey.CPUVector]:
|
| 493 |
+
lks = []
|
| 494 |
+
# ORDER MATTERS: this specifies overriding precedence
|
| 495 |
+
if k in loops: # should happen rarely
|
| 496 |
+
lks.append(k)
|
| 497 |
+
if UfuncKey.ScalarOnly in loops and k is UfuncKey.CPUScalar:
|
| 498 |
+
lks.append(UfuncKey.ScalarOnly)
|
| 499 |
+
if UfuncKey.Generic in loops:
|
| 500 |
+
lks.append(UfuncKey.Generic)
|
| 501 |
+
# TODO: don't hardcode ufunc:: namespace here, should be centralized smh
|
| 502 |
+
for lk in lks:
|
| 503 |
+
for dtype in loops[lk].supported_dtypes:
|
| 504 |
+
compute_t: CType
|
| 505 |
+
if k is UfuncKey.CPUScalar:
|
| 506 |
+
compute_t = BaseCType(scalar_t)
|
| 507 |
+
elif k is UfuncKey.CPUVector:
|
| 508 |
+
compute_t = VectorizedCType(BaseCType(scalar_t))
|
| 509 |
+
else:
|
| 510 |
+
raise AssertionError()
|
| 511 |
+
inner_ufunc_sigs = ufunc_sigs.setdefault(dtype, {})
|
| 512 |
+
if k not in inner_ufunc_sigs:
|
| 513 |
+
inner_ufunc_sigs[k] = UfuncSignature(
|
| 514 |
+
g, name=f"ufunc::{loops[lk].name}", compute_t=compute_t
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
# Build the conditionals
|
| 518 |
+
dtype_cases = []
|
| 519 |
+
for dtype, inner_ufunc_sigs in ufunc_sigs.items():
|
| 520 |
+
dtype_cases.append(
|
| 521 |
+
f"""
|
| 522 |
+
AT_DISPATCH_CASE(at::ScalarType::{dtype},
|
| 523 |
+
[&]() {{
|
| 524 |
+
{compute_ufunc_cpu_dtype_body(g, dtype, inner_ufunc_sigs, stub_sig.arguments())}
|
| 525 |
+
}}
|
| 526 |
+
)
|
| 527 |
+
"""
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
dtype_cases_str = "\n".join(dtype_cases)
|
| 531 |
+
return f"""
|
| 532 |
+
namespace {{
|
| 533 |
+
|
| 534 |
+
{stub_sig.kernel_defn()} {{
|
| 535 |
+
AT_DISPATCH_SWITCH(iter.common_dtype(), "{stub_sig.name}",
|
| 536 |
+
{dtype_cases_str}
|
| 537 |
+
);
|
| 538 |
+
}}
|
| 539 |
+
|
| 540 |
+
}} // anonymous namespace
|
| 541 |
+
|
| 542 |
+
{stub_sig.type_defn()};
|
| 543 |
+
{stub_sig.dispatch_decl()};
|
| 544 |
+
REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name});
|
| 545 |
+
"""
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/__init__.py
ADDED
|
File without changes
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (220 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/__pycache__/parse.cpython-311.pyc
ADDED
|
Binary file (7.11 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/executorch/api/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (224 Bytes). View file
|
|
|